package com.omega.common.data;

import com.omega.common.utils.JsonUtils;
import com.omega.common.utils.MatrixUtils;
import com.omega.common.utils.RandomUtils;
import com.omega.engine.ad.Graph;
import com.omega.engine.ad.op.OPType;
import com.omega.engine.ad.op.TensorOP;
import com.omega.engine.ad.op.gpu.OPKernel;
import com.omega.engine.gpu.CUDAMemoryManager;
import java.io.Serializable;
import java.util.UUID;
import jcuda.Pointer;
import jcuda.runtime.JCuda;
import jcuda.runtime.cudaError;
import jcuda.runtime.cudaStream_t;

/* loaded from: input_file:com/omega/common/data/Tensor.class */
public class Tensor implements Serializable {
    private static final long serialVersionUID = 5844762745177624845L;
    public String id;
    public int number;
    public int channel;
    public int height;
    public int width;
    public int dataLength;
    public float[] data;
    private Pointer gpuData;
    public float[] once;
    public int onceSize;
    private boolean hasGPU = false;
    private boolean requiresGrad = false;
    private Tensor grad;
    private Graph g;
    private int[] orgShape;
    private Tensor tmp;
    private Tensor tmp_once;
    static final /* synthetic */ boolean $assertionsDisabled;

    public Tensor getTmp() {
        if (this.tmp == null) {
            this.tmp = new Tensor(this.number, this.channel, this.height, this.width, this.hasGPU);
        }
        return this.tmp;
    }

    public Tensor getTmpOnce() {
        if (this.tmp_once == null) {
            this.tmp_once = new Tensor(1, 1, 1, 1, true);
        }
        return this.tmp_once;
    }

    public String getId() {
        if (this.id == null) {
            this.id = UUID.randomUUID().toString();
        }
        return this.id;
    }

    public Tensor copy() {
        float[] fArr = new float[this.dataLength];
        System.arraycopy(this.data, 0, fArr, 0, this.dataLength);
        return new Tensor(this.number, this.channel, this.height, this.width, fArr, this.hasGPU);
    }

    public Tensor copyGPU() {
        float[] fArr = new float[this.dataLength];
        System.arraycopy(syncHost(), 0, fArr, 0, this.dataLength);
        return new Tensor(this.number, this.channel, this.height, this.width, fArr, this.hasGPU);
    }

    public void copy(Tensor tensor) {
        System.arraycopy(syncHost(), 0, tensor.data, 0, this.dataLength);
        if (tensor.hasGPU) {
            tensor.hostToDevice();
        }
    }

    public Tensor(int i, int i2, int i3, int i4) {
        this.number = 0;
        this.channel = 0;
        this.height = 0;
        this.width = 0;
        this.dataLength = 0;
        this.number = i;
        this.channel = i2;
        this.height = i3;
        this.width = i4;
        this.dataLength = i * i2 * i3 * i4;
        this.data = new float[this.dataLength];
    }

    public Tensor(int i, int i2, int i3, int i4, boolean z) {
        this.number = 0;
        this.channel = 0;
        this.height = 0;
        this.width = 0;
        this.dataLength = 0;
        this.number = i;
        this.channel = i2;
        this.height = i3;
        this.width = i4;
        this.dataLength = i * i2 * i3 * i4;
        this.data = new float[this.dataLength];
        this.orgShape = new int[]{i, i2, i3, i4};
        setHasGPU(z);
        if (z) {
            this.gpuData = CUDAMemoryManager.getPointer(this.dataLength);
            JCuda.cudaMemcpy(this.gpuData, Pointer.to(this.data), this.dataLength * 4, 1);
            JCuda.cudaDeviceSynchronize();
        }
    }

    public Tensor(int i, int i2, int i3, int i4, boolean z, Graph graph) {
        this.number = 0;
        this.channel = 0;
        this.height = 0;
        this.width = 0;
        this.dataLength = 0;
        this.g = graph;
        this.number = i;
        this.channel = i2;
        this.height = i3;
        this.width = i4;
        this.dataLength = i * i2 * i3 * i4;
        this.data = new float[this.dataLength];
        setHasGPU(z);
        if (z) {
            this.gpuData = CUDAMemoryManager.getPointer(this.dataLength);
            JCuda.cudaMemcpy(this.gpuData, Pointer.to(this.data), this.dataLength * 4, 1);
            JCuda.cudaDeviceSynchronize();
        }
    }

    public Tensor(int i, int i2, int i3, int i4, float[] fArr) {
        this.number = 0;
        this.channel = 0;
        this.height = 0;
        this.width = 0;
        this.dataLength = 0;
        this.number = i;
        this.channel = i2;
        this.height = i3;
        this.width = i4;
        this.dataLength = i * i2 * i3 * i4;
        this.data = fArr;
    }

