package com.gengoai.apollo.ml.evaluation;

import com.gengoai.apollo.math.linalg.NDArray;
import com.gengoai.apollo.ml.DataSet;
import com.gengoai.apollo.ml.Datum;
import com.gengoai.apollo.ml.Split;
import com.gengoai.apollo.ml.model.Model;
import com.gengoai.string.TableFormatter;
import java.io.PrintStream;
import java.io.Serializable;
import java.util.Arrays;
import java.util.Iterator;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.logging.Logger;
import lombok.NonNull;
import org.apache.mahout.math.list.DoubleArrayList;

/* loaded from: input_file:com/gengoai/apollo/ml/evaluation/RegressionEvaluation.class */
public class RegressionEvaluation implements Evaluation, Serializable {
    private static final Logger log = Logger.getLogger(RegressionEvaluation.class.getName());
    private static final long serialVersionUID = 1;
    private final String inputSource;
    private final String predictedSource;
    private DoubleArrayList gold = new DoubleArrayList();
    private double p = 0.0d;
    private DoubleArrayList predicted = new DoubleArrayList();

    public static RegressionEvaluation crossValidation(@NonNull DataSet dataSet, @NonNull Model model, @NonNull String str, @NonNull String str2, int i) {
        if (dataSet == null) {
            throw new NullPointerException("dataset is marked non-null but is null");
        }
        if (model == null) {
            throw new NullPointerException("regression is marked non-null but is null");
        }
        if (str == null) {
            throw new NullPointerException("inputSource is marked non-null but is null");
        }
        if (str2 == null) {
            throw new NullPointerException("predictedSource is marked non-null but is null");
        }
        RegressionEvaluation regressionEvaluation = new RegressionEvaluation(str, str2);
        new AtomicInteger(0);
        for (Split split : Split.createFolds(dataSet.shuffle(), i)) {
            model.estimate(split.train);
            regressionEvaluation.evaluate(model, split.test);
        }
        return regressionEvaluation;
    }

    public static RegressionEvaluation evaluate(@NonNull Model model, @NonNull DataSet dataSet, @NonNull String str, @NonNull String str2) {
        if (model == null) {
            throw new NullPointerException("model is marked non-null but is null");
        }
        if (dataSet == null) {
            throw new NullPointerException("testingData is marked non-null but is null");
        }
        if (str == null) {
            throw new NullPointerException("inputSource is marked non-null but is null");
        }
        if (str2 == null) {
            throw new NullPointerException("predictedSource is marked non-null but is null");
        }
        RegressionEvaluation regressionEvaluation = new RegressionEvaluation(str, str2);
        regressionEvaluation.evaluate(model, dataSet);
        return regressionEvaluation;
    }

    public RegressionEvaluation(String str, @NonNull String str2) {
        if (str2 == null) {
            throw new NullPointerException("predictedSource is marked non-null but is null");
        }
        this.inputSource = str;
        this.predictedSource = str2;
    }

    public double adjustedR2() {
        double r2 = r2();
        return r2 - (((1.0d - r2) * this.p) / ((this.gold.size() - this.p) - 1.0d));
    }

    public void entry(double d, @NonNull NDArray nDArray) {
        if (nDArray == null) {
            throw new NullPointerException("predicted is marked non-null but is null");
        }
        this.gold.add(d);
        this.predicted.add(nDArray.scalar());
    }

    @Override // com.gengoai.apollo.ml.evaluation.Evaluation
    public void evaluate(@NonNull Model model, @NonNull DataSet dataSet) {
        if (model == null) {
            throw new NullPointerException("model is marked non-null but is null");
        }
        if (dataSet == null) {
            throw new NullPointerException("dataset is marked non-null but is null");
        }
        Iterator<Datum> it = dataSet.iterator();
        while (it.hasNext()) {
            Datum next = it.next();
            this.p = Math.max(this.p, next.get(this.inputSource).asNDArray().length());
            this.gold.add(next.get(this.predictedSource).asNDArray().scalar());
            this.predicted.add(model.transform(next).get(this.predictedSource).asNDArray().scalar());
        }
    }

    public double meanSquaredError() {
        return squaredError() / this.gold.size();
    }

    public void merge(RegressionEvaluation regressionEvaluation) {
        this.gold.addAllOf(regressionEvaluation.gold);
        this.predicted.addAllOf(regressionEvaluation.predicted);
    }

    @Override // com.gengoai.apollo.ml.evaluation.Evaluation
    public void output(@NonNull PrintStream printStream) {
        if (printStream == null) {
            throw new NullPointerException("printStream is marked non-null but is null");
        }
        TableFormatter tableFormatter = new TableFormatter();
        tableFormatter.title("Regression Metrics");
        tableFormatter.header(Arrays.asList("Metric", "Value"));
        tableFormatter.content(Arrays.asList("RMSE", Double.valueOf(rootMeanSquaredError())));
        tableFormatter.content(Arrays.asList("R^2", Double.valueOf(r2())));
        tableFormatter.content(Arrays.asList("Adj. R^2", Double.valueOf(adjustedR2())));
        tableFormatter.print(printStream);
    }

    public double r2() {
        double orElse = Arrays.stream(this.gold.elements()).average().orElse(0.0d);
        return 1.0d - (squaredError() / Arrays.stream(this.gold.elements()).map(d -> {
            return Math.pow(d - orElse, 2.0d);
        }).sum());
    }

    public double rootMeanSquaredError() {
        return Math.sqrt(meanSquaredError());
    }

    public void setP(double d) {
        this.p = d;
    }

    public double squaredError() {
        double d = 0.0d;
        for (int i = 0; i < this.gold.size(); i++) {
            d += Math.pow(this.predicted.get(i) - this.gold.get(i), 2.0d);
        }
        return d;
    }
}
