package org.jpmml.evaluator;

import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;
import org.dmg.pmml.CategoricalPredictor;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.MiningFunctionType;
import org.dmg.pmml.NumericPredictor;
import org.dmg.pmml.OpType;
import org.dmg.pmml.PMML;
import org.dmg.pmml.PredictorTerm;
import org.dmg.pmml.RegressionModel;
import org.dmg.pmml.RegressionNormalizationMethodType;
import org.dmg.pmml.RegressionTable;
import org.jpmml.manager.RegressionModelManager;
import org.jpmml.manager.UnsupportedFeatureException;

/* loaded from: input_file:org/jpmml/evaluator/RegressionModelEvaluator.class */
public class RegressionModelEvaluator extends RegressionModelManager implements Evaluator {
    public RegressionModelEvaluator(PMML pmml) {
        super(pmml);
    }

    public RegressionModelEvaluator(PMML pmml, RegressionModel regressionModel) {
        super(pmml, regressionModel);
    }

    public RegressionModelEvaluator(RegressionModelManager regressionModelManager) {
        this(regressionModelManager.getPmml(), regressionModelManager.getModel());
    }

    @Override // org.jpmml.evaluator.Evaluator
    public Object prepare(FieldName fieldName, Object obj) {
        return ParameterUtil.prepare(getDataField(fieldName), getMiningField(fieldName), obj);
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r5v0, types: [org.jpmml.manager.ModelManager, org.jpmml.evaluator.RegressionModelEvaluator] */
    @Override // org.jpmml.evaluator.Evaluator
    public Map<FieldName, ?> evaluate(Map<FieldName, ?> map) {
        Map evaluateClassification;
        RegressionModel model = getModel();
        ModelManagerEvaluationContext modelManagerEvaluationContext = new ModelManagerEvaluationContext(this, map);
        MiningFunctionType functionName = model.getFunctionName();
        switch (functionName) {
            case REGRESSION:
                evaluateClassification = evaluateRegression(modelManagerEvaluationContext);
                break;
            case CLASSIFICATION:
                evaluateClassification = evaluateClassification(modelManagerEvaluationContext);
                break;
            default:
                throw new UnsupportedFeatureException(functionName);
        }
        return OutputUtil.evaluate(evaluateClassification, modelManagerEvaluationContext);
    }

    public Map<FieldName, Double> evaluateRegression(EvaluationContext evaluationContext) {
        RegressionModel model = getModel();
        List<RegressionTable> regressionTables = getRegressionTables();
        if (regressionTables.size() != 1) {
            throw new EvaluationException();
        }
        return Collections.singletonMap(getTarget(), normalizeRegressionResult(model.getNormalizationMethod(), evaluateRegressionTable(regressionTables.get(0), evaluationContext)));
    }

    public Map<FieldName, ClassificationMap> evaluateClassification(EvaluationContext evaluationContext) {
        RegressionModel model = getModel();
        List<RegressionTable> regressionTables = getRegressionTables();
        if (regressionTables.size() < 1) {
            throw new EvaluationException();
        }
        double d = 0.0d;
        ClassificationMap classificationMap = new ClassificationMap();
        for (RegressionTable regressionTable : regressionTables) {
            Double evaluateRegressionTable = evaluateRegressionTable(regressionTable, evaluationContext);
            d += Math.exp(evaluateRegressionTable.doubleValue());
            classificationMap.put(regressionTable.getTargetCategory(), evaluateRegressionTable);
        }
        FieldName target = getTarget();
        OpType optype = getDataField(target).getOptype();
        switch (optype) {
            case CATEGORICAL:
                RegressionNormalizationMethodType normalizationMethod = model.getNormalizationMethod();
                for (Map.Entry entry : classificationMap.entrySet()) {
                    entry.setValue(normalizeClassificationResult(normalizationMethod, (Double) entry.getValue(), Double.valueOf(d)));
                }
                return Collections.singletonMap(target, classificationMap);
            default:
                throw new UnsupportedFeatureException(optype);
        }
    }

    private static Double evaluateRegressionTable(RegressionTable regressionTable, EvaluationContext evaluationContext) {
        double intercept = CMAESOptimizer.DEFAULT_STOPFITNESS + regressionTable.getIntercept();
        for (NumericPredictor numericPredictor : regressionTable.getNumericPredictors()) {
            Object evaluate = ExpressionUtil.evaluate(numericPredictor.getName(), evaluationContext);
            if (evaluate == null) {
                return null;
            }
            intercept += numericPredictor.getCoefficient() * Math.pow(((Number) evaluate).doubleValue(), numericPredictor.getExponent());
        }
        for (CategoricalPredictor categoricalPredictor : regressionTable.getCategoricalPredictors()) {
            Object evaluate2 = ExpressionUtil.evaluate(categoricalPredictor.getName(), evaluationContext);
            if (evaluate2 != null) {
                intercept += categoricalPredictor.getCoefficient() * (ParameterUtil.equals(evaluate2, categoricalPredictor.getValue()) ? 1.0d : CMAESOptimizer.DEFAULT_STOPFITNESS);
            }
        }
        Iterator<PredictorTerm> it = regressionTable.getPredictorTerms().iterator();
        if (it.hasNext()) {
            throw new UnsupportedFeatureException(it.next());
        }
        return Double.valueOf(intercept);
    }

    private static Double normalizeRegressionResult(RegressionNormalizationMethodType regressionNormalizationMethodType, Double d) {
        switch (regressionNormalizationMethodType) {
            case NONE:
                return d;
            case SOFTMAX:
            case LOGIT:
                return Double.valueOf(1.0d / (1.0d + Math.exp(-d.doubleValue())));
            case EXP:
                return Double.valueOf(Math.exp(d.doubleValue()));
            default:
                throw new UnsupportedFeatureException(regressionNormalizationMethodType);
        }
    }

    private static Double normalizeClassificationResult(RegressionNormalizationMethodType regressionNormalizationMethodType, Double d, Double d2) {
        switch (regressionNormalizationMethodType) {
            case NONE:
                return d;
            case SOFTMAX:
                return Double.valueOf(Math.exp(d.doubleValue()) / d2.doubleValue());
            case LOGIT:
                return Double.valueOf(1.0d / (1.0d + Math.exp(-d.doubleValue())));
            case EXP:
            default:
                throw new UnsupportedFeatureException(regressionNormalizationMethodType);
            case CLOGLOG:
                return Double.valueOf(1.0d - Math.exp(-Math.exp(d.doubleValue())));
            case LOGLOG:
                return Double.valueOf(Math.exp(-Math.exp(-d.doubleValue())));
        }
    }
}
