package org.jpmml.evaluator;

import com.google.common.cache.CacheBuilder;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.BiMap;
import com.google.common.collect.Iterables;
import java.util.ArrayList;
import java.util.Collections;
import java.util.EnumSet;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DataType;
import org.dmg.pmml.EmbeddedModel;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.LocalTransformations;
import org.dmg.pmml.MiningFunctionType;
import org.dmg.pmml.MiningModel;
import org.dmg.pmml.Model;
import org.dmg.pmml.MultipleModelMethodType;
import org.dmg.pmml.PMML;
import org.dmg.pmml.Predicate;
import org.dmg.pmml.Segment;
import org.dmg.pmml.Segmentation;
import org.dmg.pmml.TreeModel;
import org.jpmml.evaluator.Classification;

/* loaded from: input_file:WEB-INF/lib/pmml-evaluator-1.2.0.jar:org/jpmml/evaluator/MiningModelEvaluator.class */
public class MiningModelEvaluator extends ModelEvaluator<MiningModel> implements HasEntityRegistry<Segment> {
    private ModelEvaluatorFactory evaluatorFactory;
    private static final Set<MultipleModelMethodType> REGRESSION_METHODS = EnumSet.of(MultipleModelMethodType.SUM, MultipleModelMethodType.MEDIAN, MultipleModelMethodType.AVERAGE, MultipleModelMethodType.WEIGHTED_AVERAGE);
    private static final Set<MultipleModelMethodType> CLASSIFICATION_METHODS = EnumSet.of(MultipleModelMethodType.MAJORITY_VOTE, MultipleModelMethodType.WEIGHTED_MAJORITY_VOTE, MultipleModelMethodType.SUM, MultipleModelMethodType.MEDIAN, MultipleModelMethodType.AVERAGE, MultipleModelMethodType.WEIGHTED_AVERAGE);
    private static final Set<MultipleModelMethodType> CLUSTERING_METHODS = EnumSet.of(MultipleModelMethodType.MAJORITY_VOTE, MultipleModelMethodType.WEIGHTED_MAJORITY_VOTE);
    private static final LoadingCache<MiningModel, BiMap<String, Segment>> entityCache = CacheBuilder.newBuilder().weakKeys().build(new CacheLoader<MiningModel, BiMap<String, Segment>>() { // from class: org.jpmml.evaluator.MiningModelEvaluator.1
        @Override // com.google.common.cache.CacheLoader
        public BiMap<String, Segment> load(MiningModel miningModel) {
            return EntityUtil.buildBiMap(miningModel.getSegmentation().getSegments());
        }
    });

    public MiningModelEvaluator(PMML pmml) {
        super(pmml, MiningModel.class);
        this.evaluatorFactory = null;
    }

