package com.omega.engine.ad;

import com.omega.common.data.Tensor;
import com.omega.engine.ad.op.OP;
import com.omega.engine.ad.op.OPType;
import java.io.Serializable;

/* loaded from: input_file:com/omega/engine/ad/Tape.class */
public class Tape implements Serializable {
    private static final long serialVersionUID = 9147342370353517536L;
    private Tensor x;
    private Tensor y;
    private Tensor output;
    private int[] position;
    private OP op;
    private float scalar;
    private float constant;
    private Tensor tmp;
    private boolean sub = false;

    public Tape(OP op, Tensor tensor, Tensor tensor2, float f, float f2, int[] iArr, Graph graph) {
        setX(tensor);
        setY(tensor2);
        if (iArr != null && !op.getOpType().equals(OPType.set)) {
            int i = iArr[0];
            if (!op.getOpType().equals(OPType.sum) && !op.getOpType().equals(OPType.max)) {
                int i2 = iArr[2];
                switch (i) {
                    case 0:
                        setOutput(new Tensor(i2, tensor.channel, tensor.height, tensor.width, tensor.isHasGPU(), graph));
                        break;
                    case 1:
                        setOutput(new Tensor(tensor.number, i2, tensor.height, tensor.width, tensor.isHasGPU(), graph));
                        break;
                }
            } else {
                switch (i) {
                    case 0:
                        setOutput(new Tensor(1, 1, 1, 1, tensor.isHasGPU(), graph));
                        break;
                    case 1:
                        setOutput(new Tensor(tensor.number, 1, 1, 1, tensor.isHasGPU(), graph));
                        break;
                }
            }
        } else if (op.getOpType().equals(OPType.dot)) {
            setOutput(new Tensor(tensor.number, tensor.channel, tensor.height, tensor2.width, tensor.isHasGPU(), graph));
        } else if (op.getOpType().equals(OPType.set)) {
            this.output = tensor;
        } else if (op.getOpType().equals(OPType.transpose)) {
            setOutput(new Tensor(tensor.width, tensor.channel, tensor.height, tensor.number, tensor.isHasGPU(), graph));
        } else {
            setOutput(new Tensor(tensor.number, tensor.channel, tensor.height, tensor.width, tensor.isHasGPU(), graph));
        }
        setOp(op);
        this.scalar = f;
        this.constant = f2;
        setPosition(iArr);
    }

    public OP getOp() {
        return this.op;
    }

    public void setOp(OP op) {
        this.op = op;
    }

    public void zeroGrad() {
        if (getX().isRequiresGrad()) {
            getX().zeroGrad();
        }
        if (getY() != null && getY().isRequiresGrad()) {
            getY().zeroGrad();
        }
        if (getOutput().isRequiresGrad()) {
            getOutput().zeroGrad();
        }
    }

    public Tensor forward() {
        return this.op.forward(this);
    }

    public void backward(Tensor tensor) {
        this.op.backward(tensor, this);
    }

    public void backward() {
        backward(getOutput().getGrad());
    }

    public float getScalar() {
        return this.scalar;
    }

    public void setScalar(float f) {
        this.scalar = f;
    }

    public int[] getPosition() {
        return this.position;
    }

    public Tensor getX() {
        return this.x;
    }

    public void setX(Tensor tensor) {
        this.x = tensor;
    }

    public Tensor getY() {
        return this.y;
    }

    public void setY(Tensor tensor) {
        this.y = tensor;
    }

    public Tensor getOutput() {
        return this.output;
    }

    public void setOutput(Tensor tensor) {
        this.output = tensor;
    }

    public void setPosition(int[] iArr) {
        this.position = iArr;
    }

    public Tensor getTmp() {
        if (this.tmp == null) {
            this.tmp = new Tensor(this.x.number, this.x.channel, this.x.height, this.x.width, this.x.isHasGPU());
        }
        return this.tmp;
    }

    public void setTmp(Tensor tensor) {
        this.tmp = tensor;
    }

    public boolean isSub() {
        return this.sub;
    }

    public void setSub(boolean z) {
        this.sub = z;
    }

    public float getConstant() {
        return this.constant;
    }

    public void setConstant(float f) {
        this.constant = f;
    }
}
