package org.jpmml.evaluator;

import com.google.common.collect.BiMap;
import com.google.common.collect.Maps;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;
import org.dmg.pmml.ActivationFunctionType;
import org.dmg.pmml.Connection;
import org.dmg.pmml.DerivedField;
import org.dmg.pmml.Entity;
import org.dmg.pmml.Expression;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.FieldRef;
import org.dmg.pmml.MiningFunctionType;
import org.dmg.pmml.NeuralInput;
import org.dmg.pmml.NeuralLayer;
import org.dmg.pmml.NeuralNetwork;
import org.dmg.pmml.NeuralOutput;
import org.dmg.pmml.Neuron;
import org.dmg.pmml.NnNormalizationMethodType;
import org.dmg.pmml.NormContinuous;
import org.dmg.pmml.NormDiscrete;
import org.dmg.pmml.PMML;
import org.jpmml.manager.NeuralNetworkManager;
import org.jpmml.manager.UnsupportedFeatureException;

/* loaded from: input_file:org/jpmml/evaluator/NeuralNetworkEvaluator.class */
public class NeuralNetworkEvaluator extends NeuralNetworkManager implements Evaluator {
    private BiMap<String, Entity> entities;
    private static final Normalizer SIMPLEMAX_NORMALIZER = new Normalizer() { // from class: org.jpmml.evaluator.NeuralNetworkEvaluator.1
        @Override // org.jpmml.evaluator.NeuralNetworkEvaluator.Normalizer
        public double apply(double d) {
            return d;
        }
    };
    private static final Normalizer SOFTMAX_NORMALIZER = new Normalizer() { // from class: org.jpmml.evaluator.NeuralNetworkEvaluator.2
        @Override // org.jpmml.evaluator.NeuralNetworkEvaluator.Normalizer
        public double apply(double d) {
            return Math.exp(d);
        }
    };

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/jpmml/evaluator/NeuralNetworkEvaluator$Normalizer.class */
    public interface Normalizer {
        double apply(double d);
    }

    public NeuralNetworkEvaluator(PMML pmml) {
        super(pmml);
        this.entities = null;
    }

    public NeuralNetworkEvaluator(PMML pmml, NeuralNetwork neuralNetwork) {
        super(pmml, neuralNetwork);
        this.entities = null;
    }

    @Override // org.jpmml.manager.NeuralNetworkManager, org.jpmml.manager.HasEntityRegistry
    public BiMap<String, Entity> getEntityRegistry() {
        if (this.entities == null) {
            this.entities = super.getEntityRegistry();
        }
        return this.entities;
    }

    @Override // org.jpmml.evaluator.Evaluator
    public Object prepare(FieldName fieldName, Object obj) {
        return ParameterUtil.prepare(getDataField(fieldName), getMiningField(fieldName), obj);
    }

    @Override // org.jpmml.evaluator.Evaluator
    public Map<FieldName, ?> evaluate(Map<FieldName, ?> map) {
        Map<FieldName, ? extends Number> evaluateClassification;
        NeuralNetwork model = getModel();
        if (!model.isScorable()) {
            throw new InvalidResultException(model);
        }
        ModelManagerEvaluationContext modelManagerEvaluationContext = new ModelManagerEvaluationContext(this, map);
        MiningFunctionType functionName = model.getFunctionName();
        switch (functionName) {
            case REGRESSION:
                evaluateClassification = evaluateRegression(modelManagerEvaluationContext);
                break;
            case CLASSIFICATION:
                evaluateClassification = evaluateClassification(modelManagerEvaluationContext);
                break;
            default:
                throw new UnsupportedFeatureException(model, functionName);
        }
        return OutputUtil.evaluate(evaluateClassification, modelManagerEvaluationContext);
    }

    private Map<FieldName, ? extends Number> evaluateRegression(ModelManagerEvaluationContext modelManagerEvaluationContext) {
        LinkedHashMap newLinkedHashMap = Maps.newLinkedHashMap();
        Map<String, Double> evaluateRaw = evaluateRaw(modelManagerEvaluationContext);
        for (NeuralOutput neuralOutput : getOrCreateNeuralOutputs()) {
            String outputNeuron = neuralOutput.getOutputNeuron();
            Expression expression = getExpression(neuralOutput.getDerivedField());
            if (expression instanceof FieldRef) {
                newLinkedHashMap.put(((FieldRef) expression).getField(), evaluateRaw.get(outputNeuron));
            } else {
                if (!(expression instanceof NormContinuous)) {
                    throw new UnsupportedFeatureException(expression);
                }
                NormContinuous normContinuous = (NormContinuous) expression;
                newLinkedHashMap.put(normContinuous.getField(), Double.valueOf(NormalizationUtil.denormalize(normContinuous, evaluateRaw.get(outputNeuron).doubleValue())));
            }
        }
        return TargetUtil.evaluateRegression(newLinkedHashMap, modelManagerEvaluationContext);
    }

