package sklearn;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import numpy.core.NDArrayUtil;
import org.dmg.pmml.DataType;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Model;
import org.dmg.pmml.OpType;
import org.dmg.pmml.OutputField;
import org.jpmml.converter.DiscreteLabel;
import org.jpmml.converter.Feature;
import org.jpmml.converter.FieldNameUtil;
import org.jpmml.converter.Label;
import org.jpmml.converter.ModelEncoder;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.python.ClassDictUtil;
import org.jpmml.sklearn.SkLearnEncoder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import sklearn2pmml.HasPMMLOptions;
import sklearn2pmml.HasPMMLSegmentId;
import sklearn2pmml.SkLearn2PMMLFields;

/* loaded from: input_file:sklearn/Estimator.class */
public abstract class Estimator extends Step implements HasNumberOfOutputs, HasPMMLOptions<Estimator>, HasPMMLSegmentId<Estimator> {
    public static final String FIELD_APPLY = "apply";
    public static final String FIELD_DECISION_FUNCTION = "decisionFunction";
    public static final String FIELD_PREDICT = "predict";
    private static final Logger logger = LoggerFactory.getLogger(Estimator.class);

    public Estimator(String str, String str2) {
        super(str, str2);
    }

    public abstract MiningFunction getMiningFunction();

    public abstract boolean isSupervised();

    public abstract Label encodeLabel(List<String> list, SkLearnEncoder skLearnEncoder);

    /* renamed from: encodeModel */
    public abstract Model mo7encodeModel(Schema schema);

    public OpType getOpType() {
        return OpType.CONTINUOUS;
    }

    public DataType getDataType() {
        return DataType.DOUBLE;
    }

    public int getNumberOfFeatures() {
        if (containsKey(SkLearnFields.N_FEATURES_IN) && get(SkLearnFields.N_FEATURES_IN) != null) {
            return getInteger(SkLearnFields.N_FEATURES_IN).intValue();
        }
        if (containsKey(SkLearnFields.N_FEATURES)) {
            return getInteger(SkLearnFields.N_FEATURES).intValue();
        }
        return -1;
    }

    public int getNumberOfOutputs() {
        if (containsKey(SkLearnFields.N_OUTPUTS)) {
            return getInteger(SkLearnFields.N_OUTPUTS).intValue();
        }
        return -1;
    }

    public String getAlgorithmName() {
        return getClassName();
    }

    public Model encode(Schema schema) {
        String pMMLName;
        checkVersion();
        checkLabel(schema.getLabel());
        checkFeatures(schema.getFeatures());
        Model mo7encodeModel = mo7encodeModel(schema);
        if (mo7encodeModel.getModelName() == null && (pMMLName = getPMMLName()) != null) {
            mo7encodeModel.setModelName(pMMLName);
        }
        if (mo7encodeModel.getAlgorithmName() == null) {
            mo7encodeModel.setAlgorithmName(getAlgorithmName());
        }
        addFeatureImportances(mo7encodeModel, schema);
        return mo7encodeModel;
    }

    public Model encode(Object obj, Schema schema) {
        Object pMMLSegmentId = getPMMLSegmentId();
        try {
            setPMMLSegmentId(obj);
            Model encode = encode(schema);
            setPMMLSegmentId(pMMLSegmentId);
            return encode;
        } catch (Throwable th) {
            setPMMLSegmentId(pMMLSegmentId);
            throw th;
        }
    }

    public void checkLabel(Label label) {
        if (isSupervised()) {
            if (label == null) {
                throw new IllegalArgumentException("Expected a label, got no label");
            }
        } else if (label != null) {
            throw new IllegalArgumentException("Expected no label, got " + label);
        }
    }

    public void checkFeatures(List<? extends Feature> list) {
        StepUtil.checkNumberOfFeatures(this, list);
    }

    public void addFeatureImportances(Model model, Schema schema) {
        List<? extends Number> pMMLFeatureImportances = getPMMLFeatureImportances();
        if (pMMLFeatureImportances == null) {
            pMMLFeatureImportances = getFeatureImportances();
        }
        ModelEncoder encoder = schema.getEncoder();
        List features = schema.getFeatures();
        if (pMMLFeatureImportances != null) {
            ClassDictUtil.checkSize(new Collection[]{features, pMMLFeatureImportances});
            for (int i = 0; i < features.size(); i++) {
                encoder.addFeatureImportance(model, (Feature) features.get(i), pMMLFeatureImportances.get(i));
            }
        }
    }

    public Object getOption(String str, Object obj) {
        Map<String, ?> pMMLOptions = getPMMLOptions();
        if (pMMLOptions != null && pMMLOptions.containsKey(str)) {
            return pMMLOptions.get(str);
        }
        if (!containsKey(str)) {
            return obj;
        }
        logger.warn("Attribute '" + ClassDictUtil.formatMember(this, SkLearn2PMMLFields.PMML_OPTIONS) + "' is not set. Falling back to the surrogate attribute '" + ClassDictUtil.formatMember(this, str) + "'");
        return get(str);
    }

    public void putOption(String str, Object obj) {
        putOptions(Collections.singletonMap(str, obj));
    }

