package com.omega.engine.updater.gpu;

import com.omega.common.data.Tensor;
import com.omega.common.utils.RandomUtils;
import com.omega.engine.gpu.CUDAMemoryManager;
import com.omega.engine.gpu.CUDAModules;
import com.omega.engine.loss.SoftmaxWithCrossEntropyLoss;
import com.omega.engine.nn.network.BPNetwork;
import com.omega.engine.nn.network.Network;
import java.util.Map;
import jcuda.NativePointerObject;
import jcuda.Pointer;
import jcuda.driver.CUfunction;
import jcuda.driver.CUstream;
import jcuda.driver.JCudaDriver;

/* loaded from: input_file:com/omega/engine/updater/gpu/AdamWKernel.class */
public class AdamWKernel {
    public Tensor mw;
    public Tensor vw;
    public Tensor mb;
    public Tensor vb;
    private float beta1;
    private float beta2;
    private CUfunction function;
    private CUfunction r_function;
    private CUfunction bn_function;
    private int CAFFE_CUDA_NUM_THREADS;
    private Pointer kernelParameters;
    private Pointer kernelBiasParameters;
    private float weight_decay;

    public AdamWKernel(int i, float f) {
        this.beta1 = 0.9f;
        this.beta2 = 0.95f;
        this.CAFFE_CUDA_NUM_THREADS = 1024;
        this.weight_decay = 0.0f;
        this.mw = new Tensor(1, 1, 1, i, true);
        this.vw = new Tensor(1, 1, 1, i, true);
        this.weight_decay = f;
        init();
    }

    public AdamWKernel(int i, int i2, float f) {
        this.beta1 = 0.9f;
        this.beta2 = 0.95f;
        this.CAFFE_CUDA_NUM_THREADS = 1024;
        this.weight_decay = 0.0f;
        this.mw = new Tensor(1, 1, 1, i, true);
        this.vw = new Tensor(1, 1, 1, i, true);
        this.mb = new Tensor(1, 1, 1, i2, true);
        this.vb = new Tensor(1, 1, 1, i2, true);
        this.weight_decay = f;
        init();
    }

    public AdamWKernel(int i, int i2, float f, float f2, float f3) {
        this.beta1 = 0.9f;
        this.beta2 = 0.95f;
        this.CAFFE_CUDA_NUM_THREADS = 1024;
        this.weight_decay = 0.0f;
        this.mw = new Tensor(1, 1, 1, i, true);
        this.vw = new Tensor(1, 1, 1, i, true);
        this.mb = new Tensor(1, 1, 1, i2, true);
        this.vb = new Tensor(1, 1, 1, i2, true);
        this.beta1 = f;
        this.beta2 = f2;
        this.weight_decay = f3;
        init();
    }

    public void setParams(Map<String, Float> map) {
        if (map != null) {
            if (map.get("beta1") != null) {
                this.beta1 = map.get("beta1").floatValue();
            }
            if (map.get("beta2") != null) {
                this.beta2 = map.get("beta2").floatValue();
            }
            if (map.get("weight_decay") != null) {
                this.weight_decay = map.get("weight_decay").floatValue();
            }
        }
    }

    public void init() {
        initFunction();
    }

