package ai.libs.jaicore.ml.regression.loss;

import ai.libs.jaicore.ml.regression.loss.dataset.ARegressionMeasure;
import ai.libs.jaicore.ml.regression.loss.dataset.MeanAbsoluteError;
import ai.libs.jaicore.ml.regression.loss.dataset.MeanAbsolutePercentageError;
import ai.libs.jaicore.ml.regression.loss.dataset.MeanSquaredError;
import ai.libs.jaicore.ml.regression.loss.dataset.RootMeanSquaredError;
import java.util.List;
import java.util.stream.IntStream;
import org.api4.java.ai.ml.core.evaluation.IPredictionAndGroundTruthTable;
import org.api4.java.ai.ml.core.evaluation.supervised.loss.IDeterministicPredictionPerformanceMeasure;
import org.api4.java.ai.ml.regression.evaluation.IRegressionPrediction;

/* loaded from: input_file:ai/libs/jaicore/ml/regression/loss/ERegressionPerformanceMeasure.class */
public enum ERegressionPerformanceMeasure implements IDeterministicPredictionPerformanceMeasure<Double, IRegressionPrediction> {
    MSE(new MeanSquaredError()),
    RMSE(new RootMeanSquaredError()),
    RMSLE(new ARegressionMeasure() { // from class: ai.libs.jaicore.ml.regression.loss.dataset.RootMeanSquaredLogarithmError
        @Override // ai.libs.jaicore.ml.classification.loss.dataset.APredictionPerformanceMeasure
        public double score(List<? extends Double> list, List<? extends IRegressionPrediction> list2) {
            checkConsistency(list, list2);
            return Math.sqrt(IntStream.range(0, list.size()).mapToDouble(i -> {
                return ((IRegressionPrediction) list2.get(i)).getPrediction().doubleValue() - ((Double) list.get(i)).doubleValue();
            }).map(Math::log).map(d -> {
                return Math.pow(d, 2.0d);
            }).average().getAsDouble());
        }
    }),
    MAE(new MeanAbsoluteError()),
    MAPE(new MeanAbsolutePercentageError()),
    R2(new ARegressionMeasure() { // from class: ai.libs.jaicore.ml.regression.loss.dataset.R2
        @Override // ai.libs.jaicore.ml.classification.loss.dataset.APredictionPerformanceMeasure
        public double score(List<? extends Double> list, List<? extends IRegressionPrediction> list2) {
            checkConsistency(list, list2);
            double asDouble = list.stream().mapToDouble(d -> {
                return d.doubleValue();
            }).average().getAsDouble();
            double d2 = 0.0d;
            double d3 = 0.0d;
            for (int i = 0; i < list2.size(); i++) {
                d2 += Math.pow(list2.get(i).getPrediction().doubleValue() - asDouble, 2.0d);
                d3 += Math.pow(list.get(i).doubleValue() - asDouble, 2.0d);
            }
            if (d3 == 0.0d) {
                throw new IllegalStateException("Sum of expected squares must not be null.");
            }
            return d2 / d3;
        }
    });

    private final IDeterministicPredictionPerformanceMeasure<Double, IRegressionPrediction> measure;

    ERegressionPerformanceMeasure(IDeterministicPredictionPerformanceMeasure iDeterministicPredictionPerformanceMeasure) {
        this.measure = iDeterministicPredictionPerformanceMeasure;
    }

    public double loss(List<? extends Double> list, List<? extends IRegressionPrediction> list2) {
        return this.measure.loss(list, list2);
    }

    public double score(List<? extends Double> list, List<? extends IRegressionPrediction> list2) {
        return this.measure.score(list, list2);
    }

    public double loss(IPredictionAndGroundTruthTable<? extends Double, ? extends IRegressionPrediction> iPredictionAndGroundTruthTable) {
        return this.measure.loss(iPredictionAndGroundTruthTable);
    }

    public double score(IPredictionAndGroundTruthTable<? extends Double, ? extends IRegressionPrediction> iPredictionAndGroundTruthTable) {
        return this.measure.score(iPredictionAndGroundTruthTable);
    }
}