    public void putOptions(Map<String, ?> map) {
        Map<String, ?> pMMLOptions = getPMMLOptions();
        if (pMMLOptions == null) {
            pMMLOptions = new LinkedHashMap();
            setPMMLOptions(pMMLOptions);
        }
        pMMLOptions.putAll(map);
    }

    public boolean hasFeatureImportances() {
        return containsKey(SkLearnFields.FEATURE_IMPORTANCES) || containsKey(SkLearn2PMMLFields.PMML_FEATURE_IMPORTANCES);
    }

    public List<? extends Number> getFeatureImportances() {
        if (containsKey(SkLearnFields.FEATURE_IMPORTANCES)) {
            return getNumberArray(SkLearnFields.FEATURE_IMPORTANCES);
        }
        return null;
    }

    public List<? extends Number> getPMMLFeatureImportances() {
        if (containsKey(SkLearn2PMMLFields.PMML_FEATURE_IMPORTANCES)) {
            return getNumberArray(SkLearn2PMMLFields.PMML_FEATURE_IMPORTANCES);
        }
        return null;
    }

    public Estimator setPMMLFeatureImportances(List<? extends Number> list) {
        put(SkLearn2PMMLFields.PMML_FEATURE_IMPORTANCES, NDArrayUtil.toArray(list));
        return this;
    }

    @Override // sklearn2pmml.HasPMMLOptions
    public Map<String, ?> getPMMLOptions() {
        if (get(SkLearn2PMMLFields.PMML_OPTIONS) == null) {
            return null;
        }
        return getDict(SkLearn2PMMLFields.PMML_OPTIONS);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // sklearn2pmml.HasPMMLOptions
    public Estimator setPMMLOptions(Map<String, ?> map) {
        put(SkLearn2PMMLFields.PMML_OPTIONS, map);
        return this;
    }

    @Override // sklearn2pmml.HasPMMLSegmentId
    public Object getPMMLSegmentId() {
        return getOptionalScalar(SkLearn2PMMLFields.PMML_SEGMENT_ID);
    }

    @Override // sklearn2pmml.HasPMMLSegmentId
    public Estimator setPMMLSegmentId(Object obj) {
        if (obj != null) {
            put(SkLearn2PMMLFields.PMML_SEGMENT_ID, obj);
        } else {
            remove(SkLearn2PMMLFields.PMML_SEGMENT_ID);
        }
        return this;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public List<OutputField> createPredictProbaFields(DataType dataType, DiscreteLabel discreteLabel) {
        Object pMMLSegmentId = getPMMLSegmentId();
        if (!(this instanceof HasClasses)) {
            throw new IllegalArgumentException();
        }
        return (List) discreteLabel.getValues().stream().map(obj -> {
            return ModelUtil.createProbabilityField(pMMLSegmentId != null ? FieldNameUtil.create(Classifier.FIELD_PROBABILITY, new Object[]{pMMLSegmentId, obj}) : FieldNameUtil.create(Classifier.FIELD_PROBABILITY, new Object[]{obj}), dataType, obj);
        }).collect(Collectors.toList());
    }

    /* JADX WARN: Multi-variable type inference failed */
    public OutputField createApplyField(DataType dataType) {
        Object pMMLSegmentId = getPMMLSegmentId();
        if (!(this instanceof HasApplyField)) {
            throw new IllegalArgumentException();
        }
        String applyField = ((HasApplyField) this).getApplyField();
        if (pMMLSegmentId != null) {
            applyField = FieldNameUtil.create(applyField, new Object[]{pMMLSegmentId});
        }
        return ModelUtil.createEntityIdField(applyField, dataType);
    }

    public OutputField encodeApplyOutput(Model model, DataType dataType) {
        OutputField createApplyField = createApplyField(dataType);
        ModelUtil.ensureOutput(model).getOutputFields().add(createApplyField);
        return createApplyField;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public OutputField createMultiApplyField(DataType dataType, String str) {
        Object pMMLSegmentId = getPMMLSegmentId();
        if (!(this instanceof HasMultiApplyField)) {
            throw new IllegalArgumentException();
        }
        String multiApplyField = ((HasMultiApplyField) this).getMultiApplyField(str);
        if (pMMLSegmentId != null) {
            multiApplyField = FieldNameUtil.create(multiApplyField, new Object[]{pMMLSegmentId});
        }
        OutputField createEntityIdField = ModelUtil.createEntityIdField(multiApplyField, dataType);
        createEntityIdField.setSegmentId(str);
        return createEntityIdField;
    }

    public List<OutputField> encodeMultiApplyOutput(Model model, DataType dataType, List<String> list) {
        ArrayList arrayList = new ArrayList();
        Iterator<String> it = list.iterator();
        while (it.hasNext()) {
            arrayList.add(createMultiApplyField(dataType, it.next()));
        }
        ModelUtil.ensureOutput(model).getOutputFields().addAll(arrayList);
        return arrayList;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static String extractArguments(String str, String str2) {
        return (str2.startsWith(new StringBuilder().append(str).append("(").toString()) && str2.endsWith(")")) ? str2.substring((str + "(").length(), str2.length() - ")".length()) : str2;
    }

    @Override // sklearn2pmml.HasPMMLOptions
    public /* bridge */ /* synthetic */ Estimator setPMMLOptions(Map map) {
        return setPMMLOptions((Map<String, ?>) map);
    }
}