    public MiningModelEvaluator(PMML pmml, MiningModel miningModel) {
        super(pmml, miningModel);
        this.evaluatorFactory = null;
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.jpmml.evaluator.ModelManager, org.jpmml.evaluator.Consumer
    public String getSummary() {
        return isRandomForest((MiningModel) getModel()) ? "Random forest" : "Ensemble model";
    }

    @Override // org.jpmml.evaluator.HasEntityRegistry
    public BiMap<String, Segment> getEntityRegistry() {
        return (BiMap) getValue(entityCache);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.jpmml.evaluator.ModelEvaluator
    public DataField getDataField() {
        if (((MiningModel) getModel()).getSegmentation() == null) {
            return null;
        }
        switch (r0.getMultipleModelMethod()) {
            case SELECT_ALL:
                return null;
            default:
                return super.getDataField();
        }
    }

    @Override // org.jpmml.evaluator.ModelEvaluator
    public MiningModelEvaluationContext createContext(ModelEvaluationContext modelEvaluationContext) {
        return new MiningModelEvaluationContext(modelEvaluationContext, this);
    }

    @Override // org.jpmml.evaluator.ModelEvaluator
    public Map<FieldName, ?> evaluate(ModelEvaluationContext modelEvaluationContext) {
        return evaluate((MiningModelEvaluationContext) modelEvaluationContext);
    }

    /* JADX WARN: Multi-variable type inference failed */
    public Map<FieldName, ?> evaluate(MiningModelEvaluationContext miningModelEvaluationContext) {
        Map<FieldName, ?> evaluateAny;
        MiningModel miningModel = (MiningModel) getModel();
        if (!miningModel.isScorable()) {
            throw new InvalidResultException(miningModel);
        }
        EmbeddedModel embeddedModel = (EmbeddedModel) Iterables.getFirst(miningModel.getEmbeddedModels(), null);
        if (embeddedModel != null) {
            throw new UnsupportedFeatureException(embeddedModel);
        }
        if (miningModel.getSegmentation() == null) {
            throw new InvalidFeatureException(miningModel);
        }
        switch (miningModel.getFunctionName()) {
            case REGRESSION:
                evaluateAny = evaluateRegression(miningModelEvaluationContext);
                break;
            case CLASSIFICATION:
                evaluateAny = evaluateClassification(miningModelEvaluationContext);
                break;
            case CLUSTERING:
                evaluateAny = evaluateClustering(miningModelEvaluationContext);
                break;
            default:
                evaluateAny = evaluateAny(miningModelEvaluationContext);
                break;
        }
        return OutputUtil.evaluate(evaluateAny, miningModelEvaluationContext);
    }

    /* JADX WARN: Multi-variable type inference failed */
    private Map<FieldName, ?> evaluateRegression(MiningModelEvaluationContext miningModelEvaluationContext) {
        MiningModel miningModel = (MiningModel) getModel();
        List<SegmentResultMap> evaluateSegmentation = evaluateSegmentation(miningModelEvaluationContext);
        Map<FieldName, ?> segmentationResult = getSegmentationResult(REGRESSION_METHODS, evaluateSegmentation);
        return segmentationResult != null ? segmentationResult : TargetUtil.evaluateRegression(aggregateValues(miningModel.getSegmentation(), evaluateSegmentation), miningModelEvaluationContext);
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v17, types: [org.jpmml.evaluator.Classification] */
    private Map<FieldName, ?> evaluateClassification(MiningModelEvaluationContext miningModelEvaluationContext) {
        ProbabilityDistribution probabilityDistribution;
        MiningModel miningModel = (MiningModel) getModel();
        List<SegmentResultMap> evaluateSegmentation = evaluateSegmentation(miningModelEvaluationContext);
        Map<FieldName, ?> segmentationResult = getSegmentationResult(CLASSIFICATION_METHODS, evaluateSegmentation);
        if (segmentationResult != null) {
            return segmentationResult;
        }
        Segmentation segmentation = miningModel.getSegmentation();
        MultipleModelMethodType multipleModelMethod = segmentation.getMultipleModelMethod();
        switch (multipleModelMethod) {
            case MAJORITY_VOTE:
            case WEIGHTED_MAJORITY_VOTE:
                probabilityDistribution = new ProbabilityDistribution();
                probabilityDistribution.putAll(aggregateVotes(segmentation, evaluateSegmentation));
                probabilityDistribution.normalizeValues();
                break;
            case MAX:
            case MEDIAN:
                probabilityDistribution = new Classification(Classification.Type.VOTE);
                probabilityDistribution.putAll(aggregateProbabilities(segmentation, evaluateSegmentation));
                break;
            case AVERAGE:
            case WEIGHTED_AVERAGE:
                probabilityDistribution = new ProbabilityDistribution();
                probabilityDistribution.putAll(aggregateProbabilities(segmentation, evaluateSegmentation));
                break;
            default:
                throw new UnsupportedFeatureException(segmentation, multipleModelMethod);
        }
        return TargetUtil.evaluateClassification(probabilityDistribution, miningModelEvaluationContext);
    }

    /* JADX WARN: Multi-variable type inference failed */
    private Map<FieldName, ?> evaluateClustering(MiningModelEvaluationContext miningModelEvaluationContext) {
        MiningModel miningModel = (MiningModel) getModel();
        List<SegmentResultMap> evaluateSegmentation = evaluateSegmentation(miningModelEvaluationContext);
        Map<FieldName, ?> segmentationResult = getSegmentationResult(CLUSTERING_METHODS, evaluateSegmentation);
        if (segmentationResult != null) {
            return segmentationResult;
        }
        Segmentation segmentation = miningModel.getSegmentation();
        Classification classification = new Classification(Classification.Type.VOTE);
        classification.putAll(aggregateVotes(segmentation, evaluateSegmentation));
        classification.computeResult(DataType.STRING);
        return Collections.singletonMap(getTargetField(), classification);
    }

    private Map<FieldName, ?> evaluateAny(MiningModelEvaluationContext miningModelEvaluationContext) {
        return getSegmentationResult(Collections.emptySet(), evaluateSegmentation(miningModelEvaluationContext));
    }

    /* JADX WARN: Multi-variable type inference failed */
    private List<SegmentResultMap> evaluateSegmentation(MiningModelEvaluationContext miningModelEvaluationContext) {
        MiningModel miningModel = (MiningModel) getModel();
        ArrayList arrayList = new ArrayList();
        Segmentation segmentation = miningModel.getSegmentation();
        LocalTransformations localTransformations = segmentation.getLocalTransformations();
        if (localTransformations != null) {
            throw new UnsupportedFeatureException(localTransformations);
        }
        ModelEvaluatorFactory evaluatorFactory = getEvaluatorFactory();
        if (evaluatorFactory == null) {
            evaluatorFactory = ModelEvaluatorFactory.newInstance();
        }
        BiMap<String, Segment> entityRegistry = getEntityRegistry();
        MultipleModelMethodType multipleModelMethod = segmentation.getMultipleModelMethod();
        Model model = null;
        MiningFunctionType functionName = miningModel.getFunctionName();
        for (Segment segment : segmentation.getSegments()) {
            Predicate predicate = segment.getPredicate();
            if (predicate == null) {
                throw new InvalidFeatureException(segment);
            }
            Boolean evaluate = PredicateUtil.evaluate(predicate, miningModelEvaluationContext);
            if (evaluate != null && evaluate.booleanValue()) {
                Model model2 = segment.getModel();
                if (model2 == null) {
                    throw new InvalidFeatureException(segment);
                }
                switch (multipleModelMethod) {
                    case MODEL_CHAIN:
                        model = model2;
                        break;
                    default:
                        if (!functionName.equals(model2.getFunctionName())) {
                            throw new InvalidFeatureException(model2);
                        }
                        break;
                }
                ModelEvaluator<? extends Model> newModelManager = evaluatorFactory.newModelManager(getPMML(), model2);
                ModelEvaluationContext createContext = newModelManager.createContext(miningModelEvaluationContext);
                Map<FieldName, ?> evaluate2 = newModelManager.evaluate(createContext);
                FieldName targetField = newModelManager.getTargetField();
                for (FieldName fieldName : newModelManager.getOutputFields()) {
                    FieldValue field = createContext.getField(fieldName);
                    if (field == null) {
                        throw new MissingFieldException(fieldName, segment);
                    }
                    miningModelEvaluationContext.declare(fieldName, field);
                }
                Iterator<String> it = createContext.getWarnings().iterator();
                while (it.hasNext()) {
                    miningModelEvaluationContext.addWarning(it.next());
                }
                String id = EntityUtil.getId(segment, entityRegistry);
                SegmentResultMap segmentResultMap = new SegmentResultMap(segment, targetField);
                segmentResultMap.putAll(evaluate2);
                miningModelEvaluationContext.putResult(id, segmentResultMap);
                switch (multipleModelMethod) {
                    case SELECT_FIRST:
                        return Collections.singletonList(segmentResultMap);
                    default:
                        arrayList.add(segmentResultMap);
                        break;
                }
            }
        }
        switch (multipleModelMethod) {
            case MODEL_CHAIN:
                if (model != null && !functionName.equals(model.getFunctionName())) {
                    throw new InvalidFeatureException(model);
                }
                break;
        }
        return arrayList;
    }

    /* JADX WARN: Multi-variable type inference failed */
    private Map<FieldName, ?> getSegmentationResult(Set<MultipleModelMethodType> set, List<SegmentResultMap> list) {
        Segmentation segmentation = ((MiningModel) getModel()).getSegmentation();
        MultipleModelMethodType multipleModelMethod = segmentation.getMultipleModelMethod();
        switch (multipleModelMethod) {
            case SELECT_ALL:
                return selectAll(list);
            case MODEL_CHAIN:
                if (list.size() > 0) {
                    return list.get(list.size() - 1);
                }
                break;
            case SELECT_FIRST:
                if (list.size() > 0) {
                    return list.get(0);
                }
                break;
            default:
                if (!set.contains(multipleModelMethod)) {
                    throw new UnsupportedFeatureException(segmentation, multipleModelMethod);
                }
                break;
        }
        if (list.size() == 0) {
            return Collections.singletonMap(getTargetField(), null);
        }
        return null;
    }

    public ModelEvaluatorFactory getEvaluatorFactory() {
        return this.evaluatorFactory;
    }

    public void setEvaluatorFactory(ModelEvaluatorFactory modelEvaluatorFactory) {
        this.evaluatorFactory = modelEvaluatorFactory;
    }

    private static Double aggregateValues(Segmentation segmentation, List<SegmentResultMap> list) {
        RegressionAggregator regressionAggregator = new RegressionAggregator();
        MultipleModelMethodType multipleModelMethod = segmentation.getMultipleModelMethod();
        double d = 0.0d;
        for (SegmentResultMap segmentResultMap : list) {
            Double d2 = (Double) TypeUtil.parseOrCast(DataType.DOUBLE, EvaluatorUtil.decode(segmentResultMap.getTargetValue()));
            switch (multipleModelMethod) {
                case MEDIAN:
                case SUM:
                    regressionAggregator.add(d2);
                    break;
                case AVERAGE:
                    regressionAggregator.add(d2);
                    d += 1.0d;
                    break;
                case WEIGHTED_AVERAGE:
                    double weight = segmentResultMap.getWeight();
                    regressionAggregator.add(Double.valueOf(d2.doubleValue() * weight));
                    d += weight;
                    break;
                case MODEL_CHAIN:
                case SELECT_FIRST:
                default:
                    throw new UnsupportedFeatureException(segmentation, multipleModelMethod);
            }
        }
        switch (multipleModelMethod) {
            case MEDIAN:
                return regressionAggregator.median();
            case AVERAGE:
            case WEIGHTED_AVERAGE:
                return regressionAggregator.average(d);
            case MODEL_CHAIN:
            case SELECT_FIRST:
            default:
                throw new UnsupportedFeatureException(segmentation, multipleModelMethod);
            case SUM:
                return regressionAggregator.sum();
        }
    }

    private static Map<String, Double> aggregateVotes(Segmentation segmentation, List<SegmentResultMap> list) {
        VoteAggregator voteAggregator = new VoteAggregator();
        MultipleModelMethodType multipleModelMethod = segmentation.getMultipleModelMethod();
        for (SegmentResultMap segmentResultMap : list) {
            String str = (String) EvaluatorUtil.decode(segmentResultMap.getTargetValue());
            switch (multipleModelMethod) {
                case MAJORITY_VOTE:
                    voteAggregator.add(str, Double.valueOf(1.0d));
                    break;
                case WEIGHTED_MAJORITY_VOTE:
                    voteAggregator.add(str, Double.valueOf(segmentResultMap.getWeight()));
                    break;
                default:
                    throw new UnsupportedFeatureException(segmentation, multipleModelMethod);
            }
        }
        return voteAggregator.sumMap();
    }

    private static Map<String, Double> aggregateProbabilities(Segmentation segmentation, List<SegmentResultMap> list) {
        ProbabilityAggregator probabilityAggregator = new ProbabilityAggregator();
        MultipleModelMethodType multipleModelMethod = segmentation.getMultipleModelMethod();
        double d = 0.0d;
        for (SegmentResultMap segmentResultMap : list) {
            Object targetValue = segmentResultMap.getTargetValue();
            if (!(targetValue instanceof HasProbability)) {
                throw new TypeCheckException((Class<?>) HasProbability.class, targetValue);
            }
            HasProbability hasProbability = (HasProbability) targetValue;
            switch (multipleModelMethod) {
                case MAX:
                case MEDIAN:
                    probabilityAggregator.add(hasProbability);
                    break;
                case AVERAGE:
                    probabilityAggregator.add(hasProbability);
                    d += 1.0d;
                    break;
                case WEIGHTED_AVERAGE:
                    double weight = segmentResultMap.getWeight();
                    probabilityAggregator.add(hasProbability, weight);
                    d += weight;
                    break;
                default:
                    throw new UnsupportedFeatureException(segmentation, multipleModelMethod);
            }
        }
        switch (multipleModelMethod) {
            case MAX:
                return probabilityAggregator.maxMap();
            case MEDIAN:
                return probabilityAggregator.medianMap();
            case AVERAGE:
            case WEIGHTED_AVERAGE:
                return probabilityAggregator.averageMap(d);
            default:
                throw new UnsupportedFeatureException(segmentation, multipleModelMethod);
        }
    }

    private static Map<FieldName, ?> selectAll(List<SegmentResultMap> list) {
        ArrayListMultimap create = ArrayListMultimap.create();
        LinkedHashSet<FieldName> linkedHashSet = null;
        for (SegmentResultMap segmentResultMap : list) {
            if (linkedHashSet == null) {
                linkedHashSet = new LinkedHashSet(segmentResultMap.keySet());
            }
            if (!linkedHashSet.equals(segmentResultMap.keySet())) {
                throw new EvaluationException();
            }
            for (FieldName fieldName : linkedHashSet) {
                create.put(fieldName, segmentResultMap.get(fieldName));
            }
        }
        return create.asMap();
    }

    private static boolean isRandomForest(MiningModel miningModel) {
        Segmentation segmentation = miningModel.getSegmentation();
        if (segmentation == null) {
            return false;
        }
        List<Segment> segments = segmentation.getSegments();
        boolean z = segments.size() > 3;
        Iterator<Segment> it = segments.iterator();
        while (it.hasNext()) {
            z &= it.next().getModel() instanceof TreeModel;
        }
        return z;
    }
}