    public void initFunction() {
        try {
            if (this.function == null) {
                this.function = CUDAModules.getLocalFunctionByModule("updater.cu", "adamw");
            }
            if (this.r_function == null) {
                this.r_function = CUDAModules.getLocalFunctionByModule("updater.cu", "adamwr");
            }
            if (this.bn_function == null) {
                this.bn_function = CUDAModules.getLocalFunctionByModule("updater.cu", "adamw_bn");
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public int CAFFE_GET_BLOCKS(int i) {
        return ((i + this.CAFFE_CUDA_NUM_THREADS) - 1) / this.CAFFE_CUDA_NUM_THREADS;
    }

    public void updateW(Tensor tensor, Tensor tensor2, Network network, float f) {
        try {
            this.kernelParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new NativePointerObject[]{tensor.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor2.getGpuData()}), Pointer.to(new NativePointerObject[]{this.mw.getGpuData()}), Pointer.to(new NativePointerObject[]{this.vw.getGpuData()}), Pointer.to(new float[]{this.beta1}), Pointer.to(new float[]{this.beta2}), Pointer.to(new float[]{f}), Pointer.to(new float[]{this.weight_decay}), Pointer.to(new int[]{tensor.dataLength}), Pointer.to(new int[]{network.number}), Pointer.to(new int[]{network.train_time})});
            JCudaDriver.cuLaunchKernel(this.function, CAFFE_GET_BLOCKS(tensor.dataLength), 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.kernelParameters, (Pointer) null);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void updateW(Tensor tensor, Tensor tensor2, Network network, float f, int i) {
        try {
            this.kernelParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new NativePointerObject[]{tensor.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor2.getGpuData()}), Pointer.to(new NativePointerObject[]{this.mw.getGpuData()}), Pointer.to(new NativePointerObject[]{this.vw.getGpuData()}), Pointer.to(new float[]{this.beta1}), Pointer.to(new float[]{this.beta2}), Pointer.to(new float[]{f}), Pointer.to(new float[]{this.weight_decay}), Pointer.to(new int[]{tensor.dataLength}), Pointer.to(new int[]{i}), Pointer.to(new int[]{network.train_time})});
            JCudaDriver.cuLaunchKernel(this.function, CAFFE_GET_BLOCKS(tensor.dataLength), 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.kernelParameters, (Pointer) null);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void updateGama(Tensor tensor, Tensor tensor2, Network network, float f) {
        try {
            this.kernelParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new NativePointerObject[]{tensor.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor2.getGpuData()}), Pointer.to(new NativePointerObject[]{this.mw.getGpuData()}), Pointer.to(new NativePointerObject[]{this.vw.getGpuData()}), Pointer.to(new float[]{this.beta1}), Pointer.to(new float[]{this.beta2}), Pointer.to(new float[]{f}), Pointer.to(new float[]{0.0f}), Pointer.to(new int[]{tensor.dataLength}), Pointer.to(new int[]{network.number}), Pointer.to(new int[]{network.train_time})});
            JCudaDriver.cuLaunchKernel(this.bn_function, CAFFE_GET_BLOCKS(tensor.dataLength), 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.kernelParameters, (Pointer) null);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void updateB(Tensor tensor, Tensor tensor2, Network network, float f) {
        try {
            this.kernelBiasParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new NativePointerObject[]{tensor.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor2.getGpuData()}), Pointer.to(new NativePointerObject[]{this.mb.getGpuData()}), Pointer.to(new NativePointerObject[]{this.vb.getGpuData()}), Pointer.to(new float[]{this.beta1}), Pointer.to(new float[]{this.beta2}), Pointer.to(new float[]{f}), Pointer.to(new float[]{0.0f}), Pointer.to(new int[]{tensor.dataLength}), Pointer.to(new int[]{network.number}), Pointer.to(new int[]{network.train_time})});
            JCudaDriver.cuLaunchKernel(this.function, CAFFE_GET_BLOCKS(tensor.dataLength), 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.kernelBiasParameters, (Pointer) null);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void updateB(Tensor tensor, Tensor tensor2, Network network, float f, int i) {
        try {
            this.kernelBiasParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new NativePointerObject[]{tensor.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor2.getGpuData()}), Pointer.to(new NativePointerObject[]{this.mb.getGpuData()}), Pointer.to(new NativePointerObject[]{this.vb.getGpuData()}), Pointer.to(new float[]{this.beta1}), Pointer.to(new float[]{this.beta2}), Pointer.to(new float[]{f}), Pointer.to(new float[]{0.0f}), Pointer.to(new int[]{tensor.dataLength}), Pointer.to(new int[]{i}), Pointer.to(new int[]{network.train_time})});
            JCudaDriver.cuLaunchKernel(this.function, CAFFE_GET_BLOCKS(tensor.dataLength), 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.kernelBiasParameters, (Pointer) null);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void updateBeta(Tensor tensor, Tensor tensor2, Network network, float f) {
        try {
            this.kernelBiasParameters = Pointer.to(new NativePointerObject[]{Pointer.to(new NativePointerObject[]{tensor.getGpuData()}), Pointer.to(new NativePointerObject[]{tensor2.getGpuData()}), Pointer.to(new NativePointerObject[]{this.mb.getGpuData()}), Pointer.to(new NativePointerObject[]{this.vb.getGpuData()}), Pointer.to(new float[]{this.beta1}), Pointer.to(new float[]{this.beta2}), Pointer.to(new float[]{f}), Pointer.to(new float[]{0.0f}), Pointer.to(new int[]{tensor.dataLength}), Pointer.to(new int[]{network.number}), Pointer.to(new int[]{network.train_time})});
            JCudaDriver.cuLaunchKernel(this.bn_function, CAFFE_GET_BLOCKS(tensor.dataLength), 1, 1, this.CAFFE_CUDA_NUM_THREADS, 1, 1, 0, (CUstream) null, this.kernelBiasParameters, (Pointer) null);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public static void main(String[] strArr) {
        float[] fArr = {0.0075240037f, 0.022312285f, 0.037100658f, 0.05188888f, 0.06667703f, 0.08146531f, 0.09625361f, 0.111041375f, 0.12582973f, 0.14061777f, 0.15540561f, 0.17019409f, 0.18498187f, 0.19977006f, 0.21455756f, 0.22934535f, 0.24413382f, 0.25892144f, 0.27370873f, 0.2884969f, 0.303285f, 0.3180732f, 0.33286023f, 0.347648f, 0.36243534f, 0.37722275f, 0.39201018f, 0.40679908f, 0.42158592f, 0.43637308f, 0.45116156f, 0.46594855f, 0.48073593f, 0.4955238f, 0.5103103f, 0.52509815f, 0.53988534f, 0.55467236f};
        float[] fArr2 = new float[fArr.length];
        float[] order = RandomUtils.order(2 * 1 * 1 * 8, 1.0E-5f, 1.0E-5f);
        Tensor tensor = new Tensor(1, 1, 1, fArr.length, fArr2, true);
        Tensor tensor2 = new Tensor(1, 1, 1, fArr.length, fArr, true);
        BPNetwork bPNetwork = new BPNetwork(new SoftmaxWithCrossEntropyLoss());
        bPNetwork.train_time = 1;
        bPNetwork.number = 2;
        new AdamWKernel(order.length, 0.001f).updateGama(tensor2, tensor, bPNetwork, 1.0E-4f);
        tensor2.showDM();
        tensor.showDM();
        CUDAMemoryManager.free();
    }
}