    public Tensor(int i, int i2, int i3, int i4, float[] fArr, boolean z) {
        this.number = 0;
        this.channel = 0;
        this.height = 0;
        this.width = 0;
        this.dataLength = 0;
        this.number = i;
        this.channel = i2;
        this.height = i3;
        this.width = i4;
        this.dataLength = i * i2 * i3 * i4;
        this.data = fArr;
        setHasGPU(z);
        if (z) {
            this.gpuData = CUDAMemoryManager.getPointer(this.dataLength);
            JCuda.cudaMemcpy(this.gpuData, Pointer.to(fArr), this.dataLength * 4, 1);
            JCuda.cudaDeviceSynchronize();
        }
    }

    public Tensor(int i, int i2, int i3, int i4, float[] fArr, boolean z, Graph graph) {
        this.number = 0;
        this.channel = 0;
        this.height = 0;
        this.width = 0;
        this.dataLength = 0;
        this.g = graph;
        this.number = i;
        this.channel = i2;
        this.height = i3;
        this.width = i4;
        this.dataLength = i * i2 * i3 * i4;
        this.data = fArr;
        setHasGPU(z);
        if (z) {
            this.gpuData = CUDAMemoryManager.getPointer(this.dataLength);
            JCuda.cudaMemcpy(this.gpuData, Pointer.to(fArr), this.dataLength * 4, 1);
            JCuda.cudaDeviceSynchronize();
        }
    }

    public Tensor(int i, int i2, int i3, int i4, int i5, boolean z) {
        this.number = 0;
        this.channel = 0;
        this.height = 0;
        this.width = 0;
        this.dataLength = 0;
        this.number = i;
        this.channel = i2;
        this.height = i3;
        this.width = i4;
        this.dataLength = i * i2 * i3 * i4;
        this.data = MatrixUtils.val(this.dataLength, i5);
        setHasGPU(z);
        if (z) {
            hostToDevice();
        }
    }

    public Tensor(int i, int i2, int i3, int i4, int i5, boolean z, Graph graph) {
        this.number = 0;
        this.channel = 0;
        this.height = 0;
        this.width = 0;
        this.dataLength = 0;
        this.g = graph;
        this.number = i;
        this.channel = i2;
        this.height = i3;
        this.width = i4;
        this.dataLength = i * i2 * i3 * i4;
        this.data = MatrixUtils.val(this.dataLength, i5);
        setHasGPU(z);
        if (z) {
            hostToDevice();
        }
    }

    public static Tensor createTensor(Tensor tensor, int i, int i2, int i3, int i4, float[] fArr, boolean z) {
        if (tensor == null) {
            tensor = new Tensor(i, i2, i3, i4, fArr, z);
        } else {
            tensor.resize(i, i2, i3, i4, fArr);
        }
        return tensor;
    }

    public static Tensor createTensor(Tensor tensor, int i, int i2, int i3, int i4, boolean z) {
        if (tensor == null) {
            tensor = new Tensor(i, i2, i3, i4, z);
        } else {
            tensor.resize(i, i2, i3, i4);
            tensor.orgShape = new int[]{i, i2, i3, i4};
        }
        return tensor;
    }

    public void resize(int i, int i2, int i3, int i4) {
        this.number = i;
        this.channel = i2;
        this.height = i3;
        this.width = i4;
        this.dataLength = i * i2 * i3 * i4;
        this.data = new float[this.dataLength];
        if (this.hasGPU) {
            if (this.gpuData != null) {
                CUDAMemoryManager.free(this.gpuData);
            }
            this.gpuData = CUDAMemoryManager.getPointer(this.dataLength);
        }
    }

    public void resize(int i, int i2, int i3, int i4, float[] fArr) {
        this.number = i;
        this.channel = i2;
        this.height = i3;
        this.width = i4;
        this.dataLength = i * i2 * i3 * i4;
        this.data = fArr;
        if (this.hasGPU) {
            if (this.gpuData != null) {
                CUDAMemoryManager.free(this.gpuData);
            }
            this.gpuData = CUDAMemoryManager.getPointer(this.dataLength);
            JCuda.cudaMemcpy(this.gpuData, Pointer.to(fArr), this.dataLength * 4, 1);
            JCuda.cudaDeviceSynchronize();
        }
    }

    public void copy(int i, float[] fArr) {
        if (i >= this.number) {
            throw new RuntimeException("获取数据失败[下标超出长度].");
        }
        System.arraycopy(this.data, i * this.channel * this.height * this.width, fArr, 0, this.channel * this.height * this.width);
    }

