package com.omega.engine.ad.op.data;

import com.omega.common.data.Tensor;
import com.omega.common.utils.MatrixOperation;
import com.omega.engine.ad.Tape;
import com.omega.engine.ad.op.OP;
import com.omega.engine.ad.op.OPType;
import com.omega.engine.ad.op.gpu.OPKernel;

/* loaded from: input_file:com/omega/engine/ad/op/data/GetOP.class */
public class GetOP extends OP {
    private static final long serialVersionUID = 7010180428917414516L;
    public static GetOP op;
    public static final OPType opt;
    static final /* synthetic */ boolean $assertionsDisabled;

    public static GetOP getInstance() {
        if (op == null) {
            op = new GetOP();
            op.setOpType(opt);
        }
        return op;
    }

    @Override // com.omega.engine.ad.op.OP
    public Tensor forward(Tape tape) {
        Tensor x = tape.getX();
        Tensor output = tape.getOutput();
        getByPosition(x, output, tape.getPosition());
        if (x.isRequiresGrad()) {
            output.setRequiresGrad(true);
        }
        return output;
    }

    @Override // com.omega.engine.ad.op.OP
    public void backward(Tensor tensor, Tape tape) {
        Tensor x = tape.getX();
        if (x.isRequiresGrad()) {
            addByPosition(x.getGrad(), tensor, tape.getPosition());
        }
    }

    public void addByPosition(Tensor tensor, Tensor tensor2, int[] iArr) {
        int i = iArr[0];
        int i2 = iArr[1];
        int i3 = iArr[2];
        if (tensor.isHasGPU()) {
            switch (i) {
                case 0:
                    OPKernel.getInstance().add_number_gpu(tensor, tensor2, i2 * i3);
                    return;
                case 1:
                    OPKernel.getInstance().add_channel_gpu(tensor, tensor2, i2 * i3);
                    return;
                default:
                    return;
            }
        }
        MatrixOperation.add(tensor.data, tensor2.data, tensor.getNumber(), tensor.getChannel(), tensor.getHeight(), tensor.getWidth(), iArr);
    }

    public void getByPosition(Tensor tensor, Tensor tensor2, int[] iArr) {
        int i = iArr[0];
        int i2 = iArr[1];
        int i3 = iArr[2];
        switch (i) {
            case 0:
                getByNumber(tensor, tensor2, i2, i3);
                return;
            case 1:
                getByChannel(tensor, tensor2, i2, i3);
                return;
            default:
                return;
        }
    }

    public void getByNumber(Tensor tensor, Tensor tensor2, int i, int i2) {
        if (!$assertionsDisabled && tensor.getNumber() < (i + i2) - 1) {
            throw new AssertionError();
        }
        if (tensor.isHasGPU()) {
            OPKernel.getInstance().copy_number_gpu(tensor, tensor2, i * i2, 0);
        } else {
            System.arraycopy(tensor.data, i * i2 * tensor.channel * tensor.height * tensor.width, tensor2.data, 0, tensor2.dataLength);
        }
    }

    public void getByChannel(Tensor tensor, Tensor tensor2, int i, int i2) {
        if (!$assertionsDisabled && tensor.getChannel() < (i + i2) - 1) {
            throw new AssertionError();
        }
        if (tensor.isHasGPU()) {
            OPKernel.getInstance().copy_channel_gpu(tensor, tensor2, i, 0);
            return;
        }
        int i3 = tensor.height * tensor.width;
        for (int i4 = 0; i4 < tensor.number; i4++) {
            System.arraycopy(tensor.data, (i4 * tensor.channel * i3) + (i * i3), tensor2.data, i4 * i2 * i3, i2 * i3);
        }
    }

    static {
        $assertionsDisabled = !GetOP.class.desiredAssertionStatus();
        op = null;
        opt = OPType.get;
    }
}
