package edu.uci.jforestsx.learning.trees;

import edu.uci.jforestsx.config.TrainingConfig;
import edu.uci.jforestsx.dataset.Dataset;
import edu.uci.jforestsx.dataset.Feature;
import edu.uci.jforestsx.dataset.Histogram;
import edu.uci.jforestsx.learning.LearningModule;
import edu.uci.jforestsx.sample.Sample;
import edu.uci.jforestsx.util.ConfigHolder;
import edu.uci.jforestsx.util.concurrency.BlockingThreadPoolExecutor;
import edu.uci.jforestsx.util.concurrency.TaskCollection;
import edu.uci.jforestsx.util.concurrency.TaskItem;
import java.util.Arrays;
import java.util.Random;

/* loaded from: input_file:edu/uci/jforestsx/learning/trees/TreeLearner.class */
public abstract class TreeLearner extends LearningModule {
    protected double featureSamplingPerSplit;
    protected boolean randomizedSplits;
    protected double minInstancePercentagePerLeaf;
    protected int minInstancesPerLeaf;
    protected int maxLeaves;
    protected boolean[] selectedFeatures;
    protected boolean[] featuresToDiscard;
    protected Random rand;
    private TreeLeafInstances trainTreeLeafInstances;
    protected Sample curTrainSet;
    private Histogram[][] perNodeHistograms;
    protected TreeSplit[] perLeafBestSplit;
    private int parentNodeIndex;
    private int smallerChildIndex;
    private int largerChildIndex;
    private CandidateSplitsForLeaf candidateSplitsForSmallerChild;
    private CandidateSplitsForLeaf candidateSplitsForLargerChild;
    private TaskCollection<BestThresholdForFeatureFinder> leafCandidateSplitsCalculationTask;
    private static final int ROOT_LEAF_INDEX = 0;

    /* loaded from: input_file:edu/uci/jforestsx/learning/trees/TreeLearner$BestThresholdForFeatureFinder.class */
    private class BestThresholdForFeatureFinder extends TaskItem {
        private int beginIdx;
        private int endIdx;

        public BestThresholdForFeatureFinder(int i, int i2) {
            this.beginIdx = i;
            this.endIdx = i2;
        }

        @Override // java.lang.Runnable
        public void run() {
            for (int i = this.beginIdx; i < this.endIdx; i++) {
                if (TreeLearner.this.selectedFeatures[i]) {
                    if (TreeLearner.this.parentNodeIndex == -1 || TreeLearner.this.perNodeHistograms[TreeLearner.this.parentNodeIndex][i].splittable) {
                        TreeLearner.this.perNodeHistograms[TreeLearner.this.smallerChildIndex][i].init(TreeLearner.this.candidateSplitsForSmallerChild, TreeLearner.this.curTrainSet.indicesInDataset);
                        TreeLearner.this.setBestThresholdForSplit(TreeLearner.this.candidateSplitsForSmallerChild.getFeatureSplit(i), TreeLearner.this.perNodeHistograms[TreeLearner.this.smallerChildIndex][i]);
                        if (TreeLearner.this.parentNodeIndex != -1) {
                            try {
                                TreeLearner.this.perNodeHistograms[TreeLearner.this.largerChildIndex][i].subtractFromMe(TreeLearner.this.perNodeHistograms[TreeLearner.this.smallerChildIndex][i]);
                                TreeLearner.this.setBestThresholdForSplit(TreeLearner.this.candidateSplitsForLargerChild.getFeatureSplit(i), TreeLearner.this.perNodeHistograms[TreeLearner.this.largerChildIndex][i]);
                            } catch (Exception e) {
                                e.printStackTrace();
                            }
                        }
                    } else {
                        TreeLearner.this.perNodeHistograms[TreeLearner.this.smallerChildIndex][i].splittable = false;
                    }
                }
            }
        }
    }

    public TreeLearner(String str) {
        super(str);
    }