    private Map<FieldName, ? extends ClassificationMap> evaluateClassification(ModelManagerEvaluationContext modelManagerEvaluationContext) {
        LinkedHashMap newLinkedHashMap = Maps.newLinkedHashMap();
        BiMap<String, Entity> entityRegistry = getEntityRegistry();
        Map<String, Double> evaluateRaw = evaluateRaw(modelManagerEvaluationContext);
        for (NeuralOutput neuralOutput : getOrCreateNeuralOutputs()) {
            String outputNeuron = neuralOutput.getOutputNeuron();
            Expression expression = getExpression(neuralOutput.getDerivedField());
            if (!(expression instanceof NormDiscrete)) {
                throw new UnsupportedFeatureException(expression);
            }
            NormDiscrete normDiscrete = (NormDiscrete) expression;
            FieldName field = normDiscrete.getField();
            NeuronClassificationMap neuronClassificationMap = (NeuronClassificationMap) newLinkedHashMap.get(field);
            if (neuronClassificationMap == null) {
                neuronClassificationMap = new NeuronClassificationMap();
                newLinkedHashMap.put(field, neuronClassificationMap);
            }
            neuronClassificationMap.put(entityRegistry.get(outputNeuron), normDiscrete.getValue(), evaluateRaw.get(outputNeuron));
        }
        return TargetUtil.evaluateClassification(newLinkedHashMap, modelManagerEvaluationContext);
    }

    private Expression getExpression(DerivedField derivedField) {
        Expression expression = derivedField.getExpression();
        if (!(expression instanceof FieldRef)) {
            return expression;
        }
        FieldRef fieldRef = (FieldRef) expression;
        DerivedField resolveField = resolveField(fieldRef.getField());
        return resolveField != null ? getExpression(resolveField) : fieldRef;
    }

    public Map<String, Double> evaluateRaw(EvaluationContext evaluationContext) {
        LinkedHashMap newLinkedHashMap = Maps.newLinkedHashMap();
        for (NeuralInput neuralInput : getNeuralInputs()) {
            DerivedField derivedField = neuralInput.getDerivedField();
            Double d = (Double) ExpressionUtil.evaluate(derivedField, evaluationContext);
            if (d == null) {
                throw new MissingFieldException(derivedField.getName(), derivedField);
            }
            newLinkedHashMap.put(neuralInput.getId(), d);
        }
        for (NeuralLayer neuralLayer : getNeuralLayers()) {
            for (Neuron neuron : neuralLayer.getNeurons()) {
                double doubleValue = neuron.getBias().doubleValue();
                for (Connection connection : neuron.getConnections()) {
                    doubleValue += newLinkedHashMap.get(connection.getFrom()).doubleValue() * connection.getWeight();
                }
                newLinkedHashMap.put(neuron.getId(), Double.valueOf(activation(doubleValue, neuralLayer)));
            }
            normalizeNeuronOutputs(neuralLayer, newLinkedHashMap);
        }
        return newLinkedHashMap;
    }

    /* JADX WARN: Multi-variable type inference failed */
    private void normalizeNeuronOutputs(NeuralLayer neuralLayer, Map<String, Double> map) {
        NeuralNetwork model = getModel();
        NeuralNetwork neuralNetwork = neuralLayer;
        NnNormalizationMethodType normalizationMethod = neuralLayer.getNormalizationMethod();
        if (normalizationMethod == null) {
            neuralNetwork = model;
            normalizationMethod = model.getNormalizationMethod();
        }
        switch (normalizationMethod) {
            case NONE:
                return;
            case SIMPLEMAX:
                normalizeNeuronOutputs(neuralLayer, SIMPLEMAX_NORMALIZER, map);
                return;
            case SOFTMAX:
                normalizeNeuronOutputs(neuralLayer, SOFTMAX_NORMALIZER, map);
                return;
            default:
                throw new UnsupportedFeatureException(neuralNetwork, normalizationMethod);
        }
    }

    private void normalizeNeuronOutputs(NeuralLayer neuralLayer, Normalizer normalizer, Map<String, Double> map) {
        List<Neuron> neurons = neuralLayer.getNeurons();
        double d = 0.0d;
        Iterator<Neuron> it = neurons.iterator();
        while (it.hasNext()) {
            d += normalizer.apply(map.get(it.next().getId()).doubleValue());
        }
        for (Neuron neuron : neurons) {
            map.put(neuron.getId(), Double.valueOf(normalizer.apply(map.get(neuron.getId()).doubleValue()) / d));
        }
    }

    private double activation(double d, NeuralLayer neuralLayer) {
        NeuralNetwork model = getModel();
        NeuralLayer neuralLayer2 = neuralLayer;
        ActivationFunctionType activationFunction = neuralLayer.getActivationFunction();
        if (activationFunction == null) {
            neuralLayer2 = neuralLayer;
            activationFunction = model.getActivationFunction();
        }
        switch (activationFunction) {
            case THRESHOLD:
                Double threshold = neuralLayer.getThreshold();
                if (threshold == null) {
                    threshold = Double.valueOf(model.getThreshold());
                }
                if (d > threshold.doubleValue()) {
                    return 1.0d;
                }
                return CMAESOptimizer.DEFAULT_STOPFITNESS;
            case LOGISTIC:
                return 1.0d / (1.0d + Math.exp(-d));
            case TANH:
                return (1.0d - Math.exp((-2.0d) * d)) / (1.0d + Math.exp((-2.0d) * d));
            case IDENTITY:
                return d;
            case EXPONENTIAL:
                return Math.exp(d);
            case RECIPROCAL:
                return 1.0d / d;
            case SQUARE:
                return d * d;
            case GAUSS:
                return Math.exp(-(d * d));
            case SINE:
                return Math.sin(d);
            case COSINE:
                return Math.cos(d);
            case ELLIOTT:
                return d / (1.0d + Math.abs(d));
            case ARCTAN:
                return Math.atan(d);
            default:
                throw new UnsupportedFeatureException(neuralLayer2, activationFunction);
        }
    }
}
