package org.tribuo.regression.rtree.impl;

import com.oracle.labs.mlrg.olcut.util.Pair;
import java.io.IOException;
import java.io.NotSerializableException;
import java.io.ObjectOutputStream;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.SplittableRandom;
import java.util.logging.Logger;
import org.tribuo.Dataset;
import org.tribuo.Example;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.common.tree.AbstractTrainingNode;
import org.tribuo.common.tree.LeafNode;
import org.tribuo.common.tree.Node;
import org.tribuo.common.tree.impl.IntArrayContainer;
import org.tribuo.math.la.SparseVector;
import org.tribuo.math.la.VectorIterator;
import org.tribuo.math.la.VectorTuple;
import org.tribuo.regression.ImmutableRegressionInfo;
import org.tribuo.regression.Regressor;
import org.tribuo.regression.rtree.impurity.RegressorImpurity;
import org.tribuo.util.Util;

/* loaded from: input_file:org/tribuo/regression/rtree/impl/RegressorTrainingNode.class */
public class RegressorTrainingNode extends AbstractTrainingNode<Regressor> {
    private static final long serialVersionUID = 1;
    private static final Logger logger = Logger.getLogger(RegressorTrainingNode.class.getName());
    private static final ThreadLocal<IntArrayContainer> mergeBufferOne = ThreadLocal.withInitial(() -> {
        return new IntArrayContainer(16);
    });
    private static final ThreadLocal<IntArrayContainer> mergeBufferTwo = ThreadLocal.withInitial(() -> {
        return new IntArrayContainer(16);
    });
    private transient ArrayList<TreeFeature> data;
    private final ImmutableOutputInfo<Regressor> labelIDMap;
    private final ImmutableFeatureMap featureIDMap;
    private final RegressorImpurity impurity;
    private final int[] indices;
    private final float[] targets;
    private final float[] weights;
    private final String dimName;
    private final float weightSum;

    /* loaded from: input_file:org/tribuo/regression/rtree/impl/RegressorTrainingNode$InvertedData.class */
    public static class InvertedData {
        final ArrayList<TreeFeature> data;
        final int[] indices;
        final float[][] targets;
        final float[] weights;

        InvertedData(ArrayList<TreeFeature> arrayList, int[] iArr, float[][] fArr, float[] fArr2) {
            this.data = arrayList;
            this.indices = iArr;
            this.targets = fArr;
            this.weights = fArr2;
        }

        ArrayList<TreeFeature> copyData() {
            ArrayList<TreeFeature> arrayList = new ArrayList<>();
            Iterator<TreeFeature> it = this.data.iterator();
            while (it.hasNext()) {
                arrayList.add(it.next().deepCopy());
            }
            return arrayList;
        }
    }

    public RegressorTrainingNode(RegressorImpurity regressorImpurity, InvertedData invertedData, int i, String str, int i2, ImmutableFeatureMap immutableFeatureMap, ImmutableOutputInfo<Regressor> immutableOutputInfo, AbstractTrainingNode.LeafDeterminer leafDeterminer) {
        this(regressorImpurity, invertedData.copyData(), invertedData.indices, invertedData.targets[i], invertedData.weights, str, i2, 0, immutableFeatureMap, immutableOutputInfo, leafDeterminer);
    }

    private RegressorTrainingNode(RegressorImpurity regressorImpurity, ArrayList<TreeFeature> arrayList, int[] iArr, float[] fArr, float[] fArr2, String str, int i, int i2, ImmutableFeatureMap immutableFeatureMap, ImmutableOutputInfo<Regressor> immutableOutputInfo, AbstractTrainingNode.LeafDeterminer leafDeterminer) {
        super(i2, i, leafDeterminer);
        this.data = arrayList;
        this.featureIDMap = immutableFeatureMap;
        this.labelIDMap = immutableOutputInfo;
        this.impurity = regressorImpurity;
        this.indices = iArr;
        this.targets = fArr;
        this.weights = fArr2;
        this.dimName = str;
        this.weightSum = Util.sum(iArr, iArr.length, fArr2);
        this.impurityScore = regressorImpurity.impurity(iArr, fArr, fArr2);
    }