    /* JADX WARN: Type inference failed for: r1v17, types: [edu.uci.jforestsx.dataset.Histogram[], edu.uci.jforestsx.dataset.Histogram[][]] */
    public void init(Dataset dataset, ConfigHolder configHolder, int i) throws Exception {
        TrainingConfig trainingConfig = (TrainingConfig) configHolder.getConfig(TrainingConfig.class);
        TreesConfig treesConfig = (TreesConfig) configHolder.getConfig(TreesConfig.class);
        this.minInstancePercentagePerLeaf = treesConfig.minInstancePercentagePerLeaf;
        this.maxLeaves = treesConfig.numLeaves;
        this.perLeafBestSplit = new TreeSplit[treesConfig.numLeaves];
        this.leafCandidateSplitsCalculationTask = new TaskCollection<>();
        int maximumPoolSize = 1 + (dataset.numFeatures / BlockingThreadPoolExecutor.getInstance().getMaximumPoolSize());
        int i2 = 0;
        int i3 = 0;
        while (i2 < dataset.numFeatures) {
            this.leafCandidateSplitsCalculationTask.addTask(new BestThresholdForFeatureFinder(i2, i2 + Math.min(dataset.numFeatures - i2, maximumPoolSize)));
            i2 += maximumPoolSize;
            i3++;
        }
        this.perNodeHistograms = new Histogram[treesConfig.numLeaves];
        this.candidateSplitsForSmallerChild = getNewCandidateSplitsForLeaf(dataset.numFeatures, i);
        this.candidateSplitsForLargerChild = getNewCandidateSplitsForLeaf(dataset.numFeatures, i);
        this.rand = new Random(trainingConfig.randomSeed);
        this.featureSamplingPerSplit = treesConfig.featureSamplingPerSplit;
        this.randomizedSplits = treesConfig.randomizedSplits;
        this.selectedFeatures = new boolean[dataset.numFeatures];
        this.trainTreeLeafInstances = new TreeLeafInstances(i, this.maxLeaves);
        this.featuresToDiscard = new boolean[dataset.numFeatures];
        String str = treesConfig.featuresToInclude;
        if (str != null && str.trim().length() > 0) {
            Arrays.fill(this.featuresToDiscard, true);
            for (String str2 : str.split(",")) {
                int featureIdx = dataset.getFeatureIdx(str2);
                if (featureIdx < 0) {
                    throw new Exception("Unknown feature: '" + str2 + "'");
                }
                this.featuresToDiscard[featureIdx] = false;
            }
        }
        String str3 = treesConfig.featuresToDiscard;
        if (str3 == null || str3.trim().length() <= 0) {
            return;
        }
        for (String str4 : str3.split(",")) {
            int featureIdx2 = dataset.getFeatureIdx(str4);
            if (featureIdx2 < 0) {
                throw new Exception("Unknown feature: '" + str4 + "'");
            }
            this.featuresToDiscard[featureIdx2] = true;
        }
    }

    public void setRnd() {
        this.rand = new Random(1L);
    }

    protected abstract Tree getNewTree();

    protected abstract TreeSplit getNewSplit();

    protected abstract CandidateSplitsForLeaf getNewCandidateSplitsForLeaf(int i, int i2);

    protected abstract Histogram getNewHistogram(Feature feature);

