package hex.genmodel.algos.xgboost;

import biz.k11i.xgboost.Predictor;
import biz.k11i.xgboost.gbm.GBTree;
import biz.k11i.xgboost.learner.ObjFunction;
import biz.k11i.xgboost.tree.RegTree;
import biz.k11i.xgboost.tree.TreeSHAPHelper;
import biz.k11i.xgboost.util.FVec;
import hex.genmodel.PredictContributions;
import hex.genmodel.PredictContributionsFactory;
import hex.genmodel.algos.tree.ContributionsPredictor;
import hex.genmodel.algos.tree.ConvertTreeOptions;
import hex.genmodel.algos.tree.SharedTreeGraph;
import hex.genmodel.algos.tree.SharedTreeMojoModel;
import hex.genmodel.algos.tree.TreeSHAPEnsemble;
import hex.genmodel.algos.tree.TreeSHAPPredictor;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.util.ArrayList;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;

/* loaded from: input_file:hex/genmodel/algos/xgboost/XGBoostJavaMojoModel.class */
public final class XGBoostJavaMojoModel extends XGBoostMojoModel implements PredictContributionsFactory {
    private Predictor _predictor;
    private TreeSHAPPredictor<FVec> _treeSHAPPredictor;
    private OneHotEncoderFactory _1hotFactory;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:hex/genmodel/algos/xgboost/XGBoostJavaMojoModel$XGBoostContributionsPredictor.class */
    private final class XGBoostContributionsPredictor extends ContributionsPredictor<FVec> {
        private XGBoostContributionsPredictor(XGBoostMojoModel xGBoostMojoModel, TreeSHAPPredictor<FVec> treeSHAPPredictor) {
            super(XGBoostJavaMojoModel.this._nums + XGBoostJavaMojoModel.this._catOffsets[XGBoostJavaMojoModel.this._cats] + 1, XGBoostJavaMojoModel.makeFeatureContributionNames(xGBoostMojoModel), treeSHAPPredictor);
        }

        /* JADX INFO: Access modifiers changed from: protected */
        /* JADX WARN: Can't rename method to resolve collision */
        @Override // hex.genmodel.algos.tree.ContributionsPredictor
        public FVec toInputRow(double[] dArr) {
            return XGBoostJavaMojoModel.this._1hotFactory.fromArray(dArr);
        }
    }

    public XGBoostJavaMojoModel(byte[] bArr, String[] strArr, String[][] strArr2, String str) {
        this(bArr, strArr, strArr2, str, false);
    }

    public XGBoostJavaMojoModel(byte[] bArr, String[] strArr, String[][] strArr2, String str, boolean z) {
        super(strArr, strArr2, str);
        this._predictor = makePredictor(bArr);
        this._treeSHAPPredictor = z ? makeTreeSHAPPredictor(this._predictor) : null;
    }

    @Override // hex.genmodel.algos.xgboost.XGBoostMojoModel
    public void postReadInit() {
        this._1hotFactory = new OneHotEncoderFactory(backwardsCompatibility10(), this._sparse, this._catOffsets, this._cats, this._nums, this._useAllFactorLevels);
    }

    private boolean backwardsCompatibility10() {
        return this._mojo_version == 1.0d && !"gbtree".equals(this._boosterType);
    }

