package smile.classification;

import java.io.Serializable;
import java.util.Arrays;
import smile.data.Attribute;
import smile.data.NumericAttribute;
import smile.math.Math;
import smile.regression.RegressionTree;
import smile.util.SmileUtils;
import smile.validation.Accuracy;
import smile.validation.ClassificationMeasure;

/* loaded from: input_file:libarx-3.7.1.jar:smile/classification/GradientTreeBoost.class */
public class GradientTreeBoost extends SoftClassifier<double[]> implements Serializable {
    private static final long serialVersionUID = 1;
    private int k;
    private RegressionTree[] trees;
    private RegressionTree[][] forest;
    private double[] importance;
    private double b;
    private double shrinkage;
    private int maxNodes;
    private int ntrees;
    private double subsample;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:libarx-3.7.1.jar:smile/classification/GradientTreeBoost$L2NodeOutput.class */
    public class L2NodeOutput implements RegressionTree.NodeOutput {
        double[] y;

        public L2NodeOutput(double[] dArr) {
            this.y = dArr;
        }

        @Override // smile.regression.RegressionTree.NodeOutput
        public double calculate(int[] iArr) {
            double d = 0.0d;
            double d2 = 0.0d;
            for (int i = 0; i < iArr.length; i++) {
                if (iArr[i] > 0) {
                    double abs = Math.abs(this.y[i]);
                    d += this.y[i];
                    d2 += abs * (2.0d - abs);
                }
            }
            return d / d2;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:libarx-3.7.1.jar:smile/classification/GradientTreeBoost$LKNodeOutput.class */
    public class LKNodeOutput implements RegressionTree.NodeOutput {
        double[] y;

        public LKNodeOutput(double[] dArr) {
            this.y = dArr;
        }

        @Override // smile.regression.RegressionTree.NodeOutput
        public double calculate(int[] iArr) {
            int i = 0;
            double d = 0.0d;
            double d2 = 0.0d;
            for (int i2 = 0; i2 < iArr.length; i2++) {
                if (iArr[i2] > 0) {
                    i++;
                    double abs = Math.abs(this.y[i2]);
                    d += this.y[i2];
                    d2 += abs * (1.0d - abs);
                }
            }
            return d2 < 1.0E-10d ? d / i : ((GradientTreeBoost.this.k - 1.0d) / GradientTreeBoost.this.k) * (d / d2);
        }
    }

    /* loaded from: input_file:libarx-3.7.1.jar:smile/classification/GradientTreeBoost$Trainer.class */
    public static class Trainer extends ClassifierTrainer<double[]> {
        private int ntrees;
        private double shrinkage;
        private int maxNodes;
        private double subsample;

        public Trainer(Attribute[] attributeArr, int i) {
            super(attributeArr);
            this.ntrees = 500;
            this.shrinkage = 0.005d;
            this.maxNodes = 6;
            this.subsample = 0.7d;
            if (i < 1) {
                throw new IllegalArgumentException("Invalid number of trees: " + i);
            }
            this.ntrees = i;
        }

        public Trainer(int i, TrainingInterrupt trainingInterrupt) {
            super(trainingInterrupt);
            this.ntrees = 500;
            this.shrinkage = 0.005d;
            this.maxNodes = 6;
            this.subsample = 0.7d;
            if (i < 1) {
                throw new IllegalArgumentException("Invalid number of trees: " + i);
            }
            this.ntrees = i;
        }

        public Trainer(TrainingInterrupt trainingInterrupt) {
            super(trainingInterrupt);
            this.ntrees = 500;
            this.shrinkage = 0.005d;
            this.maxNodes = 6;
            this.subsample = 0.7d;
        }

        public Trainer setMaxNodes(int i) {
            if (i < 2) {
                throw new IllegalArgumentException("Invalid maximum number of leaf nodes: " + i);
            }
            this.maxNodes = i;
            return this;
        }

        public Trainer setNumTrees(int i) {
            if (i < 1) {
                throw new IllegalArgumentException("Invalid number of trees: " + i);
            }
            this.ntrees = i;
            return this;
        }

        public Trainer setSamplingRates(double d) {
            if (d <= 0.0d || d > 1.0d) {
                throw new IllegalArgumentException("Invalid sampling fraction: " + d);
            }
            this.subsample = d;
            return this;
        }

        public Trainer setShrinkage(double d) {
            if (d <= 0.0d || d > 1.0d) {
                throw new IllegalArgumentException("Invalid shrinkage: " + d);
            }
            this.shrinkage = d;
            return this;
        }

        @Override // smile.classification.ClassifierTrainer
        public GradientTreeBoost train(double[][] dArr, int[] iArr) {
            return new GradientTreeBoost(this.attributes, dArr, iArr, this.ntrees, this.maxNodes, this.shrinkage, this.subsample, this.interrupt);
        }
    }

    public GradientTreeBoost(Attribute[] attributeArr, double[][] dArr, int[] iArr, int i, int i2, double d, double d2, TrainingInterrupt trainingInterrupt) {
        super(trainingInterrupt);
        this.k = 2;
        this.b = 0.0d;
        this.shrinkage = 0.005d;
        this.maxNodes = 6;
        this.ntrees = 500;
        this.subsample = 0.7d;
        if (dArr.length != iArr.length) {
            throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", Integer.valueOf(dArr.length), Integer.valueOf(iArr.length)));
        }
        if (i < 1) {
            throw new IllegalArgumentException("Invalid number of trees: " + i);
        }
        if (i2 < 2) {
            throw new IllegalArgumentException("Invalid maximum leaves: " + i2);
        }
        if (d <= 0.0d || d > 1.0d) {
            throw new IllegalArgumentException("Invalid shrinkage: " + d);
        }
        if (d2 <= 0.0d || d2 > 1.0d) {
            throw new IllegalArgumentException("Invalid sampling fraction: " + d2);
        }
        if (attributeArr == null) {
            int length = dArr[0].length;
            attributeArr = new Attribute[length];
            for (int i3 = 0; i3 < length; i3++) {
                attributeArr[i3] = new NumericAttribute("V" + (i3 + 1));
            }
        }
        this.ntrees = i;
        this.maxNodes = i2;
        this.shrinkage = d;
        this.subsample = d2;
        this.k = Math.max(iArr) + 1;
        if (this.k < 2) {
            throw new IllegalArgumentException("Only one class or negative class labels.");
        }
        this.importance = new double[attributeArr.length];
        if (this.k == 2) {
            train2(attributeArr, dArr, iArr);
            for (RegressionTree regressionTree : this.trees) {
                double[] importance = regressionTree.importance();
                for (int i4 = 0; i4 < importance.length; i4++) {
                    double[] dArr2 = this.importance;
                    int i5 = i4;
                    dArr2[i5] = dArr2[i5] + importance[i4];
                }
            }
            return;
        }
        traink(attributeArr, dArr, iArr);
        for (RegressionTree[] regressionTreeArr : this.forest) {
            for (RegressionTree regressionTree2 : regressionTreeArr) {
                double[] importance2 = regressionTree2.importance();
                for (int i6 = 0; i6 < importance2.length; i6++) {
                    double[] dArr3 = this.importance;
                    int i7 = i6;
                    dArr3[i7] = dArr3[i7] + importance2[i6];
                }
            }
        }
    }

    public GradientTreeBoost(Attribute[] attributeArr, double[][] dArr, int[] iArr, int i, TrainingInterrupt trainingInterrupt) {
        this(attributeArr, dArr, iArr, i, 6, dArr.length < 2000 ? 0.005d : 0.05d, 0.7d, trainingInterrupt);
    }

    public GradientTreeBoost(double[][] dArr, int[] iArr, int i, int i2, double d, double d2, TrainingInterrupt trainingInterrupt) {
        this(null, dArr, iArr, i, i2, d, d2, trainingInterrupt);
    }

    public GradientTreeBoost(double[][] dArr, int[] iArr, int i, TrainingInterrupt trainingInterrupt) {
        this(null, dArr, iArr, i, trainingInterrupt);
    }

    public RegressionTree[] getTrees() {
        return this.trees;
    }

    public double[] importance() {
        return this.importance;
    }

    @Override // smile.classification.Classifier
    public int predict(double[] dArr) {
        if (this.k == 2) {
            double d = this.b;
            for (int i = 0; i < this.ntrees; i++) {
                d += this.shrinkage * this.trees[i].predict(dArr);
            }
            return d > 0.0d ? 1 : 0;
        }
        double d2 = Double.NEGATIVE_INFINITY;
        int i2 = -1;
        for (int i3 = 0; i3 < this.k; i3++) {
            double d3 = 0.0d;
            for (int i4 = 0; i4 < this.ntrees; i4++) {
                d3 += this.shrinkage * this.forest[i3][i4].predict(dArr);
            }
            if (d3 > d2) {
                d2 = d3;
                i2 = i3;
            }
        }
        return i2;
    }

    @Override // smile.classification.SoftClassifier
    public int predict(double[] dArr, double[] dArr2) {
        if (dArr2.length != this.k) {
            throw new IllegalArgumentException(String.format("Invalid posteriori vector size: %d, expected: %d", Integer.valueOf(dArr2.length), Integer.valueOf(this.k)));
        }
        if (this.k == 2) {
            double d = this.b;
            for (int i = 0; i < this.ntrees; i++) {
                d += this.shrinkage * this.trees[i].predict(dArr);
            }
            dArr2[0] = 1.0d / (1.0d + Math.exp(2.0d * d));
            dArr2[1] = 1.0d - dArr2[0];
            return d > 0.0d ? 1 : 0;
        }
        double d2 = Double.NEGATIVE_INFINITY;
        int i2 = -1;
        for (int i3 = 0; i3 < this.k; i3++) {
            dArr2[i3] = 0.0d;
            for (int i4 = 0; i4 < this.ntrees; i4++) {
                int i5 = i3;
                dArr2[i5] = dArr2[i5] + (this.shrinkage * this.forest[i3][i4].predict(dArr));
            }
            if (dArr2[i3] > d2) {
                d2 = dArr2[i3];
                i2 = i3;
            }
        }
        double d3 = 0.0d;
        for (int i6 = 0; i6 < this.k; i6++) {
            dArr2[i6] = Math.exp(dArr2[i6] - d2);
            d3 += dArr2[i6];
        }
        for (int i7 = 0; i7 < this.k; i7++) {
            int i8 = i7;
            dArr2[i8] = dArr2[i8] / d3;
        }
        return i2;
    }

    public int size() {
        return this.trees.length;
    }

    public double[] test(double[][] dArr, int[] iArr) {
        double[] dArr2 = new double[this.ntrees];
        int length = dArr.length;
        int[] iArr2 = new int[length];
        Accuracy accuracy = new Accuracy();
        if (this.k == 2) {
            double[] dArr3 = new double[length];
            Arrays.fill(dArr3, this.b);
            for (int i = 0; i < this.ntrees; i++) {
                for (int i2 = 0; i2 < length; i2++) {
                    int i3 = i2;
                    dArr3[i3] = dArr3[i3] + (this.shrinkage * this.trees[i].predict(dArr[i2]));
                    iArr2[i2] = dArr3[i2] > 0.0d ? 1 : 0;
                }
                dArr2[i] = accuracy.measure(iArr, iArr2);
            }
        } else {
            double[][] dArr4 = new double[length][this.k];
            for (int i4 = 0; i4 < this.ntrees; i4++) {
                for (int i5 = 0; i5 < length; i5++) {
                    for (int i6 = 0; i6 < this.k; i6++) {
                        double[] dArr5 = dArr4[i5];
                        int i7 = i6;
                        dArr5[i7] = dArr5[i7] + (this.shrinkage * this.forest[i6][i4].predict(dArr[i5]));
                    }
                    iArr2[i5] = Math.whichMax(dArr4[i5]);
                }
                dArr2[i4] = accuracy.measure(iArr, iArr2);
            }
        }
        return dArr2;
    }

    public double[][] test(double[][] dArr, int[] iArr, ClassificationMeasure[] classificationMeasureArr) {
        int length = classificationMeasureArr.length;
        double[][] dArr2 = new double[this.ntrees][length];
        int length2 = dArr.length;
        int[] iArr2 = new int[length2];
        if (this.k == 2) {
            double[] dArr3 = new double[length2];
            Arrays.fill(dArr3, this.b);
            for (int i = 0; i < this.ntrees; i++) {
                for (int i2 = 0; i2 < length2; i2++) {
                    int i3 = i2;
                    dArr3[i3] = dArr3[i3] + (this.shrinkage * this.trees[i].predict(dArr[i2]));
                    iArr2[i2] = dArr3[i2] > 0.0d ? 1 : 0;
                }
                for (int i4 = 0; i4 < length; i4++) {
                    dArr2[i][i4] = classificationMeasureArr[i4].measure(iArr, iArr2);
                }
            }
        } else {
            double[][] dArr4 = new double[length2][this.k];
            for (int i5 = 0; i5 < this.ntrees; i5++) {
                for (int i6 = 0; i6 < length2; i6++) {
                    for (int i7 = 0; i7 < this.k; i7++) {
                        double[] dArr5 = dArr4[i6];
                        int i8 = i7;
                        dArr5[i8] = dArr5[i8] + (this.shrinkage * this.forest[i7][i5].predict(dArr[i6]));
                    }
                    iArr2[i6] = Math.whichMax(dArr4[i6]);
                }
                for (int i9 = 0; i9 < length; i9++) {
                    dArr2[i5][i9] = classificationMeasureArr[i9].measure(iArr, iArr2);
                }
            }
        }
        return dArr2;
    }

    public void trim(int i) {
        if (i < 1) {
            throw new IllegalArgumentException("Invalid new model size: " + i);
        }
        if (this.k == 2) {
            if (i > this.trees.length) {
                throw new IllegalArgumentException("The new model size is larger than the current size.");
            }
            if (i < this.trees.length) {
                this.trees = (RegressionTree[]) Arrays.copyOf(this.trees, i);
                this.ntrees = i;
                return;
            }
            return;
        }
        if (i > this.forest[0].length) {
            throw new IllegalArgumentException("The new model size is larger than the current one.");
        }
        if (i < this.forest[0].length) {
            for (int i2 = 0; i2 < this.forest.length; i2++) {
                this.forest[i2] = (RegressionTree[]) Arrays.copyOf(this.forest[i2], i);
            }
            this.ntrees = i;
        }
    }

    private void train2(Attribute[] attributeArr, double[][] dArr, int[] iArr) {
        int length = dArr.length;
        int[] iArr2 = new int[this.k];
        for (int i = 0; i < length; i++) {
            int i2 = iArr[i];
            iArr2[i2] = iArr2[i2] + 1;
        }
        int[] iArr3 = new int[length];
        for (int i3 = 0; i3 < length; i3++) {
            if (iArr[i3] == 1) {
                iArr3[i3] = 1;
            } else {
                iArr3[i3] = -1;
            }
        }
        double[] dArr2 = new double[length];
        double[] dArr3 = new double[length];
        double mean = Math.mean(iArr3);
        this.b = 0.5d * Math.log((1.0d + mean) / (1.0d - mean));
        for (int i4 = 0; i4 < length; i4++) {
            dArr2[i4] = this.b;
        }
        int[][] sort = SmileUtils.sort(attributeArr, dArr);
        L2NodeOutput l2NodeOutput = new L2NodeOutput(dArr3);
        this.trees = new RegressionTree[this.ntrees];
        int[] iArr4 = new int[length];
        int[] iArr5 = new int[length];
        for (int i5 = 0; i5 < length; i5++) {
            iArr4[i5] = i5;
        }
        for (int i6 = 0; i6 < this.ntrees; i6++) {
            Arrays.fill(iArr5, 0);
            Math.permutate(iArr4);
            for (int i7 = 0; i7 < this.k; i7++) {
                int round = (int) Math.round(iArr2[i7] * this.subsample);
                int i8 = 0;
                for (int i9 = 0; i9 < length && i8 < round; i9++) {
                    int i10 = iArr4[i9];
                    if (iArr[i10] == i7) {
                        iArr5[i10] = 1;
                        i8++;
                    }
                }
            }
            for (int i11 = 0; i11 < length; i11++) {
                dArr3[i11] = (2.0d * iArr3[i11]) / (1.0d + Math.exp((2 * iArr3[i11]) * dArr2[i11]));
            }
            this.trees[i6] = new RegressionTree(attributeArr, dArr, dArr3, this.maxNodes, 5, dArr[0].length, sort, iArr5, l2NodeOutput);
            for (int i12 = 0; i12 < length; i12++) {
                int i13 = i12;
                dArr2[i13] = dArr2[i13] + (this.shrinkage * this.trees[i6].predict(dArr[i12]));
            }
        }
    }

    private void traink(Attribute[] attributeArr, double[][] dArr, int[] iArr) {
        int length = dArr.length;
        int[] iArr2 = new int[this.k];
        for (int i = 0; i < length; i++) {
            int i2 = iArr[i];
            iArr2[i2] = iArr2[i2] + 1;
        }
        double[][] dArr2 = new double[this.k][length];
        double[][] dArr3 = new double[this.k][length];
        double[][] dArr4 = new double[this.k][length];
        int[][] sort = SmileUtils.sort(attributeArr, dArr);
        this.forest = new RegressionTree[this.k][this.ntrees];
        LKNodeOutput[] lKNodeOutputArr = new LKNodeOutput[this.k];
        for (int i3 = 0; i3 < this.k; i3++) {
            lKNodeOutputArr[i3] = new LKNodeOutput(dArr4[i3]);
        }
        int[] iArr3 = new int[length];
        int[] iArr4 = new int[length];
        for (int i4 = 0; i4 < length; i4++) {
            iArr3[i4] = i4;
        }
        for (int i5 = 0; i5 < this.ntrees; i5++) {
            for (int i6 = 0; i6 < length; i6++) {
                double d = Double.NEGATIVE_INFINITY;
                for (int i7 = 0; i7 < this.k; i7++) {
                    if (d < dArr2[i7][i6]) {
                        d = dArr2[i7][i6];
                    }
                }
                double d2 = 0.0d;
                for (int i8 = 0; i8 < this.k; i8++) {
                    dArr3[i8][i6] = Math.exp(dArr2[i8][i6] - d);
                    d2 += dArr3[i8][i6];
                }
                for (int i9 = 0; i9 < this.k; i9++) {
                    double[] dArr5 = dArr3[i9];
                    int i10 = i6;
                    dArr5[i10] = dArr5[i10] / d2;
                }
            }
            for (int i11 = 0; i11 < this.k; i11++) {
                for (int i12 = 0; i12 < length; i12++) {
                    if (iArr[i12] == i11) {
                        dArr4[i11][i12] = 1.0d;
                    } else {
                        dArr4[i11][i12] = 0.0d;
                    }
                    double[] dArr6 = dArr4[i11];
                    int i13 = i12;
                    dArr6[i13] = dArr6[i13] - dArr3[i11][i12];
                }
                Arrays.fill(iArr4, 0);
                Math.permutate(iArr3);
                for (int i14 = 0; i14 < this.k; i14++) {
                    int round = (int) Math.round(iArr2[i14] * this.subsample);
                    int i15 = 0;
                    for (int i16 = 0; i16 < length && i15 < round; i16++) {
                        int i17 = iArr3[i16];
                        if (iArr[i17] == i14) {
                            iArr4[i17] = 1;
                            i15++;
                        }
                    }
                }
                this.forest[i11][i5] = new RegressionTree(attributeArr, dArr, dArr4[i11], this.maxNodes, 5, dArr[0].length, sort, iArr4, lKNodeOutputArr[i11]);
                for (int i18 = 0; i18 < length; i18++) {
                    double[] dArr7 = dArr2[i11];
                    int i19 = i18;
                    dArr7[i19] = dArr7[i19] + (this.shrinkage * this.forest[i11][i5].predict(dArr[i18]));
                }
            }
        }
    }
}