    @Override // edu.uci.jforestsx.learning.LearningModule
    public Ensemble learn(Sample sample, Sample sample2) throws Exception {
        this.curTrainSet = sample;
        this.trainTreeLeafInstances.init(this.curTrainSet.size);
        this.minInstancesPerLeaf = (int) ((this.curTrainSet.size * this.minInstancePercentagePerLeaf) / 100.0d);
        for (int i = 0; i < this.selectedFeatures.length; i++) {
            this.selectedFeatures[i] = !this.featuresToDiscard[i];
        }
        for (int i2 = 0; i2 < this.perNodeHistograms.length; i2++) {
            if (this.perNodeHistograms[i2] != null) {
                for (int i3 = 0; i3 < this.perNodeHistograms[0].length; i3++) {
                    Histogram histogram = this.perNodeHistograms[i2][i3];
                    if (histogram != null) {
                        histogram.splittable = true;
                    }
                }
            }
        }
        Tree newTree = getNewTree();
        this.candidateSplitsForSmallerChild.init(0, this.trainTreeLeafInstances, this.curTrainSet);
        this.parentNodeIndex = -1;
        this.smallerChildIndex = 0;
        if (this.perNodeHistograms[0] == null) {
            this.perNodeHistograms[0] = getNewHistogramArray();
        }
        this.candidateSplitsForLargerChild.init(-1);
        this.leafCandidateSplitsCalculationTask.run();
        setBestTreeSplitForLeaf(this.candidateSplitsForSmallerChild);
        TreeSplit treeSplit = this.perLeafBestSplit[0];
        if (Double.isInfinite(treeSplit.gain)) {
            return null;
        }
        int rightChild = newTree.getRightChild(newTree.split(0, treeSplit)) ^ (-1);
        int i4 = 0;
        this.parentNodeIndex = 0;
        this.trainTreeLeafInstances.split(0, this.curTrainSet.dataset, treeSplit.feature, treeSplit.threshold, rightChild, this.curTrainSet.indicesInDataset);
        for (int i5 = 2; i5 < this.maxLeaves; i5++) {
            int numberOfInstancesInLeaf = this.trainTreeLeafInstances.getNumberOfInstancesInLeaf(i4);
            int numberOfInstancesInLeaf2 = this.trainTreeLeafInstances.getNumberOfInstancesInLeaf(rightChild);
            if (numberOfInstancesInLeaf2 >= 2 * this.minInstancesPerLeaf || numberOfInstancesInLeaf >= 2 * this.minInstancesPerLeaf) {
                if (numberOfInstancesInLeaf < numberOfInstancesInLeaf2) {
                    Histogram[] histogramArr = this.perNodeHistograms[rightChild];
                    this.perNodeHistograms[rightChild] = this.perNodeHistograms[i4];
                    if (histogramArr != null) {
                        this.perNodeHistograms[i4] = histogramArr;
                    } else {
                        this.perNodeHistograms[i4] = getNewHistogramArray();
                    }
                    this.largerChildIndex = rightChild;
                    this.smallerChildIndex = i4;
                } else {
                    if (this.perNodeHistograms[rightChild] == null) {
                        this.perNodeHistograms[rightChild] = getNewHistogramArray();
                    }
                    this.largerChildIndex = i4;
                    this.smallerChildIndex = rightChild;
                }
                this.candidateSplitsForSmallerChild.init(this.smallerChildIndex, this.trainTreeLeafInstances, this.curTrainSet);
                this.candidateSplitsForLargerChild.init(this.largerChildIndex, this.trainTreeLeafInstances, this.curTrainSet);
                this.leafCandidateSplitsCalculationTask.run();
                setBestTreeSplitForLeaf(this.candidateSplitsForSmallerChild);
                setBestTreeSplitForLeaf(this.candidateSplitsForLargerChild);
            } else {
                this.perLeafBestSplit[i4].gain = Double.NEGATIVE_INFINITY;
                this.perLeafBestSplit[rightChild] = getNewSplit();
                this.perLeafBestSplit[rightChild].gain = Double.NEGATIVE_INFINITY;
            }
            int i6 = 0;
            double d = Double.NEGATIVE_INFINITY;
            for (int i7 = 0; i7 < newTree.numLeaves; i7++) {
                if (this.perLeafBestSplit[i7].gain > d) {
                    d = this.perLeafBestSplit[i7].gain;
                    i6 = i7;
                }
            }
            TreeSplit treeSplit2 = this.perLeafBestSplit[i6];
            if (treeSplit2.gain <= 0.0d || Double.isNaN(treeSplit2.gain)) {
                break;
            }
            i4 = i6;
            rightChild = newTree.getRightChild(newTree.split(i6, treeSplit2)) ^ (-1);
            this.parentNodeIndex = i6;
            this.trainTreeLeafInstances.split(i6, this.curTrainSet.dataset, treeSplit2.feature, treeSplit2.threshold, rightChild, this.curTrainSet.indicesInDataset);
        }
        if (this.parentLearner != null) {
            this.parentLearner.postProcess(newTree, this.trainTreeLeafInstances);
        }
        Ensemble ensemble = new Ensemble();
        ensemble.addTree(newTree, this.treeWeight);
        return ensemble;
    }

    protected void setBestTreeSplitForLeaf(CandidateSplitsForLeaf candidateSplitsForLeaf) {
        int bestFeature = this.featureSamplingPerSplit < 1.0d ? candidateSplitsForLeaf.getBestFeature(this.featureSamplingPerSplit, this.rand) : candidateSplitsForLeaf.getBestFeature();
        int leafIndex = candidateSplitsForLeaf.getLeafIndex();
        if (this.perLeafBestSplit[leafIndex] == null) {
            this.perLeafBestSplit[leafIndex] = getNewSplit();
        }
        if (bestFeature >= 0) {
            this.perLeafBestSplit[leafIndex].copy(candidateSplitsForLeaf.getFeatureSplit(bestFeature));
        } else {
            this.perLeafBestSplit[leafIndex].copy(candidateSplitsForLeaf.getFeatureSplit(0));
            this.perLeafBestSplit[leafIndex].gain = Double.NEGATIVE_INFINITY;
        }
    }

    protected abstract void setBestThresholdForSplit(TreeSplit treeSplit, Histogram histogram);

    @Override // edu.uci.jforestsx.learning.LearningModule
    public double getValidationMeasurement() throws Exception {
        throw new Exception("Validation Measurement should not be computed for TreeLearner.");
    }

    private Histogram[] getNewHistogramArray() {
        Histogram[] histogramArr = new Histogram[this.curTrainSet.dataset.numFeatures];
        for (int i = 0; i < this.curTrainSet.dataset.numFeatures; i++) {
            histogramArr[i] = getNewHistogram(this.curTrainSet.dataset.features[i]);
        }
        return histogramArr;
    }
}
