package org.nd4j.linalg.jcublas.kernel;

import java.util.Properties;
import java.util.Set;
import java.util.concurrent.ConcurrentSkipListSet;
import jcuda.driver.CUstream;
import jcuda.utils.KernelLauncher;
import org.nd4j.linalg.jcublas.buffer.CudaDoubleDataBuffer;
import org.nd4j.linalg.jcublas.buffer.CudaFloatDataBuffer;
import org.nd4j.linalg.jcublas.buffer.JCudaBuffer;
import org.nd4j.linalg.jcublas.context.ContextHolder;
import org.springframework.core.io.ClassPathResource;

/* loaded from: input_file:org/nd4j/linalg/jcublas/kernel/KernelFunctions.class */
public class KernelFunctions {
    public static final String NAME_SPACE = "org.nd4j.linalg.jcuda.jcublas";
    public static final String DOUBLE = "org.nd4j.linalg.jcuda.jcublas.double.functions";
    public static final String FLOAT = "org.nd4j.linalg.jcuda.jcublas.float.functions";
    public static final String REDUCE = "org.nd4j.linalg.jcuda.jcublas.reducefunctions";
    public static final String SHARED_MEM_KEY = "org.nd4j.linalg.jcuda.jcublas.sharedmem";
    public static final String THREADS_KEY = "org.nd4j.linalg.jcuda.jcublas.threads";
    public static final String BLOCKS_KEY = "org.nd4j.linalg.jcuda.jcublas.blocks";
    public static int SHARED_MEM = 512;
    public static int THREADS = 128;
    public static int BLOCKS = 512;
    private static Set<String> reduceFunctions = new ConcurrentSkipListSet();

    private KernelFunctions() {
    }

    public static void register() throws Exception {
        ClassPathResource classPathResource = new ClassPathResource("/cudafunctions.properties");
        if (!classPathResource.exists()) {
            throw new IllegalStateException("Please put a cudafunctions.properties in your class path");
        }
        Properties properties = new Properties();
        properties.load(classPathResource.getInputStream());
        KernelFunctionLoader.getInstance().load();
        for (String str : properties.getProperty(REDUCE).split(",")) {
            reduceFunctions.add(str);
        }
        SHARED_MEM = Integer.parseInt(properties.getProperty(SHARED_MEM_KEY, "512"));
        THREADS = Integer.parseInt(properties.getProperty(THREADS_KEY, "128"));
        BLOCKS = Integer.parseInt(properties.getProperty(BLOCKS_KEY, "64"));
    }

    public static void invoke(int i, int i2, String str, String str2, Object... objArr) {
        CUstream stream = ContextHolder.getInstance().getStream();
        int i3 = i2 * (str2.equals("float") ? 4 : 8);
        KernelLauncher launcher = KernelFunctionLoader.launcher(str, str2);
        if (launcher == null) {
            throw new IllegalArgumentException("Launcher for function " + str + " and data type " + str2 + " does not exist!");
        }
        launcher.forFunction(str + "_" + str2).setBlockSize(i2, 1, 1).setGridSize(i, 1, 1).setStream(stream).setSharedMemSize(i3).call(objArr);
        ContextHolder.syncStream();
    }

    public static JCudaBuffer alloc(double[] dArr) {
        return new CudaDoubleDataBuffer(dArr);
    }

    public static JCudaBuffer alloc(float[] fArr) {
        return new CudaFloatDataBuffer(fArr);
    }

    static {
        try {
            register();
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }
}
