package sklearn.calibration;

import com.google.common.collect.Iterables;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import org.dmg.pmml.DataType;
import org.dmg.pmml.DerivedField;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Model;
import org.dmg.pmml.OpType;
import org.dmg.pmml.Output;
import org.dmg.pmml.OutputField;
import org.dmg.pmml.ResultFeature;
import org.dmg.pmml.mining.MiningModel;
import org.dmg.pmml.mining.Segment;
import org.dmg.pmml.mining.Segmentation;
import org.dmg.pmml.regression.RegressionModel;
import org.dmg.pmml.regression.RegressionTable;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.Feature;
import org.jpmml.converter.FieldNameUtil;
import org.jpmml.converter.Label;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.SchemaUtil;
import org.jpmml.converter.mining.MiningModelUtil;
import org.jpmml.converter.regression.RegressionModelUtil;
import org.jpmml.sklearn.SkLearnEncoder;
import sklearn.Calibrator;
import sklearn.Classifier;
import sklearn.Estimator;
import sklearn.HasEstimator;
import sklearn.linear_model.LinearClassifier;

/* loaded from: input_file:sklearn/calibration/CalibratedClassifier.class */
public class CalibratedClassifier extends Classifier implements HasEstimator<Classifier> {

    /* renamed from: sklearn.calibration.CalibratedClassifier$1, reason: invalid class name */
    /* loaded from: input_file:sklearn/calibration/CalibratedClassifier$1.class */
    static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$dmg$pmml$ResultFeature = new int[ResultFeature.values().length];

