package greycat.ml.neuralnet.loss;

import greycat.ml.neuralnet.process.ExMatrix;
import greycat.struct.DMatrix;
import greycat.struct.matrix.MatrixOps;
import greycat.struct.matrix.VolatileDMatrix;

/* loaded from: input_file:greycat/ml/neuralnet/loss/Softmax.class */
class Softmax implements Loss {
    private static Softmax static_unit = null;

    Softmax() {
    }

    public static Softmax instance() {
        if (static_unit == null) {
            static_unit = new Softmax();
        }
        return static_unit;
    }

    @Override // greycat.ml.neuralnet.loss.Loss
    public void backward(ExMatrix exMatrix, ExMatrix exMatrix2) {
        int targetIndex = getTargetIndex(exMatrix2);
        VolatileDMatrix softmaxProbs = getSoftmaxProbs(exMatrix, 1.0d);
        int length = softmaxProbs.length();
        for (int i = 0; i < length; i++) {
            exMatrix.getDw().unsafeSet(i, softmaxProbs.unsafeGet(i));
        }
        exMatrix.getDw().unsafeSet(targetIndex, exMatrix.getDw().unsafeGet(targetIndex) - 1.0d);
    }

    @Override // greycat.ml.neuralnet.loss.Loss
    public DMatrix forward(ExMatrix exMatrix, ExMatrix exMatrix2) {
        MatrixOps.testDim(exMatrix, exMatrix2);
        double d = -Math.log(getSoftmaxProbs(exMatrix, 1.0d).unsafeGet(getTargetIndex(exMatrix2)));
        return null;
    }

    public static VolatileDMatrix getSoftmaxProbs(ExMatrix exMatrix, double d) {
        VolatileDMatrix empty = VolatileDMatrix.empty(exMatrix.rows(), exMatrix.columns());
        int length = exMatrix.length();
        if (d != 1.0d) {
            for (int i = 0; i < length; i++) {
                exMatrix.unsafeSet(i, exMatrix.unsafeGet(i) / d);
            }
        }
        double d2 = Double.NEGATIVE_INFINITY;
        for (int i2 = 0; i2 < length; i2++) {
            if (exMatrix.unsafeGet(i2) > d2) {
                d2 = exMatrix.unsafeGet(i2);
            }
        }
        double d3 = 0.0d;
        for (int i3 = 0; i3 < length; i3++) {
            empty.unsafeSet(i3, Math.exp(exMatrix.unsafeGet(i3) - d2));
            d3 += empty.unsafeGet(i3);
        }
        for (int i4 = 0; i4 < length; i4++) {
            empty.unsafeSet(i4, empty.unsafeGet(i4) / d3);
        }
        return empty;
    }

    private static int getTargetIndex(ExMatrix exMatrix) {
        int length = exMatrix.length();
        for (int i = 0; i < length; i++) {
            if (exMatrix.unsafeGet(i) == 1.0d) {
                return i;
            }
        }
        throw new RuntimeException("no target index selected");
    }
}
