package org.nd4j.linalg.jcublas.rng.distribution;

import jcuda.NativePointerObject;
import jcuda.Pointer;
import jcuda.jcurand.JCurand;
import jcuda.runtime.JCuda;
import org.apache.commons.math3.exception.NumberIsTooLargeException;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.distribution.Distribution;
import org.nd4j.linalg.jcublas.buffer.CudaDoubleDataBuffer;
import org.nd4j.linalg.jcublas.buffer.CudaFloatDataBuffer;
import org.nd4j.linalg.jcublas.buffer.JCudaBuffer;
import org.nd4j.linalg.jcublas.kernel.KernelFunctions;
import org.nd4j.linalg.jcublas.rng.JcudaRandom;
import org.nd4j.linalg.jcublas.util.KernelParamsWrapper;
import org.nd4j.linalg.jcublas.util.PointerUtil;

/* loaded from: input_file:org/nd4j/linalg/jcublas/rng/distribution/BaseJCudaDistribution.class */
public abstract class BaseJCudaDistribution implements Distribution {
    protected JcudaRandom random;

    public BaseJCudaDistribution(JcudaRandom jcudaRandom) {
        this.random = jcudaRandom;
    }

    public void reseedRandomGenerator(long j) {
        this.random.setSeed(j);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void doBinomial(INDArray iNDArray, JCudaBuffer jCudaBuffer, int i, int i2) {
        int numBlocks = PointerUtil.getNumBlocks(i2, 128, 64);
        int numThreads = PointerUtil.getNumThreads(i2, 64);
        try {
            KernelParamsWrapper kernelParamsWrapper = new KernelParamsWrapper(Integer.valueOf(i2), Integer.valueOf(i), (JCudaBuffer) iNDArray.data(), new CudaFloatDataBuffer(i2 * i), jCudaBuffer, this.random.generator());
            Throwable th = null;
            try {
                KernelFunctions.invoke(numBlocks, numThreads, "binomial", "float", kernelParamsWrapper.getKernelParameters());
                jCudaBuffer.copyToHost();
                if (kernelParamsWrapper != null) {
                    if (0 != 0) {
                        try {
                            kernelParamsWrapper.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        kernelParamsWrapper.close();
                    }
                }
            } finally {
            }
        } catch (Exception e) {
            throw new RuntimeException("Cannot run kernel", e);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void doBinomialDouble(INDArray iNDArray, JCudaBuffer jCudaBuffer, int i, int i2) {
        int numBlocks = PointerUtil.getNumBlocks(i2, 128, 64);
        int numThreads = PointerUtil.getNumThreads(i2, 64);
        try {
            KernelParamsWrapper kernelParamsWrapper = new KernelParamsWrapper(Integer.valueOf(i2), Integer.valueOf(i), (JCudaBuffer) iNDArray.data(), new CudaDoubleDataBuffer(i2), jCudaBuffer, this.random.generator());
            Throwable th = null;
            try {
                KernelFunctions.invoke(numBlocks, numThreads, "binomial", "double", kernelParamsWrapper.getKernelParameters());
                jCudaBuffer.copyToHost();
                if (kernelParamsWrapper != null) {
                    if (0 != 0) {
                        try {
                            kernelParamsWrapper.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        kernelParamsWrapper.close();
                    }
                }
            } finally {
            }
        } catch (Exception e) {
            throw new RuntimeException("Cannot run kernel", e);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void doBinomial(float f, JCudaBuffer jCudaBuffer, int i, int i2) {
        int numBlocks = PointerUtil.getNumBlocks(i2, KernelFunctions.BLOCKS, KernelFunctions.THREADS);
        int numThreads = PointerUtil.getNumThreads(i2, KernelFunctions.THREADS);
        try {
            KernelParamsWrapper kernelParamsWrapper = new KernelParamsWrapper(Integer.valueOf(i2), Integer.valueOf(i), Float.valueOf(f), new CudaFloatDataBuffer(i2 * i), jCudaBuffer, this.random.generator());
            Throwable th = null;
            try {
                try {
                    KernelFunctions.invoke(numBlocks, numThreads, "binomial_scalar", "float", kernelParamsWrapper.getKernelParameters());
                    jCudaBuffer.copyToHost();
                    if (kernelParamsWrapper != null) {
                        if (0 != 0) {
                            try {
                                kernelParamsWrapper.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            kernelParamsWrapper.close();
                        }
                    }
                } finally {
                }
            } finally {
            }
        } catch (Exception e) {
            throw new RuntimeException("Cannot run kernel", e);
        }
    }

    public double sample() {
        return inverseCumulativeProbability(this.random.nextDouble());
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void doBinomialDouble(double d, JCudaBuffer jCudaBuffer, int i, int i2) {
        int numBlocks = PointerUtil.getNumBlocks(i2, KernelFunctions.BLOCKS, KernelFunctions.THREADS);
        int numThreads = PointerUtil.getNumThreads(i2, KernelFunctions.THREADS);
        try {
            KernelParamsWrapper kernelParamsWrapper = new KernelParamsWrapper(Integer.valueOf(i2), Integer.valueOf(i), Double.valueOf(d), new CudaDoubleDataBuffer(i2), jCudaBuffer, Pointer.to(new NativePointerObject[]{this.random.generator()}));
            Throwable th = null;
            try {
                try {
                    KernelFunctions.invoke(numBlocks, numThreads, "binomial_scalar", "double", kernelParamsWrapper.getKernelParameters());
                    jCudaBuffer.copyToHost();
                    if (kernelParamsWrapper != null) {
                        if (0 != 0) {
                            try {
                                kernelParamsWrapper.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            kernelParamsWrapper.close();
                        }
                    }
                } finally {
                }
            } finally {
            }
        } catch (Exception e) {
            throw new RuntimeException("Cannot run kernel", e);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void doSampleUniform(JCudaBuffer jCudaBuffer, float f, float f2, int i) {
        int numBlocks = PointerUtil.getNumBlocks(i, KernelFunctions.BLOCKS, KernelFunctions.THREADS);
        int numThreads = PointerUtil.getNumThreads(i, KernelFunctions.THREADS);
        CudaFloatDataBuffer cudaFloatDataBuffer = new CudaFloatDataBuffer(i);
        JCurand.curandGenerateUniform(this.random.generator(), cudaFloatDataBuffer.mo9getDevicePointer(), i);
        try {
            KernelParamsWrapper kernelParamsWrapper = new KernelParamsWrapper(Integer.valueOf(i), Float.valueOf(f), Float.valueOf(f2), cudaFloatDataBuffer.mo9getDevicePointer(), jCudaBuffer);
            Throwable th = null;
            try {
                KernelFunctions.invoke(numBlocks, numThreads, "uniform", "float", kernelParamsWrapper.getKernelParameters());
                jCudaBuffer.copyToHost();
                cudaFloatDataBuffer.freeDevicePointer();
                if (kernelParamsWrapper != null) {
                    if (0 != 0) {
                        try {
                            kernelParamsWrapper.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        kernelParamsWrapper.close();
                    }
                }
            } finally {
            }
        } catch (Exception e) {
            throw new RuntimeException("Cannot run kernel", e);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void doSampleUniformDouble(JCudaBuffer jCudaBuffer, double d, double d2, int i) {
        int numBlocks = PointerUtil.getNumBlocks(i, 128, 64);
        int numThreads = PointerUtil.getNumThreads(i, 64);
        CudaDoubleDataBuffer cudaDoubleDataBuffer = new CudaDoubleDataBuffer(i);
        JCurand.curandGenerateUniformDouble(this.random.generator(), cudaDoubleDataBuffer.mo9getDevicePointer(), i);
        try {
            KernelParamsWrapper kernelParamsWrapper = new KernelParamsWrapper(Integer.valueOf(i), Double.valueOf(d), Double.valueOf(d2), cudaDoubleDataBuffer.mo9getDevicePointer(), jCudaBuffer);
            Throwable th = null;
            try {
                try {
                    KernelFunctions.invoke(numBlocks, numThreads, "uniform", "double", kernelParamsWrapper.getKernelParameters());
                    jCudaBuffer.copyToHost();
                    cudaDoubleDataBuffer.freeDevicePointer();
                    if (kernelParamsWrapper != null) {
                        if (0 != 0) {
                            try {
                                kernelParamsWrapper.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            kernelParamsWrapper.close();
                        }
                    }
                } finally {
                }
            } finally {
            }
        } catch (Exception e) {
            throw new RuntimeException("Cannot run kernel", e);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void doSampleNormal(Pointer pointer, INDArray iNDArray, float f) {
        float[] asFloat = iNDArray.data().asFloat();
        for (int i = 0; i < iNDArray.length(); i++) {
            CudaFloatDataBuffer cudaFloatDataBuffer = new CudaFloatDataBuffer(2);
            JCurand.curandGenerateNormal(this.random.generator(), cudaFloatDataBuffer.mo9getDevicePointer(), 2L, asFloat[i], f);
            JCuda.cudaMemcpy(pointer.withByteOffset(4 * i), cudaFloatDataBuffer.mo9getDevicePointer(), 4L, 3);
            cudaFloatDataBuffer.freeDevicePointer();
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void doSampleNormalDouble(Pointer pointer, INDArray iNDArray, double d) {
        double[] asDouble = iNDArray.data().asDouble();
        for (int i = 0; i < iNDArray.length(); i++) {
            CudaDoubleDataBuffer cudaDoubleDataBuffer = new CudaDoubleDataBuffer(2);
            JCurand.curandGenerateNormalDouble(this.random.generator(), cudaDoubleDataBuffer.mo9getDevicePointer(), 2L, asDouble[i], d);
            JCuda.cudaMemcpy(pointer.withByteOffset(8 * i), cudaDoubleDataBuffer.mo9getDevicePointer(), 8L, 3);
            cudaDoubleDataBuffer.freeDevicePointer();
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void doSampleNormal(float f, float f2, Pointer pointer, int i) {
        JCurand.curandGenerateNormal(this.random.generator(), pointer, i, f, f2);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void doSampleNormal(double d, double d2, Pointer pointer, int i) {
        JCurand.curandGenerateNormalDouble(this.random.generator(), pointer, i, d, d2);
    }

    public abstract double probability(double d, double d2) throws NumberIsTooLargeException;
}
