package org.tribuo.regression.rtree.impl;

import com.oracle.labs.mlrg.olcut.util.Pair;
import java.io.IOException;
import java.io.NotSerializableException;
import java.io.ObjectOutputStream;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.logging.Logger;
import org.tribuo.Dataset;
import org.tribuo.Example;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.common.tree.AbstractTrainingNode;
import org.tribuo.common.tree.LeafNode;
import org.tribuo.common.tree.Node;
import org.tribuo.common.tree.SplitNode;
import org.tribuo.common.tree.impl.IntArrayContainer;
import org.tribuo.math.la.SparseVector;
import org.tribuo.math.la.VectorIterator;
import org.tribuo.math.la.VectorTuple;
import org.tribuo.regression.Regressor;
import org.tribuo.regression.rtree.impurity.RegressorImpurity;
import org.tribuo.util.Util;

/* loaded from: input_file:org/tribuo/regression/rtree/impl/JointRegressorTrainingNode.class */
public class JointRegressorTrainingNode extends AbstractTrainingNode<Regressor> {
    private static final long serialVersionUID = 1;
    private static final Logger logger = Logger.getLogger(JointRegressorTrainingNode.class.getName());
    private static final ThreadLocal<IntArrayContainer> mergeBufferOne = ThreadLocal.withInitial(() -> {
        return new IntArrayContainer(16);
    });
    private static final ThreadLocal<IntArrayContainer> mergeBufferTwo = ThreadLocal.withInitial(() -> {
        return new IntArrayContainer(16);
    });
    private static final ThreadLocal<IntArrayContainer> mergeBufferThree = ThreadLocal.withInitial(() -> {
        return new IntArrayContainer(16);
    });
    private static final ThreadLocal<IntArrayContainer> mergeBufferFour = ThreadLocal.withInitial(() -> {
        return new IntArrayContainer(16);
    });
    private static final ThreadLocal<IntArrayContainer> mergeBufferFive = ThreadLocal.withInitial(() -> {
        return new IntArrayContainer(16);
    });
    private transient ArrayList<TreeFeature> data;
    private final boolean normalize;
    private final ImmutableOutputInfo<Regressor> labelIDMap;
    private final ImmutableFeatureMap featureIDMap;
    private final RegressorImpurity impurity;
    private final int[] indices;
    private final float[][] targets;
    private final float[] weights;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/tribuo/regression/rtree/impl/JointRegressorTrainingNode$InvertedData.class */
    public static class InvertedData {
        final ArrayList<TreeFeature> data;
        final int[] indices;
        final float[][] targets;
        final float[] weights;

        InvertedData(ArrayList<TreeFeature> arrayList, int[] iArr, float[][] fArr, float[] fArr2) {
            this.data = arrayList;
            this.indices = iArr;
            this.targets = fArr;
            this.weights = fArr2;
        }
    }

    public JointRegressorTrainingNode(RegressorImpurity regressorImpurity, Dataset<Regressor> dataset, boolean z) {
        this(regressorImpurity, invertData(dataset), dataset.size(), dataset.getFeatureIDMap(), dataset.getOutputIDInfo(), z);
    }

    private JointRegressorTrainingNode(RegressorImpurity regressorImpurity, InvertedData invertedData, int i, ImmutableFeatureMap immutableFeatureMap, ImmutableOutputInfo<Regressor> immutableOutputInfo, boolean z) {
        this(regressorImpurity, invertedData.data, invertedData.indices, invertedData.targets, invertedData.weights, i, 0, immutableFeatureMap, immutableOutputInfo, z);
    }

    private JointRegressorTrainingNode(RegressorImpurity regressorImpurity, ArrayList<TreeFeature> arrayList, int[] iArr, float[][] fArr, float[] fArr2, int i, int i2, ImmutableFeatureMap immutableFeatureMap, ImmutableOutputInfo<Regressor> immutableOutputInfo, boolean z) {
        super(i2, i);
        this.data = arrayList;
        this.normalize = z;
        this.featureIDMap = immutableFeatureMap;
        this.labelIDMap = immutableOutputInfo;
        this.impurity = regressorImpurity;
        this.indices = iArr;
        this.targets = fArr;
        this.weights = fArr2;
    }

    public double getImpurity() {
        double d = 0.0d;
        for (int i = 0; i < this.targets.length; i++) {
            d += this.impurity.impurity(this.indices, this.targets[i], this.weights);
        }
        return d / this.targets.length;
    }