    public Tensor view(int i, int i2, int i3, int i4) {
        this.number = i;
        this.channel = i2;
        this.height = i3;
        this.width = i4;
        return this;
    }

    public Tensor view(int[] iArr) {
        this.number = iArr[0];
        this.channel = iArr[1];
        this.height = iArr[2];
        this.width = iArr[3];
        return this;
    }

    public Tensor viewOrg() {
        this.number = this.orgShape[0];
        this.channel = this.orgShape[1];
        this.height = this.orgShape[2];
        this.width = this.orgShape[3];
        return this;
    }

    public int[] shape() {
        return new int[]{this.number, this.channel, this.height, this.width};
    }

    public void showShape() {
        System.out.println(JsonUtils.toJson(shape()));
    }

    public int getNumber() {
        return this.number;
    }

    public void setNumber(int i) {
        this.number = i;
    }

    public int getChannel() {
        return this.channel;
    }

    public void setChannel(int i) {
        this.channel = i;
    }

    public int getHeight() {
        return this.height;
    }

    public void setHeight(int i) {
        this.height = i;
    }

    public int getWidth() {
        return this.width;
    }

    public void setWidth(int i) {
        this.width = i;
    }

    public int getDataLength() {
        return this.number * this.channel * this.height * this.width;
    }

    public void setDataLength(int i) {
        this.dataLength = i;
    }

    public int getOnceSize() {
        return this.channel * this.height * this.width;
    }

    public float[] getData() {
        return this.data;
    }

    public void setData(float[] fArr) {
        this.data = fArr;
        if (isHasGPU()) {
            hostToDevice();
        }
    }

    public void copyData(float[] fArr) {
        System.arraycopy(fArr, 0, this.data, 0, fArr.length);
        if (isHasGPU()) {
            hostToDevice();
        }
    }

    public float getByIndex(int i, int i2, int i3, int i4) {
        return this.data[(i * this.channel * this.height * this.width) + (i2 * this.height * this.width) + (i3 * this.width) + i4];
    }

    public float[] getByNumber(int i) {
        System.arraycopy(this.data, i * this.channel * this.height * this.width, getOnce(), 0, this.channel * this.height * this.width);
        return this.once;
    }

    public void setByNumber(int i, float[] fArr) {
        System.arraycopy(fArr, 0, this.data, i * this.channel * this.height * this.width, this.channel * this.height * this.width);
    }

    public void getByNumber(int i, float[] fArr) {
        if (fArr == null || fArr.length != this.channel * this.height * this.width) {
            fArr = new float[this.channel * this.height * this.width];
        }
        System.arraycopy(this.data, i * this.channel * this.height * this.width, fArr, 0, this.channel * this.height * this.width);
    }

    public float[] getByNumberAndChannel(int i, int i2) {
        System.arraycopy(this.data, (i * this.channel * this.height * this.width) + (i2 * this.height * this.width), getOnce(), 0, this.height * this.width);
        return this.once;
    }

    public void setByNumberAndChannel(int i, int i2, float[] fArr) {
        System.arraycopy(fArr, 0, this.data, (i * this.channel * this.height * this.width) + (i2 * this.height * this.width), this.height * this.width);
    }

    public void getByNumberAndChannel(int i, int i2, float[] fArr) {
        if (fArr == null || fArr.length != this.height * this.width) {
            fArr = new float[this.height * this.width];
        }
        System.arraycopy(this.data, (i * this.channel * this.height * this.width) + (i2 * this.height * this.width), fArr, 0, this.height * this.width);
    }

    public void clear() {
        for (int i = 0; i < this.dataLength; i++) {
            this.data[i] = 0.0f;
        }
    }

    public void val_cpu(float f) {
        for (int i = 0; i < this.dataLength; i++) {
            this.data[i] = f;
        }
    }

    public void clear(int i, int i2, int i3, int i4) {
        this.number = i;
        this.channel = i2;
        this.height = i3;
        this.width = i4;
        this.dataLength = i * i2 * i3 * i4;
        this.data = new float[this.dataLength];
    }

    public Pointer getGpuData() {
        return this.gpuData;
    }

    public void setGpuData(Pointer pointer) {
        this.gpuData = pointer;
    }

    public float[] syncHost() {
        JCuda.cudaMemcpy(Pointer.to(this.data), this.gpuData, this.dataLength * 4, 2);
        return this.data;
    }

