package org.jpmml.sklearn;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import numpy.DType;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DataType;
import org.dmg.pmml.DerivedField;
import org.dmg.pmml.Expression;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Model;
import org.dmg.pmml.OpType;
import org.dmg.pmml.Output;
import org.dmg.pmml.OutputField;
import org.dmg.pmml.Predicate;
import org.dmg.pmml.ResultFeature;
import org.dmg.pmml.mining.MiningModel;
import org.dmg.pmml.mining.Segment;
import org.dmg.pmml.mining.Segmentation;
import org.jpmml.converter.DerivedOutputField;
import org.jpmml.converter.Feature;
import org.jpmml.converter.FieldNameUtil;
import org.jpmml.converter.Label;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.ScalarLabel;
import org.jpmml.converter.Schema;
import org.jpmml.converter.WildcardFeature;
import org.jpmml.model.ReflectionUtil;
import org.jpmml.model.UnsupportedAttributeException;
import org.jpmml.python.ClassDictUtil;
import org.jpmml.python.PickleUtil;
import org.jpmml.python.PythonEncoder;
import sklearn.Classifier;
import sklearn.Estimator;
import sklearn.EstimatorUtil;
import sklearn.Step;
import sklearn.ensemble.hist_gradient_boosting.TreePredictor;
import sklearn.neighbors.BinaryTree;
import sklearn.tree.Tree;
import sklearn2pmml.decoration.Alias;
import sklearn2pmml.decoration.Domain;

/* loaded from: input_file:org/jpmml/sklearn/SkLearnEncoder.class */
public class SkLearnEncoder extends PythonEncoder {
    private Map<String, Domain> domains = new LinkedHashMap();
    private Label label = null;
    private List<? extends Feature> features = Collections.emptyList();
    private Map<String, Feature> memory = new LinkedHashMap();
    private Predicate predicate = null;
    private Model model = null;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: org.jpmml.sklearn.SkLearnEncoder$1, reason: invalid class name */
    /* loaded from: input_file:org/jpmml/sklearn/SkLearnEncoder$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$dmg$pmml$mining$Segmentation$MultipleModelMethod;
        static final /* synthetic */ int[] $SwitchMap$org$dmg$pmml$ResultFeature = new int[ResultFeature.values().length];

