package org.neo4j.gds.ml.decisiontree;

import java.lang.Number;
import java.util.ArrayDeque;
import java.util.Deque;
import java.util.Objects;
import org.neo4j.gds.annotation.ValueClass;
import org.neo4j.gds.collections.ha.HugeLongArray;
import org.neo4j.gds.core.utils.paged.ReadOnlyHugeLongArray;
import org.neo4j.gds.mem.Estimate;
import org.neo4j.gds.mem.MemoryRange;
import org.neo4j.gds.ml.models.Features;

/* loaded from: input_file:org/neo4j/gds/ml/decisiontree/DecisionTreeTrainer.class */
public abstract class DecisionTreeTrainer<PREDICTION extends Number> {
    private final ImpurityCriterion impurityCriterion;
    private final Features features;
    private final DecisionTreeTrainerConfig config;
    private final FeatureBagger featureBagger;
    private Splitter splitter;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* JADX INFO: Access modifiers changed from: package-private */
    @ValueClass
    /* loaded from: input_file:org/neo4j/gds/ml/decisiontree/DecisionTreeTrainer$Split.class */
    public interface Split {
        int index();

        double value();

        Groups groups();
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    @ValueClass
    /* loaded from: input_file:org/neo4j/gds/ml/decisiontree/DecisionTreeTrainer$StackRecord.class */
    public interface StackRecord<PREDICTION extends Number> {
        TreeNode<PREDICTION> node();

        Split split();

        int depth();
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public DecisionTreeTrainer(Features features, DecisionTreeTrainerConfig decisionTreeTrainerConfig, ImpurityCriterion impurityCriterion, FeatureBagger featureBagger) {
        this.impurityCriterion = impurityCriterion;
        this.features = features;
        this.config = decisionTreeTrainerConfig;
        this.featureBagger = featureBagger;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static MemoryRange estimateTree(DecisionTreeTrainerConfig decisionTreeTrainerConfig, long j, long j2, long j3) {
        MemoryRange estimateTree = estimateTree(decisionTreeTrainerConfig, j, j2);
        long min = 2 * Math.min(decisionTreeTrainerConfig.maxDepth(), Math.max(1L, (j - decisionTreeTrainerConfig.minSplitSize()) + 2));
        return estimateTree.add(MemoryRange.of(Estimate.sizeOfInstance(ArrayDeque.class)).add(MemoryRange.of(1L, min).times(Estimate.sizeOfInstance(ImmutableStackRecord.class))).add(MemoryRange.of(0L, HugeLongArray.memoryEstimation(j / min) * min))).add(Splitter.memoryEstimation(j, j3));
    }

    public static MemoryRange estimateTree(DecisionTreeTrainerConfig decisionTreeTrainerConfig, long j, long j2) {
        if (j == 0) {
            return MemoryRange.empty();
        }
        long ceil = (long) Math.ceil(Math.min(Math.pow(2.0d, decisionTreeTrainerConfig.maxDepth()), Math.min(j / decisionTreeTrainerConfig.minLeafSize(), (2.0d * j) / decisionTreeTrainerConfig.minSplitSize())));
        return MemoryRange.of(Estimate.sizeOfInstance(DecisionTreePredictor.class)).add(MemoryRange.of(1L, ceil).times(j2)).add(MemoryRange.of(0L, ceil - 1).times(TreeNode.splitMemoryEstimation()));
    }

    public DecisionTreePredictor<PREDICTION> train(ReadOnlyHugeLongArray readOnlyHugeLongArray) {
        this.splitter = new Splitter(readOnlyHugeLongArray.size(), this.impurityCriterion, this.featureBagger, this.features, this.config.minLeafSize());
        ArrayDeque arrayDeque = new ArrayDeque();
        HugeLongArray newArray = HugeLongArray.newArray(readOnlyHugeLongArray.size());
        Objects.requireNonNull(readOnlyHugeLongArray);
        newArray.setAll(readOnlyHugeLongArray::get);
        TreeNode<PREDICTION> splitAndPush = splitAndPush(arrayDeque, ImmutableGroup.of(newArray, 0L, newArray.size(), this.impurityCriterion.groupImpurity(newArray, 0L, newArray.size())), 1);
        int maxDepth = this.config.maxDepth();
        int minSplitSize = this.config.minSplitSize();
        while (!arrayDeque.isEmpty()) {
            StackRecord stackRecord = (StackRecord) arrayDeque.pop();
            Split split = stackRecord.split();
            if (stackRecord.depth() >= maxDepth || split.groups().left().size() < minSplitSize) {
                stackRecord.node().setLeftChild(new TreeNode(toTerminal(split.groups().left())));
            } else {
                stackRecord.node().setLeftChild(splitAndPush(arrayDeque, split.groups().left(), stackRecord.depth() + 1));
            }
            if (stackRecord.depth() >= maxDepth || split.groups().right().size() < minSplitSize) {
                stackRecord.node().setRightChild(new TreeNode(toTerminal(split.groups().right())));
            } else {
                stackRecord.node().setRightChild(splitAndPush(arrayDeque, split.groups().right(), stackRecord.depth() + 1));
            }
        }
        return new DecisionTreePredictor<>(splitAndPush);
    }

    protected abstract PREDICTION toTerminal(Group group);

    private TreeNode<PREDICTION> splitAndPush(Deque<StackRecord<PREDICTION>> deque, Group group, int i) {
        if (!$assertionsDisabled && group.size() <= 0) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && i < 1) {
            throw new AssertionError();
        }
        if (group.size() < this.config.minSplitSize()) {
            return new TreeNode<>(toTerminal(group));
        }
        Split findBestSplit = this.splitter.findBestSplit(group);
        if (findBestSplit.groups().right().size() == 0) {
            return new TreeNode<>(toTerminal(findBestSplit.groups().left()));
        }
        if (findBestSplit.groups().left().size() == 0) {
            return new TreeNode<>(toTerminal(findBestSplit.groups().right()));
        }
        TreeNode<PREDICTION> treeNode = new TreeNode<>(findBestSplit.index(), findBestSplit.value());
        deque.push(ImmutableStackRecord.of(treeNode, findBestSplit, i));
        return treeNode;
    }

    static {
        $assertionsDisabled = !DecisionTreeTrainer.class.desiredAssertionStatus();
    }
}