        static {
            try {
                $SwitchMap$org$dmg$pmml$ResultFeature[ResultFeature.PROBABILITY.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
        }
    }

    public CalibratedClassifier(String str, String str2) {
        super(str, str2);
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v42, types: [java.util.List] */
    @Override // sklearn.Estimator
    /* renamed from: encodeModel */
    public Model mo5encodeModel(Schema schema) {
        RegressionModel normalizationMethod;
        List<? extends Calibrator> calibrators = getCalibrators();
        getClasses();
        Classifier estimator = getEstimator();
        String method = getMethod();
        boolean z = -1;
        switch (method.hashCode()) {
            case 583579720:
                if (method.equals("isotonic")) {
                    z = false;
                    break;
                }
                break;
            case 2088248974:
                if (method.equals("sigmoid")) {
                    z = true;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
            case true:
                SkLearnEncoder encoder = schema.getEncoder();
                CategoricalLabel categoricalLabel = (CategoricalLabel) schema.getLabel();
                schema.getFeatures();
                MiningModel encode = estimator.encode(schema);
                ArrayList arrayList = new ArrayList();
                ArrayList arrayList2 = new ArrayList();
                if (!(estimator instanceof LinearClassifier)) {
                    Output output = encode.getOutput();
                    if (output == null) {
                        throw new IllegalArgumentException();
                    }
                    for (OutputField outputField : output.getOutputFields()) {
                        switch (AnonymousClass1.$SwitchMap$org$dmg$pmml$ResultFeature[outputField.getResultFeature().ordinal()]) {
                            case 1:
                                outputField.setName(getDecisionFunctionField(outputField.getValue()));
                                arrayList2.add(new ContinuousFeature(encoder, outputField));
                                break;
                        }
                    }
                    arrayList.add(encode);
                    if (categoricalLabel.size() == 2) {
                        SchemaUtil.checkSize(2, arrayList2);
                        arrayList2 = arrayList2.subList(1, 2);
                    }
                } else if (encode instanceof MiningModel) {
                    Iterator it = encode.requireSegmentation().requireSegments().iterator();
                    while (it.hasNext()) {
                        RegressionModel requireModel = ((Segment) it.next()).requireModel();
                        if (requireModel.requireMiningFunction() == MiningFunction.REGRESSION) {
                            Output output2 = requireModel.getOutput();
                            if (output2 == null) {
                                throw new IllegalArgumentException();
                            }
                            OutputField outputField2 = (OutputField) Iterables.getOnlyElement(output2.getOutputFields());
                            outputField2.setName(getDecisionFunctionField(outputField2.getName()));
                            requireModel.setNormalizationMethod(RegressionModel.NormalizationMethod.NONE);
                            arrayList.add(requireModel);
                            arrayList2.add(new ContinuousFeature(encoder, outputField2));
                        }
                    }
                } else {
                    if (!(encode instanceof RegressionModel)) {
                        throw new IllegalArgumentException();
                    }
                    List<RegressionTable> regressionTables = ((RegressionModel) encode).getRegressionTables();
                    if (categoricalLabel.size() == 2) {
                        regressionTables = regressionTables.subList(0, 1);
                    }
                    for (RegressionTable regressionTable : regressionTables) {
                        OutputField finalResult = ModelUtil.createPredictedField(getDecisionFunctionField(regressionTable.requireTargetCategory()), OpType.CONTINUOUS, DataType.DOUBLE).setFinalResult(false);
                        Output addOutputFields = new Output().addOutputFields(new OutputField[]{finalResult});
                        regressionTable.setTargetCategory((Object) null);
                        arrayList.add(new RegressionModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema((Label) null), (List) null).setNormalizationMethod(RegressionModel.NormalizationMethod.NONE).addRegressionTables(new RegressionTable[]{regressionTable}).setOutput(addOutputFields));
                        arrayList2.add(new ContinuousFeature(encoder, finalResult));
                    }
                }
                SchemaUtil.checkSize(calibrators.size(), arrayList2);
                if (calibrators.size() == 1) {
                    SchemaUtil.checkSize(2, categoricalLabel);
                    normalizationMethod = RegressionModelUtil.createBinaryLogisticClassification(Collections.singletonList(calibrate(calibrators.get(0), (Model) arrayList.get(0), (Feature) arrayList2.get(0), encoder)), Collections.singletonList(Double.valueOf(1.0d)), (Number) null, RegressionModel.NormalizationMethod.NONE, false, schema);
                } else {
                    if (calibrators.size() < 3) {
                        throw new IllegalArgumentException();
                    }
                    SchemaUtil.checkSize(calibrators.size(), categoricalLabel);
                    ArrayList arrayList3 = new ArrayList();
                    for (int i = 0; i < calibrators.size(); i++) {
                        arrayList3.add(RegressionModelUtil.createRegressionTable(Collections.singletonList(calibrate(calibrators.get(i), (Model) (arrayList.size() == 1 ? arrayList.get(0) : arrayList.get(i)), (Feature) arrayList2.get(i), encoder)), Collections.singletonList(Double.valueOf(1.0d)), (Number) null).setTargetCategory(categoricalLabel.getValue(i)));
                    }
                    normalizationMethod = new RegressionModel(MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema(categoricalLabel), arrayList3).setNormalizationMethod(RegressionModel.NormalizationMethod.SIMPLEMAX);
                }
                encodePredictProbaOutput(normalizationMethod, DataType.DOUBLE, categoricalLabel);
                arrayList.add(normalizationMethod);
                return MiningModelUtil.createModelChain(arrayList, Segmentation.MissingPredictionTreatment.RETURN_MISSING);
            default:
                throw new IllegalArgumentException(method);
        }
    }

    @Override // sklearn.Classifier, sklearn.HasClasses
    public List<?> getClasses() {
        return getListLike("classes");
    }

    public List<? extends Calibrator> getCalibrators() {
        return getList("calibrators", Calibrator.class);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // sklearn.HasEstimator
    public Classifier getEstimator() {
        return (Classifier) get("estimator", Classifier.class);
    }

    public String getMethod() {
        return getString("method");
    }

    private String getDecisionFunctionField(Object obj) {
        Object pMMLSegmentId = getPMMLSegmentId();
        if (obj instanceof String) {
            obj = extractArguments(Estimator.FIELD_DECISION_FUNCTION, (String) obj);
        }
        return FieldNameUtil.create(Estimator.FIELD_DECISION_FUNCTION, pMMLSegmentId != null ? Arrays.asList(pMMLSegmentId, obj) : Arrays.asList(obj));
    }

    private static Feature calibrate(Calibrator calibrator, Model model, Feature feature, SkLearnEncoder skLearnEncoder) {
        skLearnEncoder.export(model, feature.getName());
        DerivedField removeDerivedField = skLearnEncoder.removeDerivedField(((Feature) Iterables.getOnlyElement(calibrator.encodeFeatures(Collections.singletonList(feature), skLearnEncoder))).getName());
        return new ContinuousFeature(skLearnEncoder, skLearnEncoder.createDerivedField(model, new OutputField(removeDerivedField.requireName(), removeDerivedField.requireOpType(), removeDerivedField.requireDataType()).setResultFeature(ResultFeature.TRANSFORMED_VALUE).setExpression(removeDerivedField.requireExpression()).setFinalResult(false), true));
    }
}
