package sklearn2pmml.pipeline;

import com.google.common.base.Function;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import numpy.core.NDArrayUtil;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DataType;
import org.dmg.pmml.DefineFunction;
import org.dmg.pmml.Extension;
import org.dmg.pmml.Field;
import org.dmg.pmml.Header;
import org.dmg.pmml.MiningBuildTask;
import org.dmg.pmml.Model;
import org.dmg.pmml.Output;
import org.dmg.pmml.OutputField;
import org.dmg.pmml.PMML;
import org.dmg.pmml.ResultFeature;
import org.dmg.pmml.VerificationField;
import org.jpmml.converter.CMatrixUtil;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.DerivedOutputField;
import org.jpmml.converter.Feature;
import org.jpmml.converter.FieldNameUtil;
import org.jpmml.converter.Label;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.PMMLUtil;
import org.jpmml.converter.ScalarLabel;
import org.jpmml.converter.ScalarLabelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.ValueUtil;
import org.jpmml.converter.WildcardFeature;
import org.jpmml.converter.mining.MiningModelUtil;
import org.jpmml.python.ClassDictUtil;
import org.jpmml.sklearn.SkLearnEncoder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import sklearn.Classifier;
import sklearn.Estimator;
import sklearn.HasClasses;
import sklearn.Step;
import sklearn.Transformer;
import sklearn.pipeline.SkLearnPipeline;
import sklearn2pmml.Customization;
import sklearn2pmml.CustomizationUtil;
import sklearn2pmml.HasPMMLOptions;
import sklearn2pmml.SkLearn2PMMLFields;
import sklearn2pmml.decoration.Domain;

/* loaded from: input_file:sklearn2pmml/pipeline/PMMLPipeline.class */
public class PMMLPipeline extends SkLearnPipeline implements HasPMMLOptions<PMMLPipeline> {
    private static final Logger logger = LoggerFactory.getLogger(PMMLPipeline.class);

