package weka.classifiers.trees;

import com.feedzai.fos.impl.weka.exception.PMMLConversionException;
import com.feedzai.fos.impl.weka.utils.pmml.PMMLConsumer;
import com.feedzai.fos.impl.weka.utils.pmml.PMMLConversionCommons;
import java.io.ByteArrayInputStream;
import java.io.File;
import java.io.FileInputStream;
import java.util.List;
import javax.xml.transform.stream.StreamSource;
import org.dmg.pmml.Node;
import org.dmg.pmml.PMML;
import org.dmg.pmml.Predicate;
import org.dmg.pmml.Segment;
import org.dmg.pmml.SimplePredicate;
import org.dmg.pmml.TreeModel;
import org.jpmml.model.JAXBUtil;
import weka.classifiers.RandomForestUtils;
import weka.classifiers.meta.Bagging;
import weka.classifiers.trees.RandomTree;
import weka.core.Attribute;
import weka.core.Instances;

/* loaded from: input_file:weka/classifiers/trees/RandomForestPMMLConsumer.class */
public class RandomForestPMMLConsumer implements PMMLConsumer<RandomForest> {
    static final /* synthetic */ boolean $assertionsDisabled;

    @Override // com.feedzai.fos.impl.weka.utils.pmml.PMMLConsumer
    public RandomForest consume(String str) throws PMMLConversionException {
        try {
            ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(str.getBytes());
            Throwable th = null;
            try {
                try {
                    PMML unmarshalPMML = JAXBUtil.unmarshalPMML(new StreamSource(byteArrayInputStream));
                    if (byteArrayInputStream != null) {
                        if (0 != 0) {
                            try {
                                byteArrayInputStream.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            byteArrayInputStream.close();
                        }
                    }
                    return consume(unmarshalPMML);
                } finally {
                }
            } finally {
            }
        } catch (Exception e) {
            throw new PMMLConversionException("Failed to unmarshal PMML from string. Make sure it is a valid PMML.", e);
        }
    }

    @Override // com.feedzai.fos.impl.weka.utils.pmml.PMMLConsumer
    public RandomForest consume(File file) throws PMMLConversionException {
        try {
            FileInputStream fileInputStream = new FileInputStream(file);
            Throwable th = null;
            try {
                try {
                    PMML unmarshalPMML = JAXBUtil.unmarshalPMML(new StreamSource(fileInputStream));
                    if (fileInputStream != null) {
                        if (0 != 0) {
                            try {
                                fileInputStream.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            fileInputStream.close();
                        }
                    }
                    return consume(unmarshalPMML);
                } finally {
                }
            } finally {
            }
        } catch (Exception e) {
            throw new PMMLConversionException("Failed to unmarshal PMML file '" + file + "'. Make sure the file is a valid PMML.", e);
        }
    }

    @Override // com.feedzai.fos.impl.weka.utils.pmml.PMMLConsumer
    public RandomForest consume(PMML pmml) throws PMMLConversionException {
        List segments = PMMLConversionCommons.getMiningModel(pmml).getSegmentation().getSegments();
        int size = segments.size();
        RandomForest randomForest = new RandomForest();
        randomForest.m_bagger = new Bagging();
        randomForest.m_bagger.setNumIterations(size);
        randomForest.m_bagger.setClassifier(new RandomTree());
        try {
            RandomForestUtils.setupBaggingClassifiers(randomForest.m_bagger);
            Instances buildInstances = PMMLConversionCommons.buildInstances(pmml.getDataDictionary());
            RandomTree[] baggingClassifiers = RandomForestUtils.getBaggingClassifiers(randomForest.m_bagger);
            for (int i = 0; i < baggingClassifiers.length; i++) {
                buildRandomTree(baggingClassifiers[i], buildInstances, ((Segment) segments.get(i)).getModel());
            }
            return randomForest;
        } catch (Exception e) {
            throw new PMMLConversionException("Failed to initialize bagging classifiers.", e);
        }
    }

    private static RandomTree buildRandomTree(RandomTree randomTree, Instances instances, TreeModel treeModel) {
        Instances instances2 = new Instances(instances);
        instances2.setClassIndex(PMMLConversionCommons.getClassIndex(instances, treeModel));
        randomTree.m_Info = instances2;
        randomTree.m_Tree = buildRandomTreeNode(randomTree, treeModel.getNode());
        return randomTree;
    }

    private static RandomTree.Tree buildRandomTreeNode(RandomTree randomTree, Node node) {
        randomTree.getClass();
        RandomTree.Tree tree = new RandomTree.Tree(randomTree);
        tree.m_ClassDistribution = PMMLConversionCommons.getClassDistribution(node);
        Instances instances = randomTree.m_Info;
        if (!(node.getNodes().size() == 0)) {
            List nodes = node.getNodes();
            Attribute attribute = instances.attribute(((Node) nodes.get(0)).getPredicate().getField().getValue());
            tree.m_Attribute = attribute.index();
            if (attribute.isNumeric()) {
                if (!$assertionsDisabled && nodes.size() != 2) {
                    throw new AssertionError("Numeric attributes must have exactly 2 children");
                }
                Node node2 = (Node) nodes.get(0);
                Node node3 = (Node) nodes.get(1);
                SimplePredicate predicate = node2.getPredicate();
                Predicate predicate2 = node3.getPredicate();
                if (!$assertionsDisabled && (!(predicate instanceof SimplePredicate) || !predicate.getClass().equals(predicate2.getClass()))) {
                    throw new AssertionError("Numeric attribute's nodes must have the same simple predicate.");
                }
                tree.m_SplitPoint = Double.valueOf(predicate.getValue()).doubleValue();
                tree.m_Successors = new RandomTree.Tree[]{buildRandomTreeNode(randomTree, node2), buildRandomTreeNode(randomTree, node3)};
                tree.m_Prop = new double[]{PMMLConversionCommons.getNodeTrainingProportion(node2), PMMLConversionCommons.getNodeTrainingProportion(node3)};
            } else {
                if (!attribute.isNominal()) {
                    throw new RuntimeException("Attribute type not supported: " + attribute);
                }
                tree.m_Successors = new RandomTree.Tree[nodes.size()];
                tree.m_Prop = new double[tree.m_Successors.length];
                for (int i = 0; i < nodes.size(); i++) {
                    Node node4 = (Node) nodes.get(i);
                    int indexOfValue = attribute.indexOfValue(node4.getPredicate().getValue());
                    tree.m_Successors[indexOfValue] = buildRandomTreeNode(randomTree, node4);
                    tree.m_Prop[indexOfValue] = PMMLConversionCommons.getNodeTrainingProportion(node4);
                }
            }
        }
        return tree;
    }

    static {
        $assertionsDisabled = !RandomForestPMMLConsumer.class.desiredAssertionStatus();
    }
}
