package edu.uci.jforestsx.learning.boosting;

import edu.uci.jforestsx.config.TrainingConfig;
import edu.uci.jforestsx.eval.EvaluationMetric;
import edu.uci.jforestsx.learning.LearningModule;
import edu.uci.jforestsx.learning.LearningUtils;
import edu.uci.jforestsx.learning.trees.Ensemble;
import edu.uci.jforestsx.learning.trees.Tree;
import edu.uci.jforestsx.learning.trees.TreeLeafInstances;
import edu.uci.jforestsx.learning.trees.regression.RegressionTree;
import edu.uci.jforestsx.sample.Sample;
import edu.uci.jforestsx.util.ConfigHolder;
import java.util.Arrays;
import java.util.Random;

/* loaded from: input_file:edu/uci/jforestsx/learning/boosting/GradientBoosting.class */
public class GradientBoosting extends LearningModule {
    protected double[] trainPredictions;
    protected double[] validPredictions;
    protected double[] residuals;
    protected int numInstances;
    private int numSubModules;
    protected double learningRate;
    protected double samplingRate;
    protected double earlyStoppingTolerance;
    protected Sample curTrainSet;
    protected Sample curValidSet;
    protected int curIteration;
    protected double bestValidationMeasurement;
    protected boolean printIntermediateValidMeasurements;
    protected EvaluationMetric evaluationMetric;
    protected Random rnd;

    public GradientBoosting(String str) {
        super(str);
    }

    public GradientBoosting() {
        super("GradientBoosting");
    }

    public void init(ConfigHolder configHolder, int i, int i2, EvaluationMetric evaluationMetric) throws Exception {
        this.evaluationMetric = evaluationMetric;
        GradientBoostingConfig gradientBoostingConfig = (GradientBoostingConfig) configHolder.getConfig(GradientBoostingConfig.class);
        this.numSubModules = gradientBoostingConfig.numTrees;
        this.learningRate = gradientBoostingConfig.learningRate;
        this.samplingRate = gradientBoostingConfig.samplingRate;
        this.earlyStoppingTolerance = gradientBoostingConfig.earlyStoppingTolerance;
        this.trainPredictions = new double[i];
        this.residuals = new double[i];
        this.validPredictions = new double[i2];
        this.printIntermediateValidMeasurements = ((TrainingConfig) configHolder.getConfig(TrainingConfig.class)).printIntermediateValidMeasurements;
        this.rnd = new Random(r0.randomSeed);
    }

    protected void preprocess() {
        Arrays.fill(this.trainPredictions, 0, this.curTrainSet.size, 0.0d);
        Arrays.fill(this.validPredictions, 0, this.curValidSet.size, 0.0d);
    }

    @Override // edu.uci.jforestsx.learning.LearningModule
    public Ensemble learn(Sample sample, Sample sample2) throws Exception {
        this.curTrainSet = sample;
        this.curValidSet = sample2;
        preprocess();
        Ensemble ensemble = new Ensemble();
        this.bestValidationMeasurement = Double.NaN;
        int i = 0;
        int i2 = 0;
        int[] iArr = new int[this.numSubModules];
        this.subLearner.setTreeWeight(this.treeWeight);
        this.curIteration = 1;
        while (this.curIteration <= this.numSubModules) {
            Ensemble learn = this.subLearner.learn(getSubLearnerSample(), sample2);
            if (learn == null) {
                break;
            }
            for (int i3 = 0; i3 < learn.getNumTrees(); i3++) {
                Tree treeAt = learn.getTreeAt(i3);
                ensemble.addTree(treeAt, learn.getWeightAt(i3));
                if (sample2 != null) {
                    LearningUtils.updateScores(sample2, this.validPredictions, (RegressionTree) treeAt, 1.0d);
                }
            }
            iArr[this.curIteration - 1] = ensemble.getNumTrees();
            if (sample2 == null) {
                i = this.curIteration;
            } else {
                double validMeasurement = getValidMeasurement();
                if (this.evaluationMetric.isFirstBetter(validMeasurement, this.bestValidationMeasurement, this.earlyStoppingTolerance)) {
                    i = this.curIteration;
                    if (this.evaluationMetric.isFirstBetter(validMeasurement, this.bestValidationMeasurement, 0.0d)) {
                        this.bestValidationMeasurement = validMeasurement;
                        i2 = this.curIteration;
                    }
                }
                if (this.curIteration - i2 > 100) {
                    break;
                }
                if (this.printIntermediateValidMeasurements) {
                    printTrainAndValidMeasurement(this.curIteration, validMeasurement, getTrainMeasurement(), this.evaluationMetric);
                }
            }
            onIterationEnd();
            this.curIteration++;
        }
        if (i > 0) {
            ensemble.removeLastTrees(ensemble.getNumTrees() - iArr[i - 1]);
        }
        onLearningEnd();
        return ensemble;
    }

    @Override // edu.uci.jforestsx.learning.LearningModule
    public double getValidationMeasurement() {
        return this.bestValidationMeasurement;
    }

    protected double getValidMeasurement() throws Exception {
        return this.curValidSet.evaluate(this.validPredictions, this.evaluationMetric);
    }

    protected double getTrainMeasurement() throws Exception {
        return this.curTrainSet.evaluate(this.trainPredictions, this.evaluationMetric);
    }

    protected Sample getSubLearnerSample() {
        for (int i = 0; i < this.curTrainSet.size; i++) {
            this.residuals[i] = this.curTrainSet.targets[i] - this.trainPredictions[i];
        }
        Sample clone = this.curTrainSet.getClone();
        clone.targets = this.residuals;
        return clone.getRandomSubSample(this.samplingRate, this.rnd);
    }

    protected void adjustOutputs(Tree tree, TreeLeafInstances treeLeafInstances) {
        ((RegressionTree) tree).multiplyLeafOutputs(this.learningRate);
    }

    @Override // edu.uci.jforestsx.learning.LearningModule
    public void postProcess(Tree tree, TreeLeafInstances treeLeafInstances) {
        adjustOutputs(tree, treeLeafInstances);
        LearningUtils.updateScores(this.curTrainSet, this.trainPredictions, (RegressionTree) tree, 1.0d);
        postProcessScores();
    }

    protected void postProcessScores() {
    }
}
