package weka.classifiers.trees;

import com.feedzai.fos.impl.weka.exception.PMMLConversionException;
import com.feedzai.fos.impl.weka.utils.pmml.PMMLConversionCommons;
import com.feedzai.fos.impl.weka.utils.pmml.PMMLProducer;
import java.io.File;
import java.io.FileOutputStream;
import java.util.ArrayList;
import java.util.Enumeration;
import javax.xml.transform.stream.StreamResult;
import org.dmg.pmml.DataDictionary;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DataType;
import org.dmg.pmml.Extension;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.FieldUsageType;
import org.dmg.pmml.MiningField;
import org.dmg.pmml.MiningFunctionType;
import org.dmg.pmml.MiningModel;
import org.dmg.pmml.MiningSchema;
import org.dmg.pmml.Model;
import org.dmg.pmml.MultipleModelMethodType;
import org.dmg.pmml.Node;
import org.dmg.pmml.OpType;
import org.dmg.pmml.PMML;
import org.dmg.pmml.Segment;
import org.dmg.pmml.Segmentation;
import org.dmg.pmml.SimplePredicate;
import org.dmg.pmml.TreeModel;
import org.dmg.pmml.True;
import org.dmg.pmml.Value;
import org.jpmml.model.JAXBUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import weka.classifiers.RandomForestUtils;
import weka.classifiers.trees.RandomTree;
import weka.core.Attribute;
import weka.core.Instances;

/* loaded from: input_file:weka/classifiers/trees/RandomForestPMMLProducer.class */
public class RandomForestPMMLProducer implements PMMLProducer<RandomForest> {
    private static final Logger logger;
    private static final String ALGORITHM_NAME;
    private static final String MODEL_NAME;
    static final /* synthetic */ boolean $assertionsDisabled;