    private RegressorTrainingNode(RegressorImpurity regressorImpurity, ArrayList<TreeFeature> arrayList, int[] iArr, float[] fArr, float[] fArr2, String str, int i, int i2, ImmutableFeatureMap immutableFeatureMap, ImmutableOutputInfo<Regressor> immutableOutputInfo, AbstractTrainingNode.LeafDeterminer leafDeterminer, float f, double d) {
        super(i2, i, leafDeterminer);
        this.data = arrayList;
        this.featureIDMap = immutableFeatureMap;
        this.labelIDMap = immutableOutputInfo;
        this.impurity = regressorImpurity;
        this.indices = iArr;
        this.targets = fArr;
        this.weights = fArr2;
        this.dimName = str;
        this.weightSum = f;
        this.impurityScore = d;
    }

    public double getImpurity() {
        return this.impurityScore;
    }

    public float getWeightSum() {
        return this.weightSum;
    }

    public List<AbstractTrainingNode<Regressor>> buildTree(int[] iArr, SplittableRandom splittableRandom, boolean z) {
        return z ? buildRandomTree(iArr, splittableRandom) : buildGreedyTree(iArr);
    }

    private List<AbstractTrainingNode<Regressor>> buildGreedyTree(int[] iArr) {
        int i = -1;
        double d = 0.0d;
        double impurity = getImpurity();
        ArrayList arrayList = new ArrayList();
        List<int[]> arrayList2 = new ArrayList<>();
        List<int[]> arrayList3 = new ArrayList<>();
        for (int i2 = 0; i2 < iArr.length; i2++) {
            List<InvertedFeature> feature = this.data.get(iArr[i2]).getFeature();
            arrayList.clear();
            for (int i3 = 0; i3 < feature.size(); i3++) {
                arrayList.add(feature.get(i3).indices());
            }
            for (int i4 = 0; i4 < feature.size() - 1; i4++) {
                List<int[]> subList = arrayList.subList(0, i4 + 1);
                List<int[]> subList2 = arrayList.subList(i4 + 1, feature.size());
                RegressorImpurity.ImpurityTuple impurityTuple = this.impurity.impurityTuple(subList, this.targets, this.weights);
                RegressorImpurity.ImpurityTuple impurityTuple2 = this.impurity.impurityTuple(subList2, this.targets, this.weights);
                double d2 = ((impurityTuple.impurity * impurityTuple.weight) + (impurityTuple2.impurity * impurityTuple2.weight)) / this.weightSum;
                if (d2 < impurity) {
                    i = i2;
                    impurity = d2;
                    d = (feature.get(i4).value + feature.get(i4 + 1).value) / 2.0d;
                    arrayList2.clear();
                    arrayList2.addAll(subList);
                    arrayList3.clear();
                    arrayList3.addAll(subList2);
                }
            }
        }
        List<AbstractTrainingNode<Regressor>> emptyList = (i == -1 || ((double) this.weightSum) * (getImpurity() - impurity) < ((double) this.leafDeterminer.getScaledMinImpurityDecrease())) ? Collections.emptyList() : splitAtBest(iArr, i, d, arrayList2, arrayList3);
        this.data = null;
        return emptyList;
    }

