package com.feedzai.fos.impl.weka.utils.pmml;

import com.feedzai.fos.impl.weka.exception.PMMLConversionException;
import com.google.common.collect.ImmutableList;
import hr.irb.fastRandomForest.FastRandomForest;
import hr.irb.fastRandomForest.FastRandomForestPMMLConsumer;
import hr.irb.fastRandomForest.FastRandomForestPMMLProducer;
import java.util.Iterator;
import java.util.List;
import org.dmg.pmml.Application;
import org.dmg.pmml.DataDictionary;
import org.dmg.pmml.DataField;
import org.dmg.pmml.Extension;
import org.dmg.pmml.FieldUsageType;
import org.dmg.pmml.Header;
import org.dmg.pmml.MiningField;
import org.dmg.pmml.MiningModel;
import org.dmg.pmml.MiningSchema;
import org.dmg.pmml.Node;
import org.dmg.pmml.OpType;
import org.dmg.pmml.PMML;
import org.dmg.pmml.ScoreDistribution;
import org.dmg.pmml.TreeModel;
import org.dmg.pmml.Value;
import weka.classifiers.Classifier;
import weka.classifiers.trees.RandomForest;
import weka.classifiers.trees.RandomForestPMMLConsumer;
import weka.classifiers.trees.RandomForestPMMLProducer;
import weka.core.Attribute;
import weka.core.FastVector;
import weka.core.Instances;
import weka.core.Utils;

/* loaded from: input_file:com/feedzai/fos/impl/weka/utils/pmml/PMMLConversionCommons.class */
public final class PMMLConversionCommons {
    public static final String ALGORITHM_EXTENSION_ELEMENT = "algorithm";
    public static final String TRAINING_PROPORTION_ELEMENT = "trainingProportion";

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: com.feedzai.fos.impl.weka.utils.pmml.PMMLConversionCommons$1, reason: invalid class name */
    /* loaded from: input_file:com/feedzai/fos/impl/weka/utils/pmml/PMMLConversionCommons$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$dmg$pmml$OpType = new int[OpType.values().length];

        static {
            try {
                $SwitchMap$org$dmg$pmml$OpType[OpType.CONTINUOUS.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$dmg$pmml$OpType[OpType.CATEGORICAL.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
        }
    }

    /* loaded from: input_file:com/feedzai/fos/impl/weka/utils/pmml/PMMLConversionCommons$Algorithm.class */
    public enum Algorithm {
        RANDOM_FOREST { // from class: com.feedzai.fos.impl.weka.utils.pmml.PMMLConversionCommons.Algorithm.1
            @Override // com.feedzai.fos.impl.weka.utils.pmml.PMMLConversionCommons.Algorithm
            public PMMLConsumer getPMMLConsumer() {
                return new RandomForestPMMLConsumer();
            }

            @Override // com.feedzai.fos.impl.weka.utils.pmml.PMMLConversionCommons.Algorithm
            public PMMLProducer getPMMLProducer() {
                return new RandomForestPMMLProducer();
            }

            @Override // com.feedzai.fos.impl.weka.utils.pmml.PMMLConversionCommons.Algorithm
            public Class<? extends Classifier> getClassifierClass() {
                return RandomForest.class;
            }
        },
        FAST_RANDOM_FOREST { // from class: com.feedzai.fos.impl.weka.utils.pmml.PMMLConversionCommons.Algorithm.2
            @Override // com.feedzai.fos.impl.weka.utils.pmml.PMMLConversionCommons.Algorithm
            public PMMLConsumer getPMMLConsumer() {
                return new FastRandomForestPMMLConsumer();
            }

            @Override // com.feedzai.fos.impl.weka.utils.pmml.PMMLConversionCommons.Algorithm
            public PMMLProducer getPMMLProducer() {
                return new FastRandomForestPMMLProducer();
            }

            @Override // com.feedzai.fos.impl.weka.utils.pmml.PMMLConversionCommons.Algorithm
            public Class<? extends Classifier> getClassifierClass() {
                return FastRandomForest.class;
            }
        };

        public abstract PMMLConsumer getPMMLConsumer();

        public abstract PMMLProducer getPMMLProducer();

        public abstract Class<? extends Classifier> getClassifierClass();

        public static Algorithm fromClassifier(Classifier classifier) throws PMMLConversionException {
            Class<?> cls = classifier.getClass();
            for (Algorithm algorithm : values()) {
                if (cls.equals(algorithm.getClassifierClass())) {
                    return algorithm;
                }
            }
            throw new PMMLConversionException("Unsupported classifier '" + classifier.getClass().getSimpleName() + "'.");
        }

        /* synthetic */ Algorithm(AnonymousClass1 anonymousClass1) {
            this();
        }
    }

