package hex.tree.dt;

import hex.ModelBuilder;
import hex.ModelCategory;
import hex.ModelMetrics;
import hex.tree.dt.DTModel;
import hex.tree.dt.binning.BinAccumulatedStatistics;
import hex.tree.dt.binning.BinningStrategy;
import hex.tree.dt.binning.Histogram;
import hex.tree.dt.mrtasks.GetClassCountsMRTask;
import hex.tree.dt.mrtasks.ScoreDTTask;
import java.util.Arrays;
import java.util.Comparator;
import java.util.LinkedList;
import java.util.List;
import java.util.Queue;
import java.util.Random;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;
import org.apache.commons.math3.util.Precision;
import org.apache.log4j.Logger;
import water.DKV;
import water.exceptions.H2OModelBuilderIllegalArgumentException;
import water.fvec.Frame;
import water.util.Log;
import water.util.MathUtils;
import water.util.Pair;
import water.util.RandomUtils;

/* loaded from: input_file:hex/tree/dt/DT.class */
public class DT extends ModelBuilder<DTModel, DTModel.DTParameters, DTModel.DTOutput> {
    private int _min_rows;
    int _nodesCount;
    private double[][] _tree;
    private DTModel _model;
    transient Random _rand;
    public static final double EPSILON = 1.0E-6d;
    public static final double MIN_IMPROVEMENT = 1.0E-6d;
    private static final Logger LOG = Logger.getLogger(DT.class);

    /* loaded from: input_file:hex/tree/dt/DT$DTDriver.class */
    private class DTDriver extends ModelBuilder<DTModel, DTModel.DTParameters, DTModel.DTOutput>.Driver {
        private DTDriver() {
            super();
        }

        private void dtChecks() {
            if (((DTModel.DTParameters) DT.this._parms)._max_depth < 1) {
                DT.this.error("_parms._max_depth", "Max depth has to be at least 1");
            }
            if (DT.this._train.hasNAs()) {
                DT.this.error("_train", "NaNs are not supported yet");
            }
            if (DT.this._train.hasInfs()) {
                DT.this.error("_train", "Infs are not supported");
            }
            if (IntStream.range(0, DT.this._train.numCols() - 1).mapToObj(i -> {
                return Boolean.valueOf(DT.this._train.vec(i).isCategorical());
            }).anyMatch(bool -> {
                return bool.booleanValue();
            })) {
                DT.this.error("_train", "Categorical features are not supported yet");
            }
            if (!DT.this._response.isCategorical()) {
                DT.this.error("_response", "Only categorical response is supported");
            }
            if (DT.this._response.isBinary()) {
                return;
            }
            DT.this.error("_response", "Only binary response is supported");
        }

        @Override // hex.ModelBuilder.Driver
        public void computeImpl() {
            DT.this._model = null;
            try {
                DT.this.init(true);
                dtChecks();
                if (DT.this.error_count() > 0) {
                    throw H2OModelBuilderIllegalArgumentException.makeFromBuilder(DT.this);
                }
                DT.this._rand = RandomUtils.getRNG(((DTModel.DTParameters) DT.this._parms)._seed);
                DT.this._model = new DTModel(DT.this.dest(), (DTModel.DTParameters) DT.this._parms, new DTModel.DTOutput(DT.this));
                DT.this._model.delete_and_lock(DT.this._job);
                buildDT();
                DT.LOG.info(DT.this._model.toString());
            } finally {
                if (DT.this._model != null) {
                    DT.this._model.unlock(DT.this._job);
                }
            }
        }

        private void buildDT() {
            buildDTIteratively();
            Log.debug("depth: " + ((DTModel.DTParameters) DT.this._parms)._max_depth + ", nodes count: " + DT.this._nodesCount);
            CompressedDT compressedDT = new CompressedDT(DT.this._tree);
            ((DTModel.DTOutput) DT.this._model._output)._treeKey = compressedDT._key;
            DKV.put(compressedDT);
            DT.this._job.update(1L);
            DT.this._model.update(DT.this._job);
            DT.this.makeModelMetrics();
            Log.debug("Tree:");
            Log.debug(Arrays.deepToString(DT.this._tree));
        }

        private void buildDTIteratively() {
            DT.this._tree = new double[((int) Math.pow(2.0d, ((DTModel.DTParameters) DT.this._parms)._max_depth + 1)) - 1][3];
            LinkedList linkedList = new LinkedList();
            linkedList.add(DT.getInitialFeaturesLimits(DT.this._train));
            for (int i = 0; i < DT.this._tree.length; i++) {
                DT.this.buildNextNode(linkedList, i);
            }
        }
    }