    public static Predictor makePredictor(byte[] bArr) {
        try {
            ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(bArr);
            Throwable th = null;
            try {
                Predictor predictor = new Predictor(byteArrayInputStream);
                if (byteArrayInputStream != null) {
                    if (0 != 0) {
                        try {
                            byteArrayInputStream.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        byteArrayInputStream.close();
                    }
                }
                return predictor;
            } finally {
            }
        } catch (IOException e) {
            throw new IllegalStateException("Failed to load predictor.", e);
        }
    }

    private static TreeSHAPPredictor<FVec> makeTreeSHAPPredictor(Predictor predictor) {
        if (predictor.getNumClass() > 2) {
            throw new UnsupportedOperationException("Calculating contributions is currently not supported for multinomial models.");
        }
        RegTree[] regTreeArr = ((GBTree) predictor.getBooster()).getGroupedTrees()[0];
        ArrayList arrayList = new ArrayList(regTreeArr.length);
        for (RegTree regTree : regTreeArr) {
            arrayList.add(TreeSHAPHelper.makePredictor(regTree));
        }
        return new TreeSHAPEnsemble(arrayList, predictor.getBaseScore());
    }

    @Override // hex.genmodel.GenModel
    public final double[] score0(double[] dArr, double d, double[] dArr2) {
        float[] predict;
        if (backwardsCompatibility10()) {
            if (dArr.length > this._cats + this._nums) {
                throw new ArrayIndexOutOfBoundsException("Too many input values.");
            }
            if (dArr.length < this._cats + this._nums) {
                double[] dArr3 = new double[this._cats + this._nums];
                System.arraycopy(dArr, 0, dArr3, 0, dArr.length);
                dArr = dArr3;
            }
        }
        FVec fromArray = this._1hotFactory.fromArray(dArr);
        if (this._hasOffset) {
            predict = this._predictor.predict(fromArray, (float) d);
        } else {
            if (d != CMAESOptimizer.DEFAULT_STOPFITNESS) {
                throw new UnsupportedOperationException("Unsupported: offset != 0");
            }
            predict = this._predictor.predict(fromArray);
        }
        return toPreds(dArr, predict, dArr2, this._nclasses, this._priorClassDistrib, this._defaultThreshold);
    }

    public final Object makeContributionsWorkspace() {
        return this._treeSHAPPredictor.makeWorkspace();
    }

    public final float[] calculateContributions(FVec fVec, float[] fArr, Object obj) {
        this._treeSHAPPredictor.calculateContributions(fVec, fArr, 0, -1, obj);
        return fArr;
    }

    @Override // hex.genmodel.PredictContributionsFactory
    public final PredictContributions makeContributionsPredictor() {
        return new XGBoostContributionsPredictor(this, this._treeSHAPPredictor != null ? this._treeSHAPPredictor : makeTreeSHAPPredictor(this._predictor));
    }

    static ObjFunction getObjFunction(String str) {
        return ObjFunction.fromName(str);
    }

    @Override // java.io.Closeable, java.lang.AutoCloseable
    public void close() {
        this._predictor = null;
        this._treeSHAPPredictor = null;
        this._1hotFactory = null;
    }

    @Override // hex.genmodel.algos.tree.SharedTreeGraphConverter
    public SharedTreeGraph convert(int i, String str) {
        return computeGraph(this._predictor.getBooster(), i);
    }

    @Override // hex.genmodel.algos.xgboost.XGBoostMojoModel, hex.genmodel.algos.tree.SharedTreeGraphConverter
    public SharedTreeGraph convert(int i, String str, ConvertTreeOptions convertTreeOptions) {
        return convert(i, str);
    }

    @Override // hex.genmodel.algos.tree.TreeBackedMojoModel
    public double getInitF() {
        return this._predictor.getBaseScore();
    }

    @Override // hex.genmodel.algos.tree.TreeBackedMojoModel
    public SharedTreeMojoModel.LeafNodeAssignments getLeafNodeAssignments(double[] dArr) {
        FVec fromArray = this._1hotFactory.fromArray(dArr);
        SharedTreeMojoModel.LeafNodeAssignments leafNodeAssignments = new SharedTreeMojoModel.LeafNodeAssignments();
        leafNodeAssignments._paths = this._predictor.predictLeafPath(fromArray);
        leafNodeAssignments._nodeIds = this._predictor.predictLeaf(fromArray);
        return leafNodeAssignments;
    }

    @Override // hex.genmodel.algos.tree.TreeBackedMojoModel
    public String[] getDecisionPath(double[] dArr) {
        return this._predictor.predictLeafPath(this._1hotFactory.fromArray(dArr));
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static String[] makeFeatureContributionNames(XGBoostMojoModel xGBoostMojoModel) {
        String[] strArr = new String[xGBoostMojoModel._nums + xGBoostMojoModel._catOffsets[xGBoostMojoModel._cats]];
        String[] features = xGBoostMojoModel.features();
        int i = 0;
        for (int i2 = 0; i2 < features.length; i2++) {
            if (xGBoostMojoModel._domains[i2] == null) {
                int i3 = i;
                i++;
                strArr[i3] = features[i2];
            } else {
                for (String str : xGBoostMojoModel._domains[i2]) {
                    int i4 = i;
                    i++;
                    strArr[i4] = features[i2] + "." + str;
                }
                int i5 = i;
                i++;
                strArr[i5] = features[i2] + ".missing(NA)";
            }
        }
        if ($assertionsDisabled || strArr.length == i) {
            return strArr;
        }
        throw new AssertionError();
    }

    static {
        $assertionsDisabled = !XGBoostJavaMojoModel.class.desiredAssertionStatus();
    }
}