    @Override // com.feedzai.fos.impl.weka.utils.pmml.PMMLProducer
    public void produce(RandomForest randomForest, File file) throws PMMLConversionException {
        PMML produce = produce(randomForest);
        try {
            FileOutputStream fileOutputStream = new FileOutputStream(file);
            Throwable th = null;
            try {
                try {
                    JAXBUtil.marshalPMML(produce, new StreamResult(fileOutputStream));
                    if (fileOutputStream != null) {
                        if (0 != 0) {
                            try {
                                fileOutputStream.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            fileOutputStream.close();
                        }
                    }
                } finally {
                }
            } finally {
            }
        } catch (Exception e) {
            throw new PMMLConversionException("Failed to marshal the PMML to the given file.", e);
        }
    }

    @Override // com.feedzai.fos.impl.weka.utils.pmml.PMMLProducer
    public PMML produce(RandomForest randomForest) {
        Instances instances = RandomForestUtils.getBaggingClassifiers(randomForest.m_bagger)[0].m_Info;
        PMML pmml = new PMML(PMMLConversionCommons.buildPMMLHeader("Weka RandomForest as PMML."), new DataDictionary(), "4.2");
        DataDictionary dataDictionary = new DataDictionary();
        MiningSchema miningSchema = new MiningSchema();
        if (instances != null) {
            for (int i = 0; i < instances.numAttributes(); i++) {
                Attribute attribute = instances.attribute(i);
                DataField dataField = new DataField(new FieldName(attribute.name()), attribute.isNominal() ? OpType.CATEGORICAL : OpType.CONTINUOUS, attribute.isNumeric() ? DataType.DOUBLE : DataType.STRING);
                if (attribute.isNominal()) {
                    Enumeration enumerateValues = attribute.enumerateValues();
                    while (enumerateValues.hasMoreElements()) {
                        dataField.withValues(new Value[]{new Value(String.valueOf(enumerateValues.nextElement()))});
                    }
                }
                dataDictionary.withDataFields(new DataField[]{dataField});
                MiningField miningField = new MiningField(new FieldName(attribute.name()));
                if (instances.classIndex() == i) {
                    miningField.withUsageType(FieldUsageType.PREDICTED);
                } else {
                    miningField.withUsageType(FieldUsageType.ACTIVE);
                }
                miningSchema.withMiningFields(new MiningField[]{miningField});
            }
        }
        pmml.withDataDictionary(dataDictionary);
        Model miningModel = new MiningModel(miningSchema, MiningFunctionType.CLASSIFICATION);
        miningModel.withModelName(MODEL_NAME);
        pmml.withModels(new Model[]{miningModel});
        Segmentation segmentation = new Segmentation(MultipleModelMethodType.MAJORITY_VOTE);
        miningModel.withSegmentation(segmentation);
        if (randomForest.m_bagger != null) {
            int i2 = 1;
            for (RandomTree randomTree : RandomForestUtils.getBaggingClassifiers(randomForest.m_bagger)) {
                int i3 = i2;
                i2++;
                segmentation.withSegments(new Segment[]{buildSegment(miningSchema, i3, randomTree)});
            }
        }
        return pmml;
    }

    private static Segment buildSegment(MiningSchema miningSchema, int i, RandomTree randomTree) {
        Node withPredicate = new Node().withId(String.valueOf(1)).withPredicate(new True());
        TreeModel withSplitCharacteristic = new TreeModel(miningSchema, withPredicate, MiningFunctionType.CLASSIFICATION).withAlgorithmName(ALGORITHM_NAME).withModelName(MODEL_NAME).withSplitCharacteristic(TreeModel.SplitCharacteristic.MULTI_SPLIT);
        buildTreeNode(randomTree, randomTree.m_Tree, 1, withPredicate);
        Segment segment = new Segment();
        segment.withId(String.valueOf(i));
        segment.withModel(withSplitCharacteristic);
        return segment;
    }

    private static int buildTreeNode(RandomTree randomTree, RandomTree.Tree tree, int i, Node node) {
        PMMLConversionCommons.addScoreDistribution(node, tree.m_ClassDistribution, randomTree.m_Info);
        if (tree.m_Attribute == -1) {
            node.withScore(PMMLConversionCommons.leafScoreFromDistribution(tree.m_ClassDistribution, randomTree.m_Info));
            return i;
        }
        Attribute attribute = randomTree.m_Info.attribute(tree.m_Attribute);
        if (attribute.isNominal()) {
            return buildNominalNode(randomTree, attribute, tree, i, node);
        }
        if (attribute.isNumeric()) {
            return buildNumericNode(randomTree, attribute, tree, i, node);
        }
        throw new RuntimeException("Unsupported attribute type for: " + attribute);
    }

    private static int buildNominalNode(RandomTree randomTree, Attribute attribute, RandomTree.Tree tree, int i, Node node) {
        ArrayList arrayList = new ArrayList();
        Enumeration enumerateValues = attribute.enumerateValues();
        while (enumerateValues.hasMoreElements()) {
            arrayList.add(enumerateValues.nextElement());
        }
        if (!$assertionsDisabled && arrayList.size() != tree.m_Successors.length) {
            throw new AssertionError("Number of successors expected to be the same as the number of attribute values");
        }
        ArrayList arrayList2 = new ArrayList();
        for (int i2 = 0; i2 < arrayList.size(); i2++) {
            int i3 = i + 1;
            Node withPredicate = new Node().withId(String.valueOf(i3)).withPredicate(new SimplePredicate(new FieldName(attribute.name()), SimplePredicate.Operator.EQUAL).withValue(String.valueOf(arrayList.get(i2))));
            i = buildTreeNode(randomTree, tree.m_Successors[i2], i3, withPredicate);
            withPredicate.withExtensions(new Extension[]{new Extension().withName(PMMLConversionCommons.TRAINING_PROPORTION_ELEMENT).withValue(String.valueOf(tree.m_Prop[i2]))});
            arrayList2.add(withPredicate);
        }
        node.withNodes(arrayList2);
        return i;
    }

    private static int buildNumericNode(RandomTree randomTree, Attribute attribute, RandomTree.Tree tree, int i, Node node) {
        SimplePredicate withValue = new SimplePredicate(new FieldName(attribute.name()), SimplePredicate.Operator.LESS_THAN).withValue(String.valueOf(tree.m_SplitPoint));
        SimplePredicate withValue2 = new SimplePredicate(new FieldName(attribute.name()), SimplePredicate.Operator.GREATER_OR_EQUAL).withValue(String.valueOf(tree.m_SplitPoint));
        int i2 = i + 1;
        Node withId = new Node().withId(String.valueOf(i2));
        withId.withPredicate(withValue);
        int buildTreeNode = buildTreeNode(randomTree, tree.m_Successors[0], i2, withId) + 1;
        Node withId2 = new Node().withId(String.valueOf(buildTreeNode));
        withId2.withPredicate(withValue2);
        int buildTreeNode2 = buildTreeNode(randomTree, tree.m_Successors[1], buildTreeNode, withId2);
        withId.withExtensions(new Extension[]{new Extension().withName(PMMLConversionCommons.TRAINING_PROPORTION_ELEMENT).withValue(String.valueOf(tree.m_Prop[0]))});
        withId2.withExtensions(new Extension[]{new Extension().withName(PMMLConversionCommons.TRAINING_PROPORTION_ELEMENT).withValue(String.valueOf(tree.m_Prop[1]))});
        node.withNodes(new Node[]{withId, withId2});
        return buildTreeNode2;
    }

    static {
        $assertionsDisabled = !RandomForestPMMLProducer.class.desiredAssertionStatus();
        logger = LoggerFactory.getLogger(RandomForestPMMLProducer.class);
        ALGORITHM_NAME = "weka:" + RandomForest.class.getName();
        MODEL_NAME = ALGORITHM_NAME + "_Model";
    }
}