    public List<AbstractTrainingNode<Regressor>> buildTree(int[] iArr) {
        List<AbstractTrainingNode<Regressor>> emptyList;
        int i = -1;
        double d = 0.0d;
        double sum = Util.sum(this.indices, this.indices.length, this.weights);
        double impurity = getImpurity();
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        ArrayList arrayList3 = new ArrayList();
        for (int i2 = 0; i2 < iArr.length; i2++) {
            List<InvertedFeature> feature = this.data.get(iArr[i2]).getFeature();
            arrayList.clear();
            for (int i3 = 0; i3 < feature.size(); i3++) {
                arrayList.add(feature.get(i3).indices());
            }
            for (int i4 = 0; i4 < feature.size() - 1; i4++) {
                List<int[]> subList = arrayList.subList(0, i4 + 1);
                List<int[]> subList2 = arrayList.subList(i4 + 1, feature.size());
                double d2 = 0.0d;
                double d3 = 0.0d;
                for (int i5 = 0; i5 < this.targets.length; i5++) {
                    RegressorImpurity.ImpurityTuple impurityTuple = this.impurity.impurityTuple(subList, this.targets[i5], this.weights);
                    d2 += impurityTuple.impurity * impurityTuple.weight;
                    RegressorImpurity.ImpurityTuple impurityTuple2 = this.impurity.impurityTuple(subList2, this.targets[i5], this.weights);
                    d3 += impurityTuple2.impurity * impurityTuple2.weight;
                }
                double length = (d2 + d3) / (this.targets.length * sum);
                if (length < impurity) {
                    i = i2;
                    impurity = length;
                    d = (feature.get(i4).value + feature.get(i4 + 1).value) / 2.0d;
                    arrayList2.clear();
                    arrayList2.addAll(subList);
                    arrayList3.clear();
                    arrayList3.addAll(subList2);
                }
            }
        }
        if (i != -1) {
            this.splitID = iArr[i];
            this.split = true;
            this.splitValue = d;
            IntArrayContainer intArrayContainer = mergeBufferOne.get();
            intArrayContainer.size = 0;
            intArrayContainer.grow(this.indices.length);
            IntArrayContainer intArrayContainer2 = mergeBufferTwo.get();
            intArrayContainer2.size = 0;
            intArrayContainer2.grow(this.indices.length);
            int[] merge = IntArrayContainer.merge(arrayList2, intArrayContainer, intArrayContainer2);
            int[] merge2 = IntArrayContainer.merge(arrayList3, intArrayContainer, intArrayContainer2);
            ArrayList arrayList4 = new ArrayList(this.data.size());
            ArrayList arrayList5 = new ArrayList(this.data.size());
            Iterator<TreeFeature> it = this.data.iterator();
            while (it.hasNext()) {
                Pair<TreeFeature, TreeFeature> split = it.next().split(merge, merge2, intArrayContainer, intArrayContainer2);
                arrayList4.add(split.getA());
                arrayList5.add(split.getB());
            }
            this.lessThanOrEqual = new JointRegressorTrainingNode(this.impurity, arrayList4, merge, this.targets, this.weights, merge.length, this.depth + 1, this.featureIDMap, this.labelIDMap, this.normalize);
            this.greaterThan = new JointRegressorTrainingNode(this.impurity, arrayList5, merge2, this.targets, this.weights, merge2.length, this.depth + 1, this.featureIDMap, this.labelIDMap, this.normalize);
            emptyList = new ArrayList();
            emptyList.add(this.lessThanOrEqual);
            emptyList.add(this.greaterThan);
        } else {
            emptyList = Collections.emptyList();
        }
        this.data = null;
        return emptyList;
    }