    public DT(DTModel.DTParameters dTParameters) {
        super(dTParameters);
        this._min_rows = dTParameters._min_rows;
        this._nodesCount = 0;
        this._tree = (double[][]) null;
        init(true);
    }

    public DT(boolean z) {
        super(new DTModel.DTParameters(), z);
    }

    private SplitInfo findBestSplit(Histogram histogram) {
        Pair pair;
        int featuresCount = histogram.featuresCount();
        Pair pair2 = new Pair(Double.valueOf(-1.0d), Double.valueOf(Double.MAX_VALUE));
        int i = -1;
        for (int i2 = 0; i2 < featuresCount; i2++) {
            if (!histogram.isConstant(i2) && (pair = (Pair) histogram.calculateBinsStatisticsForFeature(i2).stream().filter(binAccumulatedStatistics -> {
                return binAccumulatedStatistics._leftCount >= this._min_rows && binAccumulatedStatistics._rightCount >= this._min_rows;
            }).peek(binAccumulatedStatistics2 -> {
                Log.debug("counts: " + binAccumulatedStatistics2._maxBinValue + " " + binAccumulatedStatistics2._leftCount + " " + binAccumulatedStatistics2._rightCount);
            }).map(binAccumulatedStatistics3 -> {
                return new Pair(Double.valueOf(binAccumulatedStatistics3._maxBinValue), calculateCriterionOfSplit(binAccumulatedStatistics3));
            }).min(Comparator.comparing((v0) -> {
                return v0._2();
            })).orElse(null)) != null && ((Double) pair._2()).doubleValue() < ((Double) pair2._2()).doubleValue()) {
                pair2 = pair;
                i = i2;
            }
        }
        if (i == -1) {
            return null;
        }
        return new SplitInfo(i, ((Double) pair2._1()).doubleValue(), ((Double) pair2._2()).doubleValue());
    }

    private Double binaryEntropy(int i, int i2, int i3, int i4) {
        return Double.valueOf(((entropyBinarySplit((i2 * 1.0d) / i) * i) / (i + i3)) + ((entropyBinarySplit((i4 * 1.0d) / i3) * i3) / (i + i3)));
    }

    private double entropyBinarySplit(double d) {
        return (-1.0d) * ((d < Precision.EPSILON ? CMAESOptimizer.DEFAULT_STOPFITNESS : d * Math.log(d)) + (1.0d - d < Precision.EPSILON ? CMAESOptimizer.DEFAULT_STOPFITNESS : (1.0d - d) * Math.log(1.0d - d)));
    }

    private Double calculateCriterionOfSplit(BinAccumulatedStatistics binAccumulatedStatistics) {
        return binaryEntropy(binAccumulatedStatistics._leftCount, binAccumulatedStatistics._leftCount0, binAccumulatedStatistics._rightCount, binAccumulatedStatistics._rightCount0);
    }

    private int selectDecisionValue(int[] iArr) {
        if (this._nclass == 1) {
            return iArr[0];
        }
        int i = 0;
        int i2 = iArr[0];
        for (int i3 = 1; i3 < this._nclass; i3++) {
            if (iArr[i3] > i2) {
                i = i3;
                i2 = iArr[i3];
            }
        }
        return i;
    }

    private double[] calculateProbability(int[] iArr) {
        int sum = Arrays.stream(iArr).sum();
        return Arrays.stream(iArr).asDoubleStream().map(d -> {
            return d / sum;
        }).toArray();
    }

    public void makeLeafFromNode(int[] iArr, int i) {
        this._tree[i][0] = 1.0d;
        this._tree[i][1] = selectDecisionValue(iArr);
        this._tree[i][2] = calculateProbability(iArr)[0];
    }