    public static Header buildPMMLHeader(String str) {
        return new Header().withCopyright("www.dmg.org").withDescription(str).withApplication(new Application("Feedzai FOS-Weka").withVersion("1.0.4"));
    }

    public static void addScoreDistribution(Node node, double[] dArr, Instances instances) {
        double d;
        double length;
        if (dArr != null) {
            double d2 = 0.0d;
            for (double d3 : dArr) {
                d2 += d3;
            }
            for (int i = 0; i < dArr.length; i++) {
                String value = instances.classAttribute().value(i);
                if (d2 != 0.0d) {
                    d = dArr[i];
                    length = d2;
                } else {
                    d = 1.0d;
                    length = dArr.length;
                }
                double d4 = d / length;
                ScoreDistribution scoreDistribution = new ScoreDistribution(value, 0.0d);
                scoreDistribution.withConfidence(Double.valueOf(dArr[i]));
                scoreDistribution.withProbability(Double.valueOf(d4));
                node.withScoreDistributions(new ScoreDistribution[]{scoreDistribution});
            }
        }
    }

    public static String leafScoreFromDistribution(double[] dArr, Instances instances) {
        int i = 0;
        if (dArr != null) {
            Utils.sum(dArr);
            i = Utils.maxIndex(dArr);
            double d = dArr[i];
        }
        return instances.classAttribute().value(i);
    }

    public static double[] getClassDistribution(Node node) {
        List scoreDistributions = node.getScoreDistributions();
        double[] dArr = null;
        if (!scoreDistributions.isEmpty()) {
            dArr = new double[scoreDistributions.size()];
            for (int i = 0; i < scoreDistributions.size(); i++) {
                dArr[i] = ((ScoreDistribution) scoreDistributions.get(i)).getConfidence().doubleValue();
            }
        }
        return dArr;
    }

    public static double getNodeTrainingProportion(Node node) {
        for (Extension extension : node.getExtensions()) {
            if (TRAINING_PROPORTION_ELEMENT.equals(extension.getName())) {
                return Double.valueOf(extension.getValue()).doubleValue();
            }
        }
        return 0.0d;
    }

    public static int getClassIndex(Instances instances, TreeModel treeModel) {
        String str = null;
        Iterator it = treeModel.getMiningSchema().getMiningFields().iterator();
        while (true) {
            if (!it.hasNext()) {
                break;
            }
            MiningField miningField = (MiningField) it.next();
            if (miningField.getUsageType() == FieldUsageType.PREDICTED) {
                str = miningField.getName().getValue();
                break;
            }
        }
        return instances.attribute(str).index();
    }

    public static int getClassIndex(Instances instances, MiningSchema miningSchema) {
        String str = null;
        Iterator it = miningSchema.getMiningFields().iterator();
        while (true) {
            if (!it.hasNext()) {
                break;
            }
            MiningField miningField = (MiningField) it.next();
            if (miningField.getUsageType() == FieldUsageType.PREDICTED) {
                str = miningField.getName().getValue();
                break;
            }
        }
        return instances.attribute(str).index();
    }

    public static Instances buildInstances(DataDictionary dataDictionary) {
        List<Attribute> buildAttributes = buildAttributes(dataDictionary);
        FastVector fastVector = new FastVector(buildAttributes.size());
        Iterator<Attribute> it = buildAttributes.iterator();
        while (it.hasNext()) {
            fastVector.addElement(it.next());
        }
        return new Instances("instances", fastVector, buildAttributes.size());
    }

    public static List<Attribute> buildAttributes(DataDictionary dataDictionary) {
        ImmutableList.Builder builder = ImmutableList.builder();
        Iterator it = dataDictionary.getDataFields().iterator();
        while (it.hasNext()) {
            builder.add(buildAttribute((DataField) it.next()));
        }
        return builder.build();
    }

    public static Attribute buildAttribute(DataField dataField) {
        switch (AnonymousClass1.$SwitchMap$org$dmg$pmml$OpType[dataField.getOptype().ordinal()]) {
            case 1:
                return new Attribute(dataField.getName().getValue());
            case 2:
                List values = dataField.getValues();
                FastVector fastVector = new FastVector();
                Iterator it = values.iterator();
                while (it.hasNext()) {
                    fastVector.addElement(((Value) it.next()).getValue());
                }
                return new Attribute(dataField.getName().getValue(), fastVector);
            default:
                throw new RuntimeException("PMML DataField OPTYPE " + dataField.getOptype() + " not supported.");
        }
    }

    public static MiningModel getMiningModel(PMML pmml) {
        for (MiningModel miningModel : pmml.getModels()) {
            if (miningModel instanceof MiningModel) {
                return miningModel;
            }
        }
        throw new RuntimeException("PMML MiningModel not found.");
    }
}
