package com.omega.engine.ad.op.data;

import com.omega.common.data.Tensor;
import com.omega.common.utils.MatrixOperation;
import com.omega.engine.ad.Tape;
import com.omega.engine.ad.op.OP;
import com.omega.engine.ad.op.OPType;
import com.omega.engine.ad.op.gpu.OPKernel;

/* loaded from: input_file:com/omega/engine/ad/op/data/SetOP.class */
public class SetOP extends OP {
    private static final long serialVersionUID = 7010180428917414516L;
    public static SetOP op;
    public static final OPType opt;
    static final /* synthetic */ boolean $assertionsDisabled;

    public static SetOP getInstance() {
        if (op == null) {
            op = new SetOP();
            op.setOpType(opt);
        }
        return op;
    }

    @Override // com.omega.engine.ad.op.OP
    public Tensor forward(Tape tape) {
        Tensor x = tape.getX();
        Tensor y = tape.getY();
        setByPosition(x, y, tape.getPosition());
        if (x.isRequiresGrad()) {
            y.setRequiresGrad(true);
        }
        return y;
    }

    @Override // com.omega.engine.ad.op.OP
    public void backward(Tensor tensor, Tape tape) {
        Tensor y = tape.getY();
        if (y.isRequiresGrad()) {
            addByPosition(y.getGrad(), tensor, tape.getPosition());
        }
    }

    public void addByPosition(Tensor tensor, Tensor tensor2, int[] iArr) {
        int i = iArr[0];
        int i2 = iArr[1];
        if (tensor.isHasGPU()) {
            switch (i) {
                case 0:
                    OPKernel.getInstance().axpy_gpu(tensor2, tensor, i2 * tensor.channel * tensor.height * tensor.width, 0);
                    return;
                default:
                    return;
            }
        } else {
            MatrixOperation.add(tensor2.data, tensor.data, tensor.getNumber(), tensor.getChannel(), tensor.getHeight(), tensor.getWidth(), iArr);
        }
    }

    public void setByPosition(Tensor tensor, Tensor tensor2, int[] iArr) {
        int i = iArr[0];
        int i2 = iArr[1];
        switch (i) {
            case 0:
                setByNumber(tensor, tensor2, i2);
                return;
            case 1:
            default:
                return;
        }
    }

    public void setByNumber(Tensor tensor, Tensor tensor2, int i) {
        if (!$assertionsDisabled && tensor.getNumber() < i - 1) {
            throw new AssertionError();
        }
        if (tensor.isHasGPU()) {
            OPKernel.getInstance().copy_gpu(tensor2, tensor, 0, i * tensor2.channel * tensor2.height * tensor2.width);
        } else {
            System.arraycopy(tensor2.data, 0, tensor.data, i * tensor2.channel * tensor2.height * tensor2.width, tensor2.dataLength);
        }
    }

    static {
        $assertionsDisabled = !SetOP.class.desiredAssertionStatus();
        op = null;
        opt = OPType.set;
    }
}
