package sklearn2pmml.preprocessing;

import java.util.Collections;
import java.util.List;
import org.dmg.pmml.Apply;
import org.dmg.pmml.DataType;
import org.dmg.pmml.DerivedField;
import org.dmg.pmml.Expression;
import org.dmg.pmml.Field;
import org.dmg.pmml.FieldRef;
import org.dmg.pmml.HasDefaultValue;
import org.dmg.pmml.HasMapMissingTo;
import org.dmg.pmml.InvalidValueTreatmentMethod;
import org.dmg.pmml.OpType;
import org.jpmml.converter.ExpressionUtil;
import org.jpmml.converter.Feature;
import org.jpmml.converter.FeatureUtil;
import org.jpmml.converter.TypeUtil;
import org.jpmml.converter.ValueUtil;
import org.jpmml.python.DataFrameScope;
import org.jpmml.python.TypeInfo;
import org.jpmml.sklearn.SkLearnEncoder;
import pandas.CategoricalDtypeUtil;
import pandas.core.CategoricalDtype;
import sklearn.Transformer;
import sklearn2pmml.util.EvaluatableUtil;

/* loaded from: input_file:sklearn2pmml/preprocessing/ExpressionTransformer.class */
public class ExpressionTransformer extends Transformer {
    public ExpressionTransformer() {
        this("sklearn2pmml.preprocessing", "ExpressionTransformer");
    }

    public ExpressionTransformer(String str, String str2) {
        super(str, str2);
    }

    @Override // sklearn.Transformer
    public List<Feature> encodeFeatures(List<Feature> list, SkLearnEncoder skLearnEncoder) {
        DataType dataType;
        Object expr = getExpr();
        Object mapMissingTo = getMapMissingTo();
        Object defaultValue = getDefaultValue();
        InvalidValueTreatmentMethod parseInvalidValueTreatment = parseInvalidValueTreatment(getInvalidValueTreatment());
        CategoricalDtype dType = getDType();
        if (ValueUtil.isNaN(defaultValue)) {
            defaultValue = null;
        }
        if (ValueUtil.isNaN(mapMissingTo)) {
            mapMissingTo = null;
        }
        DataFrameScope dataFrameScope = new DataFrameScope("X", list, skLearnEncoder);
        Expression translateExpression = EvaluatableUtil.translateExpression(expr, dataFrameScope);
        DerivedField derivedField = null;
        if (translateExpression instanceof FieldRef) {
            derivedField = skLearnEncoder.getDerivedField(((FieldRef) translateExpression).requireField());
            if (derivedField != null) {
                translateExpression = derivedField.getExpression();
            }
        }
        if (mapMissingTo != null) {
            ((HasMapMissingTo) translateExpression).setMapMissingTo(mapMissingTo);
        }
        if (defaultValue != null) {
            ((HasDefaultValue) translateExpression).setDefaultValue(defaultValue);
        }
        if (parseInvalidValueTreatment != null) {
            ((Apply) translateExpression).setInvalidValueTreatment(parseInvalidValueTreatment);
        }
        if (dType != null) {
            dataType = dType.getDataType();
        } else {
            dataType = ExpressionUtil.getDataType(translateExpression, dataFrameScope);
            if (dataType == null) {
                dataType = DataType.DOUBLE;
            }
        }
        OpType opType = TypeUtil.getOpType(dataType);
        if ((translateExpression instanceof FieldRef) && mapMissingTo == null) {
            Feature resolveFeature = dataFrameScope.resolveFeature(((FieldRef) translateExpression).requireField());
            if (resolveFeature != null) {
                Field field = resolveFeature.getField();
                if (field.requireOpType() == opType && field.requireDataType() == dataType) {
                    if (dType instanceof CategoricalDtype) {
                        resolveFeature = CategoricalDtypeUtil.refineFeature(resolveFeature, dType, skLearnEncoder);
                    }
                    return Collections.singletonList(resolveFeature);
                }
            }
        }
        if (derivedField != null) {
            derivedField.setOpType(opType).setDataType(dataType);
        } else {
            derivedField = skLearnEncoder.createDerivedField(createFieldName("eval", EvaluatableUtil.toString(expr)), opType, dataType, translateExpression);
        }
        Feature createFeature = FeatureUtil.createFeature(derivedField, skLearnEncoder);
        if (dType instanceof CategoricalDtype) {
            createFeature = CategoricalDtypeUtil.refineFeature(createFeature, dType, skLearnEncoder);
        }
        return Collections.singletonList(createFeature);
    }

    public Object getDefaultValue() {
        return getOptionalScalar("default_value");
    }

    public ExpressionTransformer setDefaultValue(Object obj) {
        put("default_value", obj);
        return this;
    }

    public TypeInfo getDType() {
        return containsKey("dtype_") ? super.getOptionalDType("dtype_", true) : super.getOptionalDType("dtype", true);
    }

    public ExpressionTransformer setDType(Object obj) {
        put("dtype", obj);
        return this;
    }

    public Object getExpr() {
        return containsKey("expr_") ? getString("expr_") : getObject("expr");
    }

    public ExpressionTransformer setExpr(String str) {
        put("expr", str);
        return this;
    }

    public String getInvalidValueTreatment() {
        return getOptionalString("invalid_value_treatment");
    }

    public ExpressionTransformer setInvalidValueTreatment(String str) {
        put("invalid_value_treatment", str);
        return this;
    }

    public Object getMapMissingTo() {
        return getOptionalScalar("map_missing_to");
    }

    public ExpressionTransformer setMapMissingTo(Object obj) {
        put("map_missing_to", obj);
        return this;
    }

    private static InvalidValueTreatmentMethod parseInvalidValueTreatment(String str) {
        if (str == null) {
            return null;
        }
        boolean z = -1;
        switch (str.hashCode()) {
            case -1225415192:
                if (str.equals("return_invalid")) {
                    z = false;
                    break;
                }
                break;
            case -153099687:
                if (str.equals("as_missing")) {
                    z = true;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                return InvalidValueTreatmentMethod.RETURN_INVALID;
            case true:
                return InvalidValueTreatmentMethod.AS_MISSING;
            default:
                throw new IllegalArgumentException(str);
        }
    }
}