        static {
            try {
                $SwitchMap$org$dmg$pmml$ResultFeature[ResultFeature.PREDICTED_VALUE.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$dmg$pmml$ResultFeature[ResultFeature.TRANSFORMED_VALUE.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$org$dmg$pmml$ResultFeature[ResultFeature.DECISION.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            $SwitchMap$org$dmg$pmml$mining$Segmentation$MultipleModelMethod = new int[Segmentation.MultipleModelMethod.values().length];
            try {
                $SwitchMap$org$dmg$pmml$mining$Segmentation$MultipleModelMethod[Segmentation.MultipleModelMethod.MODEL_CHAIN.ordinal()] = 1;
            } catch (NoSuchFieldError e4) {
            }
        }
    }

    public void addTransformer(Model model) {
        if (hasModel()) {
            throw new IllegalStateException("Model is already defined");
        }
        super.addTransformer(model);
    }

    public Model encodeModel(Model model) {
        Predicate predicate = getPredicate();
        MiningModel encodeModel = super.encodeModel(model);
        if (predicate == null) {
            return encodeModel;
        }
        MiningModel miningModel = encodeModel;
        Segmentation requireSegmentation = miningModel.requireSegmentation();
        Segmentation.MultipleModelMethod requireMultipleModelMethod = requireSegmentation.requireMultipleModelMethod();
        switch (AnonymousClass1.$SwitchMap$org$dmg$pmml$mining$Segmentation$MultipleModelMethod[requireMultipleModelMethod.ordinal()]) {
            case 1:
                List requireSegments = requireSegmentation.requireSegments();
                ((Segment) requireSegments.get(requireSegments.size() - 1)).setPredicate(predicate);
                if (((Set) requireSegments.stream().map(segment -> {
                    return segment.requireModel().requireMiningFunction();
                }).collect(Collectors.toSet())).size() > 1) {
                    miningModel.setMiningFunction(MiningFunction.MIXED);
                }
                return miningModel;
            default:
                throw new UnsupportedAttributeException(requireSegmentation, requireMultipleModelMethod);
        }
    }

    public Label initLabel(Estimator estimator, List<String> list) {
        if (!getFeatures().isEmpty()) {
            throw new IllegalStateException();
        }
        Label encodeLabel = estimator.encodeLabel(list, this);
        setLabel(encodeLabel);
        return encodeLabel;
    }

    public List<Feature> initFeatures(Step step, List<String> list) {
        ArrayList arrayList = new ArrayList();
        OpType opType = step.getOpType();
        DataType dataType = step.getDataType();
        Iterator<String> it = list.iterator();
        while (it.hasNext()) {
            arrayList.add(new WildcardFeature(this, createDataField(it.next(), opType, dataType)));
        }
        setFeatures(arrayList);
        return arrayList;
    }

    public Schema createSchema() {
        return new Schema(this, getLabel(), getFeatures());
    }

    public List<Feature> export(Model model, String str) {
        return export(model, Collections.singletonList(str));
    }

    public List<Feature> export(Model model, List<String> list) {
        Output finalOutput = EstimatorUtil.getFinalOutput(model);
        if (finalOutput == null) {
            throw new IllegalArgumentException();
        }
        List outputFields = finalOutput.getOutputFields();
        ArrayList arrayList = new ArrayList();
        Iterator<String> it = list.iterator();
        while (it.hasNext()) {
            DerivedOutputField derivedOutputField = null;
            List<OutputField> selectOutputFields = selectOutputFields(it.next(), outputFields);
            Iterator<OutputField> it2 = selectOutputFields.iterator();
            while (it2.hasNext()) {
                derivedOutputField = createDerivedField(model, it2.next(), true);
            }
            arrayList.add(derivedOutputField.toFeature(this));
            outputFields.removeAll(selectOutputFields);
        }
        return arrayList;
    }

    public Feature exportPrediction(Model model, ScalarLabel scalarLabel) {
        return exportPrediction(model, FieldNameUtil.create("predict", new Object[]{scalarLabel.getName()}), scalarLabel);
    }

    public Feature exportPrediction(Model model, String str, ScalarLabel scalarLabel) {
        return createDerivedField(model, ModelUtil.createPredictedField(str, scalarLabel.getOpType(), scalarLabel.getDataType()).setFinalResult(false), false).toFeature(this);
    }

    public Feature exportProbability(Model model, Object obj) {
        return exportProbability(model, FieldNameUtil.create(Classifier.FIELD_PROBABILITY, new Object[]{obj}), obj);
    }

    public Feature exportProbability(Model model, String str, Object obj) {
        return createDerivedField(model, ModelUtil.createProbabilityField(str, DataType.DOUBLE, obj).setFinalResult(false), false).toFeature(this);
    }

    public DataField createDataField(String str) {
        return createDataField(str, OpType.CONTINUOUS, DataType.DOUBLE);
    }

    public DerivedField createDerivedField(String str, Expression expression) {
        return createDerivedField(str, OpType.CONTINUOUS, DataType.DOUBLE, expression);
    }

    public void addDerivedField(DerivedField derivedField) {
        try {
            super.addDerivedField(derivedField);
        } catch (RuntimeException e) {
            throw new IllegalArgumentException("Field " + derivedField.requireName() + " is already defined. Please refactor the pipeline so that it would not contain duplicate field declarations, or use the " + Alias.class.getName() + " wrapper class to override the default name with a custom name (eg. " + Alias.formatAliasExample() + ")", e);
        }
    }

    public void renameFeature(Feature feature, String str) {
        String name = feature.getName();
        if (getField(name) instanceof DataField) {
            throw new IllegalArgumentException("User input field " + name + " cannot be renamed");
        }
        DerivedField removeDerivedField = removeDerivedField(name);
        try {
            ReflectionUtil.setFieldValue(Feature.class.getDeclaredField("name"), feature, str);
            removeDerivedField.setName(str);
            addDerivedField(removeDerivedField);
        } catch (ReflectiveOperationException e) {
            throw new RuntimeException(e);
        }
    }

    public void renameFeatures(List<Feature> list, List<String> list2) {
        ClassDictUtil.checkSize(list2.size(), new Collection[]{list});
        for (int i = 0; i < list.size(); i++) {
            renameFeature(list.get(i), list2.get(i));
        }
    }

    public boolean isFrozen(String str) {
        return this.domains.containsKey(str);
    }

    public Domain getDomain(String str) {
        return this.domains.get(str);
    }

    public void setDomain(String str, Domain domain) {
        if (domain != null) {
            this.domains.put(str, domain);
        } else {
            this.domains.remove(str);
        }
    }

    public Label getLabel() {
        return this.label;
    }

    public void setLabel(Label label) {
        this.label = label;
    }

    public List<? extends Feature> getFeatures() {
        return this.features;
    }

    public void setFeatures(List<? extends Feature> list) {
        this.features = (List) Objects.requireNonNull(list);
    }

    public void memorize(String str, Feature feature) {
        this.memory.put(str, feature);
    }

    public Feature recall(String str) {
        return this.memory.get(str);
    }

    public Predicate getPredicate() {
        return this.predicate;
    }

    public void setPredicate(Predicate predicate) {
        this.predicate = predicate;
    }

    public boolean hasModel() {
        return getModel() != null;
    }

    public Model getModel() {
        return this.model;
    }

    public void setModel(Model model) {
        this.model = model;
    }

    public static boolean isPrediction(OutputField outputField) {
        switch (AnonymousClass1.$SwitchMap$org$dmg$pmml$ResultFeature[outputField.getResultFeature().ordinal()]) {
            case 1:
            case 2:
            case 3:
                return true;
            default:
                return false;
        }
    }

    private static List<OutputField> selectOutputFields(String str, List<OutputField> list) {
        ArrayList arrayList = new ArrayList();
        for (OutputField outputField : list) {
            boolean isPrediction = isPrediction(outputField);
            if (isPrediction) {
                arrayList.add(outputField);
            }
            if (Objects.equals(str, outputField.requireName())) {
                return isPrediction ? arrayList : Collections.singletonList(outputField);
            }
        }
        throw new IllegalArgumentException(str);
    }

    static {
        PickleUtil.init(SkLearnEncoder.class.getClassLoader(), "sklearn2pmml.properties");
        DType.addDefinition(BinaryTree.DTYPE_NODEDATA);
        DType.addDefinition(Tree.DTYPE_TREE);
        DType.addDefinition(TreePredictor.DTYPE_PREDICTOR_OLD);
        DType.addDefinition(TreePredictor.DTYPE_PREDICTOR_NEW);
    }
}
