package greycat.ml.neuralnet.loss;

import greycat.ml.neuralnet.process.ExMatrix;
import greycat.struct.DMatrix;
import greycat.struct.matrix.MatrixOps;
import greycat.struct.matrix.VolatileDMatrix;

/* loaded from: input_file:greycat/ml/neuralnet/loss/SumOfSquares.class */
class SumOfSquares implements Loss {
    private static SumOfSquares static_unit = null;

    SumOfSquares() {
    }

    public static SumOfSquares instance() {
        if (static_unit == null) {
            static_unit = new SumOfSquares();
        }
        return static_unit;
    }

    @Override // greycat.ml.neuralnet.loss.Loss
    public void backward(ExMatrix exMatrix, ExMatrix exMatrix2) {
        int length = exMatrix2.length();
        for (int i = 0; i < length; i++) {
            exMatrix.getDw().unsafeSet(i, exMatrix.getDw().unsafeGet(i) + (exMatrix.unsafeGet(i) - exMatrix2.unsafeGet(i)));
        }
    }

    @Override // greycat.ml.neuralnet.loss.Loss
    public DMatrix forward(ExMatrix exMatrix, ExMatrix exMatrix2) {
        MatrixOps.testDim(exMatrix, exMatrix2);
        VolatileDMatrix empty = VolatileDMatrix.empty(exMatrix.rows(), exMatrix.columns());
        int length = exMatrix2.length();
        for (int i = 0; i < length; i++) {
            double unsafeGet = exMatrix.unsafeGet(i) - exMatrix2.unsafeGet(i);
            empty.unsafeSet(i, 0.5d * unsafeGet * unsafeGet);
        }
        return empty;
    }
}
