package com.omega.engine.nn.layer.gpu;

import com.omega.common.data.Tensor;
import com.omega.common.utils.JsonUtils;
import com.omega.common.utils.MatrixUtils;
import com.omega.common.utils.RandomUtils;
import com.omega.engine.gpu.BaseKernel;
import com.omega.engine.gpu.CUDAModules;
import jcuda.NativePointerObject;
import jcuda.Pointer;
import jcuda.driver.CUfunction;
import jcuda.driver.CUstream;
import jcuda.driver.JCudaDriver;
import jcuda.runtime.cudaError;

/* loaded from: input_file:com/omega/engine/nn/layer/gpu/UpSampleKernel.class */
public class UpSampleKernel extends BaseKernel {
    private int stride;
    private float scale;
    private boolean reverse;
    private CUfunction forward_function;
    private int CAFFE_CUDA_NUM_THREADS = 1024;
    private Pointer forwardKernelParameters;
    private Pointer backwardKernelParameters;

    public UpSampleKernel(int i, float f) {
        this.reverse = false;
        this.stride = i;
        this.scale = f;
        if (this.stride < 0) {
            this.stride = -i;
            this.reverse = true;
        }
        init();
    }

    public void initFunction() {
        try {
            if (this.forward_function == null) {
                this.forward_function = CUDAModules.getLocalFunctionByModule("UpSampleKernel.cu", "upsample_kernel");
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void init() {
        initFunction();
    }

    @Override // com.omega.engine.gpu.BaseKernel
    public int CAFFE_GET_BLOCKS(int i) {
        return ((i + this.CAFFE_CUDA_NUM_THREADS) - 1) / this.CAFFE_CUDA_NUM_THREADS;
    }

    public void forward(Tensor tensor, Tensor tensor2) {
        if (this.reverse) {
            upsample(tensor2, tensor, 0);
        } else {
            upsample(tensor, tensor2, 1);
        }
    }

    public void backward(Tensor tensor, Tensor tensor2) {
        if (this.reverse) {
            upsampleDelta(tensor, tensor2, 1);
        } else {
            upsampleDelta(tensor2, tensor, 0);
        }
    }

    public void upsample(Tensor tensor, Tensor tensor2, int i) {
        try {
            int i2 = tensor.channel * tensor.number * tensor.width * tensor.height * this.stride * this.stride;
            if (tensor.number != this.N) {
                this.N = tensor.number;
                this.forwardKernelParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new int[]{i2}), Pointer.to(new NativePointerObject[]{tensor.getGpuData()}), Pointer.to(new int[]{tensor.width}), Pointer.to(new int[]{tensor.height}), Pointer.to(new int[]{tensor.channel}), Pointer.to(new int[]{tensor.number}), Pointer.to(new int[]{this.stride}), Pointer.to(new int[]{i}), Pointer.to(new float[]{this.scale}), Pointer.to(new NativePointerObject[]{tensor2.getGpuData()})});
            }
            JCudaDriver.cuLaunchKernel(this.forward_function, CAFFE_GET_BLOCKS(i2), 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.forwardKernelParameters, (Pointer) null);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void upsampleDelta(Tensor tensor, Tensor tensor2, int i) {
        try {
            int i2 = tensor.channel * tensor.number * tensor.width * tensor.height * this.stride * this.stride;
            if (this.backwardKernelParameters == null) {
                this.backwardKernelParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new int[]{i2}), Pointer.to(new NativePointerObject[]{tensor.getGpuData()}), Pointer.to(new int[]{tensor.width}), Pointer.to(new int[]{tensor.height}), Pointer.to(new int[]{tensor.channel}), Pointer.to(new int[]{tensor.number}), Pointer.to(new int[]{this.stride}), Pointer.to(new int[]{i}), Pointer.to(new float[]{this.scale}), Pointer.to(new NativePointerObject[]{tensor2.getGpuData()})});
            }
            JCudaDriver.cuLaunchKernel(this.forward_function, CAFFE_GET_BLOCKS(i2), 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.backwardKernelParameters, (Pointer) null);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    @Override // com.omega.engine.gpu.BaseKernel
    public void checkCUDA(int i) {
        if (i != 0) {
            System.err.println("Error code " + i + ":" + cudaError.stringFor(i));
        }
    }

    public static void main(String[] strArr) {
        CUDAModules.initContext();
        int i = 2;
        int i2 = 4 * 2;
        int i3 = 4 * 2;
        if (2 < 0) {
            i = -2;
            i2 = 4 / i;
            i3 = 4 / i;
        }
        float[] order = MatrixUtils.order(2 * 3 * 4 * 4, 1, 1);
        float[] order2 = RandomUtils.order(2 * 3 * i2 * i3, 0.1f, 0.1f);
        Tensor tensor = new Tensor(2, 3, 4, 4, order, true);
        Tensor tensor2 = new Tensor(2, 3, i2, i3, true);
        float[] fArr = new float[tensor2.dataLength];
        Tensor tensor3 = new Tensor(2, 3, i2, i3, order2, true);
        Tensor tensor4 = new Tensor(2, 3, 4, 4, true);
        float[] fArr2 = new float[tensor4.dataLength];
        UpSampleKernel upSampleKernel = new UpSampleKernel(i, 1.0f);
        long nanoTime = System.nanoTime();
        upSampleKernel.forward(tensor, tensor2);
        System.out.println(((System.nanoTime() - nanoTime) / 1000000.0d) + "ms.");
        tensor.showDM();
        tensor2.showDM();
        upsample_cpu(tensor.data, 4, 4, 3, 2, i, 1, 1.0f, fArr);
        System.out.println(JsonUtils.toJson(fArr));
        upSampleKernel.backward(tensor3, tensor4);
        tensor3.showDM();
        tensor4.showDM();
        upsample_cpu(fArr2, 4, 4, 3, 2, i, 0, 1.0f, tensor3.data);
        System.out.println(JsonUtils.toJson(fArr2));
    }

    public static void upsample_cpu(float[] fArr, int i, int i2, int i3, int i4, int i5, int i6, float f, float[] fArr2) {
        for (int i7 = 0; i7 < i4; i7++) {
            for (int i8 = 0; i8 < i3; i8++) {
                for (int i9 = 0; i9 < i2 * i5; i9++) {
                    for (int i10 = 0; i10 < i * i5; i10++) {
                        int i11 = (i7 * i * i2 * i3) + (i8 * i * i2) + ((i9 / i5) * i) + (i10 / i5);
                        int i12 = (i7 * i * i2 * i3 * i5 * i5) + (i8 * i * i2 * i5 * i5) + (i9 * i * i5) + i10;
                        if (i6 == 1) {
                            fArr2[i12] = f * fArr[i11];
                        } else {
                            fArr[i11] = fArr[i11] + (f * fArr2[i12]);
                        }
                    }
                }
            }
        }
    }
}