    private List<AbstractTrainingNode<Regressor>> buildRandomTree(int[] iArr, SplittableRandom splittableRandom) {
        int i = -1;
        double d = 0.0d;
        double impurity = getImpurity();
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        List<int[]> arrayList3 = new ArrayList<>();
        List<int[]> arrayList4 = new ArrayList<>();
        for (int i2 = 0; i2 < iArr.length; i2++) {
            List<InvertedFeature> feature = this.data.get(iArr[i2]).getFeature();
            if (feature.size() != 1) {
                int nextInt = splittableRandom.nextInt(feature.size() - 1);
                for (int i3 = 0; i3 < nextInt + 1; i3++) {
                    arrayList.add(feature.get(i3).indices());
                }
                for (int i4 = nextInt + 1; i4 < feature.size(); i4++) {
                    arrayList2.add(feature.get(i4).indices());
                }
                RegressorImpurity.ImpurityTuple impurityTuple = this.impurity.impurityTuple(arrayList, this.targets, this.weights);
                RegressorImpurity.ImpurityTuple impurityTuple2 = this.impurity.impurityTuple(arrayList2, this.targets, this.weights);
                double d2 = ((impurityTuple.impurity * impurityTuple.weight) + (impurityTuple2.impurity * impurityTuple2.weight)) / this.weightSum;
                if (d2 < impurity) {
                    i = i2;
                    impurity = d2;
                    d = (feature.get(nextInt).value + feature.get(nextInt + 1).value) / 2.0d;
                    arrayList3.clear();
                    arrayList3.addAll(arrayList);
                    arrayList4.clear();
                    arrayList4.addAll(arrayList2);
                }
            }
        }
        List<AbstractTrainingNode<Regressor>> emptyList = (i == -1 || ((double) this.weightSum) * (getImpurity() - impurity) < ((double) this.leafDeterminer.getScaledMinImpurityDecrease())) ? Collections.emptyList() : splitAtBest(iArr, i, d, arrayList3, arrayList4);
        this.data = null;
        return emptyList;
    }

    private List<AbstractTrainingNode<Regressor>> splitAtBest(int[] iArr, int i, double d, List<int[]> list, List<int[]> list2) {
        this.splitID = iArr[i];
        this.split = true;
        this.splitValue = d;
        IntArrayContainer intArrayContainer = mergeBufferOne.get();
        intArrayContainer.size = 0;
        intArrayContainer.grow(this.indices.length);
        IntArrayContainer intArrayContainer2 = mergeBufferTwo.get();
        intArrayContainer2.size = 0;
        intArrayContainer2.grow(this.indices.length);
        int[] merge = IntArrayContainer.merge(list, intArrayContainer, intArrayContainer2);
        int[] merge2 = IntArrayContainer.merge(list2, intArrayContainer, intArrayContainer2);
        float sum = Util.sum(merge, merge.length, this.weights);
        double impurity = this.impurity.impurity(merge, this.targets, this.weights);
        float sum2 = Util.sum(merge2, merge2.length, this.weights);
        double impurity2 = this.impurity.impurity(merge2, this.targets, this.weights);
        boolean shouldMakeLeaf = shouldMakeLeaf(impurity, sum);
        boolean shouldMakeLeaf2 = shouldMakeLeaf(impurity2, sum2);
        if (shouldMakeLeaf && shouldMakeLeaf2) {
            this.lessThanOrEqual = createLeaf(impurity, merge);
            this.greaterThan = createLeaf(impurity2, merge2);
            return Collections.emptyList();
        }
        ArrayList arrayList = new ArrayList(this.data.size());
        ArrayList arrayList2 = new ArrayList(this.data.size());
        Iterator<TreeFeature> it = this.data.iterator();
        while (it.hasNext()) {
            Pair<TreeFeature, TreeFeature> split = it.next().split(merge, merge2, intArrayContainer, intArrayContainer2);
            arrayList.add((TreeFeature) split.getA());
            arrayList2.add((TreeFeature) split.getB());
        }
        ArrayList arrayList3 = new ArrayList(2);
        if (shouldMakeLeaf) {
            this.lessThanOrEqual = createLeaf(impurity, merge);
        } else {
            RegressorTrainingNode regressorTrainingNode = new RegressorTrainingNode(this.impurity, arrayList, merge, this.targets, this.weights, this.dimName, merge.length, this.depth + 1, this.featureIDMap, this.labelIDMap, this.leafDeterminer, sum, impurity);
            this.lessThanOrEqual = regressorTrainingNode;
            arrayList3.add(regressorTrainingNode);
        }
        if (shouldMakeLeaf2) {
            this.greaterThan = createLeaf(impurity2, merge2);
        } else {
            RegressorTrainingNode regressorTrainingNode2 = new RegressorTrainingNode(this.impurity, arrayList2, merge2, this.targets, this.weights, this.dimName, merge2.length, this.depth + 1, this.featureIDMap, this.labelIDMap, this.leafDeterminer, sum2, impurity2);
            this.greaterThan = regressorTrainingNode2;
            arrayList3.add(regressorTrainingNode2);
        }
        return arrayList3;
    }