    public void hostToDevice() {
        if (this.hasGPU) {
            if (this.gpuData == null) {
                this.gpuData = CUDAMemoryManager.getPointer(this.dataLength);
            }
            JCuda.cudaMemcpy(this.gpuData, Pointer.to(this.data), this.dataLength * 4, 1);
            JCuda.cudaDeviceSynchronize();
        }
    }

    public void freeGPU() {
        if (this.gpuData != null) {
            JCuda.cudaFree(this.gpuData);
            this.gpuData = CUDAMemoryManager.getPointer(this.dataLength);
        }
    }

    public void showDM() {
        syncHost();
        System.out.println(JsonUtils.toJson(this.data));
    }

    public void showDMByNumber(int i) {
        syncHost();
        System.out.println(JsonUtils.toJson(getByNumber(i)));
    }

    public void showDMByOffset(int i, int i2) {
        syncHost();
        System.out.println(JsonUtils.toJson(getByNumber(this.number)));
    }

    public void showDM(int i) {
        syncHost();
        System.out.println(this.data[i]);
    }

    public boolean checkDM() {
        for (float f : syncHost()) {
            if (f > 0.0f) {
                return true;
            }
        }
        return false;
    }

    public void clearGPU(cudaStream_t cudastream_t) {
        checkCUDA(JCuda.cudaMemsetAsync(this.gpuData, 0, this.dataLength * 4, cudastream_t));
    }

    public void clearGPU() {
        if (this.gpuData != null) {
            checkCUDA(JCuda.cudaMemset(this.gpuData, 0, this.dataLength * 4));
        }
    }

    public void valueGPU(int i) {
        if (this.gpuData != null) {
            checkCUDA(JCuda.cudaMemset(this.gpuData, i, this.dataLength * 4));
            JCuda.cudaDeviceSynchronize();
        }
    }

    public void checkCUDA(int i) {
        if (i != 0) {
            throw new RuntimeException(cudaError.stringFor(i));
        }
    }

    public boolean isHasGPU() {
        return this.hasGPU;
    }

    public void setHasGPU(boolean z) {
        this.hasGPU = z;
    }

    public boolean isRequiresGrad() {
        return this.requiresGrad;
    }

    public void setRequiresGrad(boolean z) {
        this.requiresGrad = z;
        if (this.requiresGrad) {
            getGrad();
        }
    }

    public Tensor getGrad() {
        if (this.grad == null) {
            this.grad = new Tensor(this.number, this.channel, this.height, this.width, this.hasGPU);
        }
        return this.grad;
    }

    public Tensor getGrad(float[] fArr) {
        if (this.grad == null) {
            this.grad = new Tensor(this.number, this.channel, this.height, this.width, fArr, this.hasGPU);
        }
        return this.grad;
    }

    public void setGrad(Tensor tensor) {
        this.grad = tensor;
    }

    public void setGrad(float[] fArr) {
        if (this.grad == null) {
            this.grad = new Tensor(this.number, this.channel, this.height, this.width, fArr, this.hasGPU);
        } else {
            this.grad.data = fArr;
            this.grad.hostToDevice();
        }
    }

    public void setGrad(Tensor tensor, int[] iArr) {
        if (this.grad == null) {
            this.grad = new Tensor(this.number, this.channel, this.height, this.width, this.hasGPU);
        }
        int i = iArr[0];
        int i2 = iArr[1];
        int i3 = iArr[2];
        switch (i) {
            case 0:
                if (isHasGPU()) {
                    OPKernel.getInstance().copy_number_gpu(this.grad, tensor, i2, 1);
                    return;
                } else {
                    setGradByNumber(tensor.data, i2, i3);
                    return;
                }
            case 1:
                if (isHasGPU()) {
                    OPKernel.getInstance().copy_channel_gpu(this.grad, tensor, i2, 1);
                    return;
                } else {
                    setGradByChannel(tensor.data, i2, i3);
                    return;
                }
            default:
                return;
        }
    }

    public void zeroGrad() {
        if (this.grad != null) {
            this.grad.fill(0.0f);
        }
    }

    public void random() {
        RandomUtils.gaussianRandom(this.data, 1.0f);
        if (isHasGPU()) {
            hostToDevice();
        }
    }

    public boolean isZero() {
        return isHasGPU() ? MatrixUtils.isZero(syncHost()) : MatrixUtils.isZero(this.data);
    }

    public Tensor add(Tensor tensor) {
        return this.g.OP(OPType.add, this, tensor);
    }

    public Tensor add(float f) {
        return this.g.OP(OPType.add, this, f);
    }

    public Tensor sub(Tensor tensor) {
        return this.g.OP(OPType.subtraction, this, tensor);
    }