    public Node<Regressor> convertTree() {
        Regressor regressor;
        if (this.split) {
            return new SplitNode(this.splitValue, this.splitID, getImpurity(), this.greaterThan.convertTree(), this.lessThanOrEqual.convertTree());
        }
        double d = 0.0d;
        double[] dArr = new double[this.targets.length];
        if (this.normalize) {
            for (int i = 0; i < this.indices.length; i++) {
                float f = this.weights[this.indices[i]];
                d += f;
                for (int i2 = 0; i2 < this.targets.length; i2++) {
                    int i3 = i2;
                    dArr[i3] = dArr[i3] + ((f / d) * (this.targets[i2][r0] - dArr[i2]));
                }
            }
            String[] strArr = new String[this.targets.length];
            double d2 = 0.0d;
            for (int i4 = 0; i4 < this.targets.length; i4++) {
                strArr[i4] = this.labelIDMap.getOutput(i4).getNames()[0];
                d2 += dArr[i4];
            }
            for (int i5 = 0; i5 < this.targets.length; i5++) {
                int i6 = i5;
                dArr[i6] = dArr[i6] / d2;
            }
            regressor = new Regressor(strArr, dArr);
        } else {
            double[] dArr2 = new double[this.targets.length];
            for (int i7 = 0; i7 < this.indices.length; i7++) {
                int i8 = this.indices[i7];
                float f2 = this.weights[i8];
                d += f2;
                for (int i9 = 0; i9 < this.targets.length; i9++) {
                    float f3 = this.targets[i9][i8];
                    double d3 = dArr[i9];
                    int i10 = i9;
                    dArr[i10] = dArr[i10] + ((f2 / d) * (f3 - d3));
                    int i11 = i9;
                    dArr2[i11] = dArr2[i11] + (f2 * (f3 - d3) * (f3 - dArr[i9]));
                }
            }
            String[] strArr2 = new String[this.targets.length];
            for (int i12 = 0; i12 < this.targets.length; i12++) {
                strArr2[i12] = this.labelIDMap.getOutput(i12).getNames()[0];
                dArr2[i12] = this.indices.length > 1 ? dArr2[i12] / (d - 1.0d) : 0.0d;
            }
            regressor = new Regressor(strArr2, dArr, dArr2);
        }
        return new LeafNode(getImpurity(), regressor, Collections.emptyMap(), false);
    }

    private static InvertedData invertData(Dataset<Regressor> dataset) {
        ImmutableFeatureMap featureIDMap = dataset.getFeatureIDMap();
        ImmutableOutputInfo outputIDInfo = dataset.getOutputIDInfo();
        int size = outputIDInfo.size();
        int size2 = featureIDMap.size();
        int[] iArr = new int[dataset.size()];
        float[][] fArr = new float[outputIDInfo.size()][dataset.size()];
        float[] fArr2 = new float[dataset.size()];
        logger.fine("Building initial List<TreeFeature> for " + size2 + " features and " + size + " outputs");
        ArrayList arrayList = new ArrayList(featureIDMap.size());
        for (int i = 0; i < featureIDMap.size(); i++) {
            arrayList.add(new TreeFeature(i));
        }
        for (int i2 = 0; i2 < dataset.size(); i2++) {
            Example example = dataset.getExample(i2);
            iArr[i2] = i2;
            fArr2[i2] = example.getWeight();
            double[] values = example.getOutput().getValues();
            for (int i3 = 0; i3 < fArr.length; i3++) {
                fArr[i3][i2] = (float) values[i3];
            }
            int i4 = 0;
            VectorIterator it = SparseVector.createSparseVector(example, featureIDMap, false).iterator();
            while (it.hasNext()) {
                VectorTuple vectorTuple = (VectorTuple) it.next();
                int i5 = vectorTuple.index;
                for (int i6 = i4; i6 < i5; i6++) {
                    ((TreeFeature) arrayList.get(i6)).observeValue(0.0d, i2);
                }
                ((TreeFeature) arrayList.get(i5)).observeValue(vectorTuple.value, i2);
                if (i4 > i5) {
                    logger.severe("Example = " + example.toString());
                    throw new IllegalStateException("Features aren't ordered. At id " + i2 + ", lastID = " + i4 + ", curID = " + i5);
                }
                if (i4 - 1 == i5) {
                    logger.severe("Example = " + example.toString());
                    throw new IllegalStateException("Features are repeated. At id " + i2 + ", lastID = " + i4 + ", curID = " + i5);
                }
                i4 = i5 + 1;
            }
            for (int i7 = i4; i7 < size2; i7++) {
                ((TreeFeature) arrayList.get(i7)).observeValue(0.0d, i2);
            }
            if (i2 % 1000 == 0) {
                logger.fine("Processed example " + i2);
            }
        }
        logger.fine("Sorting features");
        arrayList.forEach((v0) -> {
            v0.sort();
        });
        logger.fine("Fixing InvertedFeature sizes");
        arrayList.forEach((v0) -> {
            v0.fixSize();
        });
        logger.fine("Built initial List<TreeFeature>");
        return new InvertedData(arrayList, iArr, fArr, fArr2);
    }

    private void writeObject(ObjectOutputStream objectOutputStream) throws IOException {
        throw new NotSerializableException("JointRegressorTrainingNode is a runtime class only, and should not be serialized.");
    }
}
