package sklearn2pmml.expression;

import com.google.common.collect.Iterables;
import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.dmg.pmml.DataType;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.regression.RegressionModel;
import org.dmg.pmml.regression.RegressionTable;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.FieldNameUtil;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.PMMLEncoder;
import org.jpmml.converter.Schema;
import org.jpmml.converter.regression.RegressionModelUtil;
import org.jpmml.python.DataFrameScope;
import sklearn.Classifier;
import sklearn2pmml.util.EvaluatableUtil;
import sklearn2pmml.util.Expression;

/* loaded from: input_file:sklearn2pmml/expression/ExpressionClassifier.class */
public class ExpressionClassifier extends Classifier {

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: sklearn2pmml.expression.ExpressionClassifier$1, reason: invalid class name */
    /* loaded from: input_file:sklearn2pmml/expression/ExpressionClassifier$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$dmg$pmml$regression$RegressionModel$NormalizationMethod = new int[RegressionModel.NormalizationMethod.values().length];

        static {
            try {
                $SwitchMap$org$dmg$pmml$regression$RegressionModel$NormalizationMethod[RegressionModel.NormalizationMethod.LOGIT.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$dmg$pmml$regression$RegressionModel$NormalizationMethod[RegressionModel.NormalizationMethod.SIMPLEMAX.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$org$dmg$pmml$regression$RegressionModel$NormalizationMethod[RegressionModel.NormalizationMethod.SOFTMAX.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
        }
    }

    public ExpressionClassifier(String str, String str2) {
        super(str, str2);
    }

    @Override // sklearn.Estimator
    /* renamed from: encodeModel, reason: merged with bridge method [inline-methods] */
    public RegressionModel mo7encodeModel(Schema schema) {
        List list;
        Object obj;
        Map<?, Expression> classExprs = getClassExprs();
        RegressionModel.NormalizationMethod parseNormalizationMethod = parseNormalizationMethod(getNormalizationMethod());
        PMMLEncoder encoder = schema.getEncoder();
        CategoricalLabel label = schema.getLabel();
        DataFrameScope dataFrameScope = new DataFrameScope("X", schema.getFeatures(), encoder);
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        for (Map.Entry<?, Expression> entry : classExprs.entrySet()) {
            Object key = entry.getKey();
            linkedHashMap.put(key, RegressionModelUtil.createRegressionTable(Collections.singletonList(ExpressionUtil.toFeature(FieldNameUtil.create("expression", new Object[]{key}), EvaluatableUtil.translateExpression(entry.getValue(), dataFrameScope), encoder)), Collections.singletonList(Double.valueOf(1.0d)), Double.valueOf(0.0d)));
        }
        List values = label.getValues();
        switch (AnonymousClass1.$SwitchMap$org$dmg$pmml$regression$RegressionModel$NormalizationMethod[parseNormalizationMethod.ordinal()]) {
            case 1:
                if (linkedHashMap.size() != 1 || values.size() != 2) {
                    throw new IllegalArgumentException();
                }
                Object onlyElement = Iterables.getOnlyElement(linkedHashMap.keySet());
                int indexOf = values.indexOf(onlyElement);
                if (indexOf == 0) {
                    obj = values.get(1);
                } else {
                    if (indexOf != 1) {
                        throw new IllegalArgumentException();
                    }
                    obj = values.get(0);
                }
                list = Arrays.asList(((RegressionTable) linkedHashMap.get(onlyElement)).setTargetCategory(onlyElement), RegressionModelUtil.createRegressionTable(Collections.emptyList(), Collections.emptyList(), (Number) null).setTargetCategory(obj));
                break;
            case 2:
            case 3:
                if (linkedHashMap.size() != values.size() || !linkedHashMap.keySet().containsAll(values)) {
                    throw new IllegalArgumentException();
                }
                list = (List) values.stream().map(obj2 -> {
                    return ((RegressionTable) linkedHashMap.get(obj2)).setTargetCategory(obj2);
                }).collect(Collectors.toList());
                break;
            default:
                throw new IllegalArgumentException();
        }
        RegressionModel normalizationMethod = new RegressionModel(MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema(label), list).setNormalizationMethod(parseNormalizationMethod);
        encodePredictProbaOutput(normalizationMethod, DataType.DOUBLE, label);
        return normalizationMethod;
    }

    public Map<?, Expression> getClassExprs() {
        return getDict("class_exprs");
    }

    public String getNormalizationMethod() {
        return getString("normalization_method");
    }

    private static RegressionModel.NormalizationMethod parseNormalizationMethod(String str) {
        boolean z = -1;
        switch (str.hashCode()) {
            case -2124366222:
                if (str.equals("simplemax")) {
                    z = true;
                    break;
                }
                break;
            case -2035660550:
                if (str.equals("softmax")) {
                    z = 2;
                    break;
                }
                break;
            case 103149423:
                if (str.equals("logit")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                return RegressionModel.NormalizationMethod.LOGIT;
            case true:
                return RegressionModel.NormalizationMethod.SIMPLEMAX;
            case true:
                return RegressionModel.NormalizationMethod.SOFTMAX;
            default:
                throw new IllegalArgumentException(str);
        }
    }
}