    public Tensor sub(float f) {
        return this.g.OP(OPType.subtraction, this, f);
    }

    public Tensor scalarSub(float f) {
        return this.g.OP(OPType.scalarSubtraction, this, f);
    }

    public Tensor mul(Tensor tensor) {
        return this.g.OP(OPType.multiplication, this, tensor);
    }

    public Tensor mul(float f) {
        return this.g.OP(OPType.multiplication, this, f);
    }

    public Tensor div(Tensor tensor) {
        return this.g.OP(OPType.division, this, tensor);
    }

    public Tensor div(float f) {
        return this.g.OP(OPType.division, this, f);
    }

    public Tensor scalarDiv(float f) {
        return this.g.OP(OPType.scalarDivision, this, f);
    }

    public Tensor maximum(Tensor tensor) {
        return this.g.OP(OPType.maximum, this, tensor);
    }

    public Tensor minimum(Tensor tensor) {
        return this.g.OP(OPType.minimum, this, tensor);
    }

    public Tensor dot(Tensor tensor) {
        return this.g.OP(OPType.dot, this, tensor);
    }

    public Tensor log() {
        return this.g.OP(OPType.log, this);
    }

    public Tensor transpose() {
        return this.g.OP(OPType.transpose, this);
    }

    public Tensor pow() {
        return this.g.OP(OPType.pow, this, 2.0f);
    }

    public float norm() {
        getTmpOnce().valueGPU(0);
        TensorOP.pow(this, 2.0f, getTmp());
        TensorOP.sum(getTmp(), getTmpOnce(), 0);
        TensorOP.sqrt(getTmpOnce(), getTmpOnce());
        return getTmpOnce().syncHost()[0];
    }

    public Tensor pow(float f) {
        return this.g.OP(OPType.pow, this, f);
    }

    public Tensor sin() {
        return this.g.OP(OPType.sin, this);
    }

    public Tensor cos() {
        return this.g.OP(OPType.cos, this);
    }

    public Tensor tan() {
        return this.g.OP(OPType.tan, this);
    }

    public Tensor atan() {
        return this.g.OP(OPType.atan, this);
    }

    public Tensor exp() {
        return this.g.OP(OPType.exp, this);
    }

    public Tensor sum(int i) {
        return this.g.OP(OPType.sum, this, new int[]{i});
    }

    public Tensor max(int i) {
        return this.g.OP(OPType.max, this, new int[]{i});
    }

    public Tensor clamp(float f, float f2) {
        return this.g.OP(OPType.clamp, this, f, f2);
    }

    public Tensor get(int[] iArr) {
        return this.g.OP(OPType.get, this, iArr);
    }

    public Tensor set(Tensor tensor, int[] iArr) {
        return this.g.OP(OPType.set, this, tensor, iArr);
    }

    public Tensor get(int i, int i2, int i3) {
        return this.g.OP(OPType.get, this, new int[]{i, i2, i3});
    }

    public Tensor set(Tensor tensor, int i, int i2) {
        return this.g.OP(OPType.set, this, tensor, new int[]{i, i2});
    }

    public void setGradByNumber(float[] fArr, int i, int i2) {
        if (!$assertionsDisabled && this.number < (i + i2) - 1) {
            throw new AssertionError();
        }
        System.arraycopy(fArr, 0, this.grad.data, i * this.channel * this.height * this.width, fArr.length);
    }

    public void setGradByChannel(float[] fArr, int i, int i2) {
        if (!$assertionsDisabled && this.channel < (i + i2) - 1) {
            throw new AssertionError();
        }
        int i3 = this.height * this.width;
        for (int i4 = 0; i4 < this.number; i4++) {
            System.arraycopy(fArr, i4 * i2 * i3, this.grad.data, (i4 * this.channel * i3) + (i * i3), i2 * i3);
        }
    }

    public void fill(float f) {
        if (isHasGPU()) {
            OPKernel.getInstance().fill_gpu(this, f);
        } else {
            MatrixUtils.val(this.data, f);
        }
    }

    public float[] getOnce() {
        if (this.once == null || this.once.length != this.channel * this.height * this.width) {
            this.once = new float[this.channel * this.height * this.width];
        }
        return this.once;
    }

    public Graph getG() {
        return this.g;
    }

    public void setG(Graph graph) {
        this.g = graph;
    }

    public void uniform(float f, float f2) {
        for (int i = 0; i < this.dataLength; i++) {
            this.data[i] = RandomUtils.uniformFloat(f, f2);
        }
        if (isHasGPU()) {
            hostToDevice();
        }
    }

    static {
        $assertionsDisabled = !Tensor.class.desiredAssertionStatus();
    }
}