    public Node<Regressor> convertTree() {
        return this.split ? createSplitNode() : createLeaf(getImpurity(), this.indices);
    }

    private LeafNode<Regressor> createLeaf(double d, int[] iArr) {
        double d2 = 0.0d;
        double d3 = 0.0d;
        double d4 = 0.0d;
        for (int i : iArr) {
            float f = this.targets[i];
            float f2 = this.weights[i];
            d3 += f2;
            double d5 = d2;
            d2 += (f2 / d3) * (f - d5);
            d4 += f2 * (f - d5) * (f - d2);
        }
        return new LeafNode<>(d, new Regressor.DimensionTuple(this.dimName, d2, iArr.length > 1 ? d4 / (d3 - 1.0d) : 0.0d), Collections.emptyMap(), false);
    }

    public static InvertedData invertData(Dataset<Regressor> dataset) {
        ImmutableFeatureMap featureIDMap = dataset.getFeatureIDMap();
        ImmutableRegressionInfo outputIDInfo = dataset.getOutputIDInfo();
        int size = outputIDInfo.size();
        int size2 = featureIDMap.size();
        int[] iArr = new int[dataset.size()];
        float[][] fArr = new float[size][dataset.size()];
        float[] fArr2 = new float[dataset.size()];
        logger.fine("Building initial List<TreeFeature> for " + size2 + " features and " + size + " outputs");
        ArrayList arrayList = new ArrayList(featureIDMap.size());
        for (int i = 0; i < featureIDMap.size(); i++) {
            arrayList.add(new TreeFeature(i));
        }
        int[] naturalOrderToIDMapping = outputIDInfo.getNaturalOrderToIDMapping();
        for (int i2 = 0; i2 < dataset.size(); i2++) {
            Example example = dataset.getExample(i2);
            iArr[i2] = i2;
            fArr2[i2] = example.getWeight();
            double[] values = example.getOutput().getValues();
            for (int i3 = 0; i3 < values.length; i3++) {
                fArr[naturalOrderToIDMapping[i3]][i2] = (float) values[i3];
            }
            int i4 = 0;
            VectorIterator it = SparseVector.createSparseVector(example, featureIDMap, false).iterator();
            while (it.hasNext()) {
                VectorTuple vectorTuple = (VectorTuple) it.next();
                int i5 = vectorTuple.index;
                for (int i6 = i4; i6 < i5; i6++) {
                    ((TreeFeature) arrayList.get(i6)).observeValue(0.0d, i2);
                }
                ((TreeFeature) arrayList.get(i5)).observeValue(vectorTuple.value, i2);
                if (i4 > i5) {
                    logger.severe("Example = " + example.toString());
                    throw new IllegalStateException("Features aren't ordered. At id " + i2 + ", lastID = " + i4 + ", curID = " + i5);
                }
                if (i4 - 1 == i5) {
                    logger.severe("Example = " + example.toString());
                    throw new IllegalStateException("Features are repeated. At id " + i2 + ", lastID = " + i4 + ", curID = " + i5);
                }
                i4 = i5 + 1;
            }
            for (int i7 = i4; i7 < size2; i7++) {
                ((TreeFeature) arrayList.get(i7)).observeValue(0.0d, i2);
            }
            if (i2 % 1000 == 0) {
                logger.fine("Processed example " + i2);
            }
        }
        logger.fine("Sorting features");
        arrayList.forEach((v0) -> {
            v0.sort();
        });
        logger.fine("Fixing InvertedFeature sizes");
        arrayList.forEach((v0) -> {
            v0.fixSize();
        });
        logger.fine("Built initial List<TreeFeature>");
        return new InvertedData(arrayList, iArr, fArr, fArr2);
    }

    private void writeObject(ObjectOutputStream objectOutputStream) throws IOException {
        throw new NotSerializableException("RegressorTrainingNode is a runtime class only, and should not be serialized.");
    }
}