    public void buildNextNode(Queue<DataFeaturesLimits> queue, int i) {
        DataFeaturesLimits poll = queue.poll();
        if (poll == null) {
            queue.add(null);
            queue.add(null);
            return;
        }
        int[] countClasses = countClasses(poll);
        if (i == 0) {
            Log.info("Classes counts in dataset: 0 - " + countClasses[0] + ", 1 - " + countClasses[1]);
        }
        if (((int) Math.floor(MathUtils.log2(i + 1))) >= ((DTModel.DTParameters) this._parms)._max_depth || countClasses[0] <= this._min_rows || countClasses[1] <= this._min_rows) {
            queue.add(null);
            queue.add(null);
            makeLeafFromNode(countClasses, i);
            return;
        }
        SplitInfo findBestSplit = findBestSplit(new Histogram(this._train, poll, BinningStrategy.EQUAL_WIDTH));
        double entropyBinarySplit = entropyBinarySplit((1.0d * countClasses[0]) / (countClasses[0] + countClasses[1]));
        if (findBestSplit == null || Math.abs(entropyBinarySplit - findBestSplit._criterionValue) < 1.0E-6d) {
            queue.add(null);
            queue.add(null);
            makeLeafFromNode(countClasses, i);
            return;
        }
        this._tree[i][0] = 0.0d;
        this._tree[i][1] = findBestSplit._splitFeatureIndex;
        this._tree[i][2] = findBestSplit._threshold;
        DataFeaturesLimits updateMax = poll.updateMax(findBestSplit._splitFeatureIndex, findBestSplit._threshold);
        DataFeaturesLimits updateMin = poll.updateMin(findBestSplit._splitFeatureIndex, findBestSplit._threshold);
        Log.debug("root: " + Arrays.toString(countClasses(poll)) + ", left: " + Arrays.toString(countClasses(updateMax)) + ", right: " + Arrays.toString(countClasses(updateMin)) + ", best feature: " + findBestSplit._splitFeatureIndex + ", threshold: " + findBestSplit._threshold);
        Log.debug("feature: " + findBestSplit._splitFeatureIndex + ", threshold: " + findBestSplit._threshold);
        Log.debug("Left min-max: " + updateMax.getFeatureLimits(findBestSplit._splitFeatureIndex)._min + " " + updateMax.getFeatureLimits(findBestSplit._splitFeatureIndex)._max);
        Log.debug("Right min-max: " + updateMin.getFeatureLimits(findBestSplit._splitFeatureIndex)._min + " " + updateMin.getFeatureLimits(findBestSplit._splitFeatureIndex)._max);
        queue.add(updateMax);
        queue.add(updateMin);
    }

    public static DataFeaturesLimits getInitialFeaturesLimits(Frame frame) {
        return new DataFeaturesLimits((List<FeatureLimits>) IntStream.range(0, frame.numCols() - 1).mapToObj(i -> {
            return frame.vec(i);
        }).map(vec -> {
            return new FeatureLimits(vec.min() - 1.0E-6d, vec.max());
        }).collect(Collectors.toList()));
    }

    @Override // hex.ModelBuilder
    protected ModelBuilder<DTModel, DTModel.DTParameters, DTModel.DTOutput>.Driver trainModelImpl() {
        return new DTDriver();
    }

    @Override // hex.ModelBuilder
    public ModelBuilder.BuilderVisibility builderVisibility() {
        return ModelBuilder.BuilderVisibility.Experimental;
    }

    @Override // hex.ModelBuilder
    public ModelCategory[] can_build() {
        return new ModelCategory[]{ModelCategory.Binomial};
    }

    @Override // hex.ModelBuilder
    public boolean isSupervised() {
        return true;
    }

    protected final void makeModelMetrics() {
        ((DTModel.DTOutput) this._model._output)._training_metrics = new ScoreDTTask(this._model).doAll(this._train).getMetricsBuilder().makeModelMetrics(this._model, ((DTModel.DTParameters) this._parms).train(), null, null);
        if (((DTModel.DTParameters) this._parms)._valid != null) {
            Frame frame = new Frame(valid());
            ModelMetrics.MetricBuilder metricsBuilder = new ScoreDTTask(this._model).doAll(frame).getMetricsBuilder();
            ((DTModel.DTOutput) this._model._output)._validation_metrics = metricsBuilder.makeModelMetrics(this._model, frame, null, null);
        }
    }

    private int[] countClasses(DataFeaturesLimits dataFeaturesLimits) {
        GetClassCountsMRTask getClassCountsMRTask = new GetClassCountsMRTask(dataFeaturesLimits == null ? (double[][]) Stream.generate(() -> {
            return new double[]{-1.7976931348623157E308d, Double.MAX_VALUE};
        }).limit(this._train.numCols() - 1).toArray(i -> {
            return new double[i];
        }) : dataFeaturesLimits.toDoubles(), this._nclass);
        getClassCountsMRTask.doAll(this._train);
        return getClassCountsMRTask._countsByClass;
    }
}