    /* renamed from: sklearn2pmml.pipeline.PMMLPipeline$4, reason: invalid class name */
    /* loaded from: input_file:sklearn2pmml/pipeline/PMMLPipeline$4.class */
    static /* synthetic */ class AnonymousClass4 {
        static final /* synthetic */ int[] $SwitchMap$org$dmg$pmml$DataType = new int[DataType.values().length];

        static {
            try {
                $SwitchMap$org$dmg$pmml$DataType[DataType.DOUBLE.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$dmg$pmml$DataType[DataType.FLOAT.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
        }
    }

    public PMMLPipeline() {
        this("sklearn2pmml.pipeline", "PMMLPipeline");
    }

    public PMMLPipeline(String str, String str2) {
        super(str, str2);
    }

    /* JADX WARN: Type inference failed for: r1v29, types: [int[], int[][]] */
    /* JADX WARN: Type inference failed for: r1v46, types: [int[], int[][]] */
    /* JADX WARN: Type inference failed for: r2v13, types: [int[], int[][]] */
    /* JADX WARN: Type inference failed for: r2v25, types: [int[], int[][]] */
    @Override // sklearn.pipeline.SkLearnPipeline, org.jpmml.sklearn.Encodable
    public PMML encodePMML() {
        List<? extends Number> pMMLFeatureImportances;
        SkLearnEncoder skLearnEncoder = new SkLearnEncoder();
        List<String> activeFields = getActiveFields();
        List<String> targetFields = getTargetFields();
        Map<?, ?> header = getHeader();
        String repr = getRepr();
        Transformer predictTransformer = getPredictTransformer();
        Transformer predictProbaTransformer = getPredictProbaTransformer();
        Transformer applyTransformer = getApplyTransformer();
        Verification verification = getVerification();
        List<? extends Customization> list = null;
        Estimator estimator = null;
        if (hasFinalEstimator()) {
            estimator = getFinalEstimator();
            targetFields = initLabel(targetFields, skLearnEncoder);
            list = estimator.getPMMLCustomizations();
        }
        List<String> initFeatures = initFeatures(activeFields, skLearnEncoder);
        if (estimator == null) {
            return encodePMML(header, null, repr, skLearnEncoder);
        }
        Schema createSchema = skLearnEncoder.createSchema();
        Model encode = estimator.encode(createSchema);
        skLearnEncoder.setModel(encode);
        if (!estimator.hasFeatureImportances() && (pMMLFeatureImportances = getPMMLFeatureImportances()) != null) {
            ClassDictUtil.checkSize(new Collection[]{initFeatures, pMMLFeatureImportances});
            for (int i = 0; i < initFeatures.size(); i++) {
                String str = initFeatures.get(i);
                Number number = pMMLFeatureImportances.get(i);
                DataField dataField = skLearnEncoder.getDataField(str);
                if (dataField == null) {
                    throw new IllegalArgumentException("Field " + str + " is undefined");
                }
                skLearnEncoder.addFeatureImportance(encode, new WildcardFeature(skLearnEncoder, dataField), number);
            }
        }
        if (predictTransformer != null || predictProbaTransformer != null || applyTransformer != null) {
            Model finalModel = MiningModelUtil.getFinalModel(encode);
            skLearnEncoder.setModel(finalModel);
            CategoricalLabel label = createSchema.getLabel();
            Output ensureOutput = ModelUtil.ensureOutput(finalModel);
            if (predictTransformer != null) {
                List<ScalarLabel> scalarLabels = ScalarLabelUtil.toScalarLabels(label);
                ArrayList arrayList = new ArrayList();
                for (ScalarLabel scalarLabel : scalarLabels) {
                    OutputField finalResult = ModelUtil.createPredictedField(FieldNameUtil.create("predict", new Object[]{scalarLabel.getName()}), scalarLabel.getOpType(), scalarLabel.getDataType()).setFinalResult(false);
                    ensureOutput.addOutputFields(new OutputField[]{finalResult});
                    arrayList.add(finalResult);
                }
                encodeOutput(ensureOutput, arrayList, predictTransformer, skLearnEncoder);
            }
            if (predictProbaTransformer != null) {
                encodeOutput(ensureOutput, estimator.createPredictProbaFields(DataType.DOUBLE, label), predictProbaTransformer, skLearnEncoder);
            }
            if (applyTransformer != null) {
                encodeOutput(ensureOutput, Collections.singletonList(estimator.createApplyField(DataType.INTEGER)), applyTransformer, skLearnEncoder);
            }
            skLearnEncoder.setModel(encode);
        }
        if (estimator.isSupervised()) {
            if (verification == null) {
                logger.warn("Model verification data is not set. Use the '" + ClassDictUtil.formatMember(this, "verify(X)") + "' method to correct this deficiency");
            } else {
                Label label2 = createSchema.getLabel();
                List<?> activeValues = verification.getActiveValues();
                int[] activeValuesShape = verification.getActiveValuesShape();
                ClassDictUtil.checkShapes(1, initFeatures.size(), (int[][]) new int[]{activeValuesShape});
                int i2 = activeValuesShape[0];
                LinkedHashMap linkedHashMap = new LinkedHashMap();
                if (initFeatures != null) {
                    for (int i3 = 0; i3 < initFeatures.size(); i3++) {
                        VerificationField createVerificationField = ModelUtil.createVerificationField(initFeatures.get(i3));
                        linkedHashMap.put(createVerificationField, CMatrixUtil.getColumn(cleanValues(skLearnEncoder.getDomain(createVerificationField.requireField()), activeValues), i2, initFeatures.size(), i3));
                    }
                }
                Number precision = verification.getPrecision();
                Number zeroThreshold = verification.getZeroThreshold();
                List scalarLabels2 = ScalarLabelUtil.toScalarLabels(label2);
                boolean hasProbabilityValues = verification.hasProbabilityValues();
                if (estimator instanceof HasClasses) {
                    hasProbabilityValues &= ((HasClasses) estimator).hasProbabilityDistribution();
                }
                if (hasProbabilityValues) {
                    List<? extends Number> probabilityValues = verification.getProbabilityValues();
                    int[] probabilityValuesShape = verification.getProbabilityValuesShape();
                    ClassDictUtil.checkShapes(0, (int[][]) new int[]{activeValuesShape, probabilityValuesShape});
                    ClassDictUtil.checkSize(1, new Collection[]{scalarLabels2});
                    List<String> initProbabilityFields = initProbabilityFields((CategoricalLabel) ((ScalarLabel) scalarLabels2.get(0)));
                    ClassDictUtil.checkShapes(1, initProbabilityFields.size(), (int[][]) new int[]{probabilityValuesShape});
                    for (int i4 = 0; i4 < initProbabilityFields.size(); i4++) {
                        linkedHashMap.put(ModelUtil.createVerificationField(initProbabilityFields.get(i4)).setPrecision(precision).setZeroThreshold(zeroThreshold), CMatrixUtil.getColumn(cleanValues(null, probabilityValues), i2, initProbabilityFields.size(), i4));
                    }
                } else {
                    List<?> targetValues = verification.getTargetValues();
                    ClassDictUtil.checkShapes(0, (int[][]) new int[]{activeValuesShape, verification.getTargetValuesShape()});
                    ClassDictUtil.checkSize(new Collection[]{targetFields, scalarLabels2});
                    for (int i5 = 0; i5 < targetFields.size(); i5++) {
                        VerificationField createVerificationField2 = ModelUtil.createVerificationField(targetFields.get(i5));
                        switch (AnonymousClass4.$SwitchMap$org$dmg$pmml$DataType[((ScalarLabel) scalarLabels2.get(i5)).getDataType().ordinal()]) {
                            case 1:
                            case 2:
                                createVerificationField2.setPrecision(precision).setZeroThreshold(zeroThreshold);
                                break;
                        }
                        linkedHashMap.put(createVerificationField2, CMatrixUtil.getColumn(cleanValues(skLearnEncoder.getDomain(createVerificationField2.requireField()), targetValues), i2, targetFields.size(), i5));
                    }
                }
                encode.setModelVerification(ModelUtil.createModelVerification(linkedHashMap));
            }
        }
        if (list != null && !list.isEmpty()) {
            try {
                CustomizationUtil.customize(encode, list);
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
        }
        return encodePMML(header, encode, repr, skLearnEncoder);
    }

    private PMML encodePMML(Map<?, ?> map, Model model, String str, SkLearnEncoder skLearnEncoder) {
        PMML encodePMML = skLearnEncoder.encodePMML(model);
        if (map != null) {
            Header requireHeader = encodePMML.requireHeader();
            requireHeader.setCopyright((String) map.get("copyright"));
            requireHeader.setDescription((String) map.get("description"));
            requireHeader.setModelVersion((String) map.get("modelVersion"));
        }
        if (str != null) {
            encodePMML.setMiningBuildTask(new MiningBuildTask().addExtensions(new Extension[]{PMMLUtil.createExtension("repr", new Object[]{str})}));
        }
        return encodePMML;
    }

    private void encodeOutput(Output output, List<OutputField> list, Transformer transformer, final SkLearnEncoder skLearnEncoder) {
        OutputField expression;
        SkLearnEncoder skLearnEncoder2 = new SkLearnEncoder() { // from class: sklearn2pmml.pipeline.PMMLPipeline.1
            @Override // org.jpmml.sklearn.SkLearnEncoder
            public void addTransformer(Model model) {
                throw new UnsupportedOperationException();
            }

            @Override // org.jpmml.sklearn.SkLearnEncoder
            public boolean isFrozen(String str) {
                return true;
            }

            @Override // org.jpmml.sklearn.SkLearnEncoder
            public Map<String, Domain> getDomains() {
                throw new UnsupportedOperationException();
            }

            @Override // org.jpmml.sklearn.SkLearnEncoder
            public Map<String, Feature> getMemory() {
                return skLearnEncoder.getMemory();
            }
        };
        Model model = skLearnEncoder.getModel();
        if (model != null) {
            skLearnEncoder2.setModel(model);
        }
        ArrayList arrayList = new ArrayList();
        for (OutputField outputField : list) {
            arrayList.add(new WildcardFeature(skLearnEncoder2, skLearnEncoder2.createDataField(outputField.requireName(), outputField.requireOpType(), outputField.requireDataType())));
        }
        List<Feature> encode = transformer.encode(arrayList, skLearnEncoder2);
        final LinkedHashMap linkedHashMap = new LinkedHashMap();
        for (Feature feature : encode) {
            String name = feature.getName();
            Field field = feature.getField();
            try {
                skLearnEncoder2.getField(field.requireName());
            } catch (IllegalArgumentException e) {
                output.addOutputFields(new OutputField[]{new OutputField(FieldNameUtil.create("xref", new Object[]{feature}), field.requireOpType(), field.requireDataType()).setResultFeature(ResultFeature.TRANSFORMED_VALUE).setFinalResult(true).setExpression(feature.ref())});
            }
            linkedHashMap.put(name, Integer.valueOf(linkedHashMap.size()));
        }
        for (DerivedOutputField derivedOutputField : skLearnEncoder2.getDerivedFields().values()) {
            if (derivedOutputField instanceof DerivedOutputField) {
                expression = derivedOutputField.getOutputField();
            } else {
                String requireName = derivedOutputField.requireName();
                expression = new OutputField(requireName, derivedOutputField.requireOpType(), derivedOutputField.requireDataType()).setResultFeature(ResultFeature.TRANSFORMED_VALUE).setFinalResult(Boolean.valueOf(linkedHashMap.containsKey(requireName))).setExpression(derivedOutputField.requireExpression());
            }
            output.addOutputFields(new OutputField[]{expression});
        }
        Collections.sort(output.getOutputFields(), new Comparator<OutputField>() { // from class: sklearn2pmml.pipeline.PMMLPipeline.2
            @Override // java.util.Comparator
            public int compare(OutputField outputField2, OutputField outputField3) {
                return Integer.compare(((Integer) linkedHashMap.getOrDefault(outputField2.requireName(), -1)).intValue(), ((Integer) linkedHashMap.getOrDefault(outputField3.requireName(), -1)).intValue());
            }
        });
        Iterator it = skLearnEncoder2.getDefineFunctions().values().iterator();
        while (it.hasNext()) {
            skLearnEncoder.addDefineFunction((DefineFunction) it.next());
        }
    }

    @Override // sklearn.pipeline.SkLearnPipeline
    public List<Object[]> getSteps() {
        return super.getSteps();
    }

    @Override // sklearn.pipeline.SkLearnPipeline
    public PMMLPipeline setSteps(List<Object[]> list) {
        return (PMMLPipeline) super.setSteps(list);
    }

    @Override // sklearn2pmml.HasPMMLOptions
    public Map<String, ?> getPMMLOptions() {
        if (hasFinalEstimator()) {
            return getFinalEstimator().getPMMLOptions();
        }
        return null;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // sklearn2pmml.HasPMMLOptions
    public PMMLPipeline setPMMLOptions(Map<String, ?> map) {
        if (hasFinalEstimator()) {
            getFinalEstimator().setPMMLOptions(map);
        }
        return this;
    }

    public Map<?, ?> getHeader() {
        return (Map) getOptional("header", Map.class);
    }

    public List<? extends Number> getPMMLFeatureImportances() {
        if (containsKey(SkLearn2PMMLFields.PMML_FEATURE_IMPORTANCES)) {
            return getNumberArray(SkLearn2PMMLFields.PMML_FEATURE_IMPORTANCES);
        }
        return null;
    }

    public Transformer getPredictTransformer() {
        return getTransformer("predict_transformer");
    }

    public Transformer getPredictProbaTransformer() {
        return getTransformer("predict_proba_transformer");
    }

    public Transformer getApplyTransformer() {
        return getTransformer("apply_transformer");
    }

    private Transformer getTransformer(String str) {
        return (Transformer) getOptional(str, Transformer.class);
    }

    public List<String> getActiveFields() {
        if (containsKey("active_fields")) {
            return getListLike("active_fields", String.class);
        }
        return null;
    }

    public PMMLPipeline setActiveFields(List<String> list) {
        put("active_fields", NDArrayUtil.toArray(list));
        return this;
    }

    public List<String> getTargetFields() {
        if (containsKey("target_field")) {
            return Collections.singletonList(getOptionalString("target_field"));
        }
        if (containsKey("target_fields")) {
            return getListLike("target_fields", String.class);
        }
        return null;
    }

    public PMMLPipeline setTargetFields(List<String> list) {
        put("target_fields", NDArrayUtil.toArray(list));
        return this;
    }

    public String getRepr() {
        return getOptionalString("repr_");
    }

    public PMMLPipeline setRepr(String str) {
        put("repr_", str);
        return this;
    }

    public Verification getVerification() {
        return (Verification) getOptional("verification", Verification.class);
    }

    public PMMLPipeline setVerification(Verification verification) {
        put("verification", verification);
        return this;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // sklearn.Composite
    public List<String> initTargetFields(Estimator estimator) {
        List<String> initTargetFields = super.initTargetFields(estimator);
        logger.warn("Attribute '" + ClassDictUtil.formatMember(this, "target_fields") + "' is not set. Assuming {} as the name(s) of the target field(s)", initTargetFields);
        return initTargetFields;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // sklearn.Composite
    public List<String> initActiveFields(Step step) {
        List<String> initActiveFields = super.initActiveFields(step);
        logger.warn("Attribute '" + ClassDictUtil.formatMember(this, "active_fields") + "' is not set. Assuming {} as the names of active fields", initActiveFields);
        return initActiveFields;
    }

    private List<String> initProbabilityFields(CategoricalLabel categoricalLabel) {
        ArrayList arrayList = new ArrayList();
        Iterator it = categoricalLabel.getValues().iterator();
        while (it.hasNext()) {
            arrayList.add(FieldNameUtil.create(Classifier.FIELD_PROBABILITY, new Object[]{it.next()}));
        }
        return arrayList;
    }

    private static List<?> cleanValues(Domain domain, List<?> list) {
        return Lists.transform(list, new Function<Object, Object>() { // from class: sklearn2pmml.pipeline.PMMLPipeline.3
            public Object apply(Object obj) {
                Domain.checkValue(obj);
                if (ValueUtil.isNaN(obj)) {
                    return null;
                }
                return obj;
            }
        });
    }

    @Override // sklearn.pipeline.SkLearnPipeline
    public /* bridge */ /* synthetic */ SkLearnPipeline setSteps(List list) {
        return setSteps((List<Object[]>) list);
    }

    @Override // sklearn2pmml.HasPMMLOptions
    public /* bridge */ /* synthetic */ PMMLPipeline setPMMLOptions(Map map) {
        return setPMMLOptions((Map<String, ?>) map);
    }
}
