package greycat.ml.neuralnet.loss;

import greycat.ml.neuralnet.process.ExMatrix;
import greycat.struct.DMatrix;
import greycat.struct.matrix.MatrixOps;
import greycat.struct.matrix.VolatileDMatrix;

/* loaded from: input_file:greycat/ml/neuralnet/loss/ArgMax.class */
class ArgMax implements Loss {
    private static ArgMax static_unit = null;

    ArgMax() {
    }

    public static ArgMax instance() {
        if (static_unit == null) {
            static_unit = new ArgMax();
        }
        return static_unit;
    }

    @Override // greycat.ml.neuralnet.loss.Loss
    public void backward(ExMatrix exMatrix, ExMatrix exMatrix2) {
        throw new RuntimeException("not implemented");
    }

    @Override // greycat.ml.neuralnet.loss.Loss
    public DMatrix forward(ExMatrix exMatrix, ExMatrix exMatrix2) {
        MatrixOps.testDim(exMatrix, exMatrix2);
        VolatileDMatrix empty = VolatileDMatrix.empty(1, exMatrix.columns());
        if (exMatrix.rows() == 1) {
            empty.fill(0.0d);
            return empty;
        }
        for (int i = 0; i < exMatrix.columns(); i++) {
            double d = exMatrix.get(0, i);
            double d2 = exMatrix2.get(0, i);
            int i2 = 0;
            int i3 = 0;
            for (int i4 = 1; i4 < exMatrix.rows(); i4++) {
                double d3 = exMatrix.get(i4, i);
                double d4 = exMatrix2.get(i4, i);
                if (d3 > d) {
                    d = d3;
                    i2 = i4;
                }
                if (d4 > d2) {
                    d2 = d4;
                    i3 = i4;
                }
            }
            if (i2 == i3) {
                empty.set(0, i, 0.0d);
            } else {
                empty.set(0, i, 1.0d);
            }
        }
        return empty;
    }
}
