package com.omega.engine.loss;

import com.omega.common.data.Tensor;
import com.omega.common.utils.JsonUtils;
import com.omega.common.utils.MatrixOperation;
import com.omega.engine.nn.network.Network;

/* loaded from: input_file:com/omega/engine/loss/LossFunction.class */
public abstract class LossFunction {
    public Network net;
    public LossType lossType;
    public float eta = 1.0E-5f;
    public float[] params;

    public abstract Tensor loss(Tensor tensor, Tensor tensor2);

    public abstract Tensor loss(Tensor tensor, Tensor tensor2, int i);

    public abstract Tensor loss(Tensor tensor, Tensor tensor2, Tensor tensor3);

    public abstract Tensor[] loss(Tensor[] tensorArr, Tensor tensor);

    public abstract Tensor diff(Tensor tensor, Tensor tensor2);

    public abstract Tensor diff(Tensor tensor, Tensor tensor2, int i);

    public abstract Tensor diff(Tensor tensor, Tensor tensor2, Tensor tensor3);

    public abstract Tensor[] diff(Tensor[] tensorArr, Tensor tensor);

    public abstract LossType getLossType();

    public float gradientCheck(Tensor tensor, Tensor tensor2) {
        Tensor loss = loss(new Tensor(tensor.number, tensor.channel, tensor.height, tensor.width, MatrixOperation.add(tensor.data, this.eta), true), tensor2);
        Tensor loss2 = loss(new Tensor(tensor.number, tensor.channel, tensor.height, tensor.width, MatrixOperation.subtraction(tensor.data, this.eta), true), tensor2);
        Tensor diff = diff(tensor, tensor2);
        float[] division = MatrixOperation.division(MatrixOperation.subtraction(loss.data, loss2.data), 2.0f * this.eta);
        System.out.println("diff:" + JsonUtils.toJson(diff.syncHost()));
        System.out.println("gradientCheck:" + JsonUtils.toJson(division));
        return MatrixOperation.sum(MatrixOperation.subtraction(diff.data, division));
    }
}
