package sklearn.ensemble.hist_gradient_boosting;

import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import org.dmg.pmml.DataType;
import org.dmg.pmml.OpType;
import org.dmg.pmml.mining.MiningModel;
import org.dmg.pmml.regression.RegressionModel;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.FieldNameUtil;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.SchemaUtil;
import org.jpmml.converter.Transformation;
import org.jpmml.converter.mining.MiningModelUtil;
import org.jpmml.python.ClassDictUtil;
import org.jpmml.python.PythonObject;
import sklearn.Classifier;
import sklearn.Estimator;
import sklearn.loss.HalfBinomialLoss;
import sklearn.loss.HalfMultinomialLoss;

/* loaded from: input_file:sklearn/ensemble/hist_gradient_boosting/HistGradientBoostingClassifier.class */
public class HistGradientBoostingClassifier extends Classifier {
    public HistGradientBoostingClassifier(String str, String str2) {
        super(str, str2);
    }

    @Override // sklearn.Estimator
    /* renamed from: encodeModel, reason: merged with bridge method [inline-methods] */
    public MiningModel mo16encodeModel(Schema schema) {
        List<? extends Number> baselinePrediction = getBaselinePrediction();
        PythonObject loss = getLoss();
        BinMapper binMapper = getBinMapper();
        int intValue = getNumberOfTreesPerIteration().intValue();
        List<List<TreePredictor>> predictors = getPredictors();
        if (!predictors.isEmpty()) {
            ClassDictUtil.checkSize(intValue, new Collection[]{predictors.get(0), baselinePrediction});
        }
        Schema anonymousRegressorSchema = schema.toAnonymousRegressorSchema(DataType.DOUBLE);
        CategoricalLabel label = schema.getLabel();
        if (intValue == 1) {
            SchemaUtil.checkSize(2, label);
            if ((loss instanceof BinaryCrossEntropy) || (loss instanceof HalfBinomialLoss)) {
                return MiningModelUtil.createBinaryLogisticClassification(HistGradientBoostingUtil.encodeHistGradientBoosting(predictors, binMapper, baselinePrediction, 0, anonymousRegressorSchema).setOutput(ModelUtil.createPredictedOutput(FieldNameUtil.create(Estimator.FIELD_DECISION_FUNCTION, new Object[]{label.getValue(1)}), OpType.CONTINUOUS, DataType.DOUBLE, new Transformation[0])), 1.0d, 0.0d, RegressionModel.NormalizationMethod.LOGIT, true, schema);
            }
            throw new IllegalArgumentException();
        }
        if (intValue < 3) {
            throw new IllegalArgumentException();
        }
        SchemaUtil.checkSize(intValue, label);
        if (!(loss instanceof CategoricalCrossEntropy) && !(loss instanceof HalfMultinomialLoss)) {
            throw new IllegalArgumentException();
        }
        ArrayList arrayList = new ArrayList();
        int size = label.size();
        for (int i = 0; i < size; i++) {
            arrayList.add(HistGradientBoostingUtil.encodeHistGradientBoosting(predictors, binMapper, baselinePrediction, i, anonymousRegressorSchema).setOutput(ModelUtil.createPredictedOutput(FieldNameUtil.create(Estimator.FIELD_DECISION_FUNCTION, new Object[]{label.getValue(i)}), OpType.CONTINUOUS, DataType.DOUBLE, new Transformation[0])));
        }
        return MiningModelUtil.createClassification(arrayList, RegressionModel.NormalizationMethod.SOFTMAX, true, schema);
    }

    public List<? extends Number> getBaselinePrediction() {
        return getNumberArray("_baseline_prediction");
    }

    public BinMapper getBinMapper() {
        return (BinMapper) getOptional("_bin_mapper", BinMapper.class);
    }

    public PythonObject getLoss() {
        if (containsKey("loss_")) {
            get("loss_", BaseLoss.class);
        }
        try {
            return (PythonObject) get("_loss", BaseLoss.class);
        } catch (IllegalArgumentException e) {
            return (PythonObject) get("_loss", sklearn.loss.BaseLoss.class);
        }
    }

    public Integer getNumberOfTreesPerIteration() {
        return getInteger("n_trees_per_iteration_");
    }

    public List<List<TreePredictor>> getPredictors() {
        return getList("_predictors", List.class);
    }
}
