package sklearn2pmml.ensemble;

import com.google.common.collect.Iterables;
import java.util.List;
import org.dmg.pmml.DataType;
import org.dmg.pmml.OpType;
import org.dmg.pmml.mining.MiningModel;
import org.dmg.pmml.regression.RegressionModel;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.SchemaUtil;
import org.jpmml.converter.Transformation;
import org.jpmml.converter.mining.MiningModelUtil;
import sklearn.Classifier;
import sklearn.HasMultiDecisionFunctionField;
import sklearn.linear_model.LinearClassifier;
import sklearn.preprocessing.MultiOneHotEncoder;

/* loaded from: input_file:sklearn2pmml/ensemble/GBDTLRClassifier.class */
public class GBDTLRClassifier extends Classifier implements HasMultiDecisionFunctionField {
    public GBDTLRClassifier(String str, String str2) {
        super(str, str2);
    }

    @Override // sklearn.Classifier, sklearn.HasClasses
    public List<?> getClasses() {
        return getGBDT().getClasses();
    }

    @Override // sklearn.Classifier, sklearn.HasClasses
    public boolean hasProbabilityDistribution() {
        return getLR().hasProbabilityDistribution();
    }

    @Override // sklearn.Estimator
    /* renamed from: encodeModel, reason: merged with bridge method [inline-methods] */
    public MiningModel mo7encodeModel(Schema schema) {
        Classifier gbdt = getGBDT();
        MultiOneHotEncoder ohe = getOHE();
        LinearClassifier lr = getLR();
        CategoricalLabel label = schema.getLabel();
        SchemaUtil.checkSize(2, label);
        List<Number> coef = lr.getCoef();
        List<Number> intercept = lr.getIntercept();
        MiningModel createBinaryLogisticClassification = MiningModelUtil.createBinaryLogisticClassification(GBDTUtil.encodeModel(gbdt, ohe, coef, (Number) Iterables.getOnlyElement(intercept), schema.toAnonymousSchema()).setOutput(ModelUtil.createPredictedOutput(getMultiDecisionFunctionField(label.getValue(1)), OpType.CONTINUOUS, DataType.DOUBLE, new Transformation[0])), 1.0d, 0.0d, RegressionModel.NormalizationMethod.LOGIT, false, schema);
        if (lr.hasProbabilityDistribution()) {
            encodePredictProbaOutput(createBinaryLogisticClassification, DataType.DOUBLE, label);
        }
        return createBinaryLogisticClassification;
    }

    public Classifier getGBDT() {
        return (Classifier) get("gbdt_", Classifier.class);
    }

    public LinearClassifier getLR() {
        return (LinearClassifier) get("lr_", LinearClassifier.class);
    }

    public MultiOneHotEncoder getOHE() {
        return (MultiOneHotEncoder) get("ohe_", MultiOneHotEncoder.class);
    }
}
