package hr.irb.fastRandomForest;

import com.feedzai.fos.impl.weka.exception.PMMLConversionException;
import com.feedzai.fos.impl.weka.utils.pmml.PMMLConsumer;
import com.feedzai.fos.impl.weka.utils.pmml.PMMLConversionCommons;
import java.io.ByteArrayInputStream;
import java.io.File;
import java.io.FileInputStream;
import java.util.List;
import javax.xml.transform.stream.StreamSource;
import org.dmg.pmml.Extension;
import org.dmg.pmml.MiningModel;
import org.dmg.pmml.Node;
import org.dmg.pmml.PMML;
import org.dmg.pmml.Predicate;
import org.dmg.pmml.ScoreDistribution;
import org.dmg.pmml.Segment;
import org.dmg.pmml.SimplePredicate;
import org.dmg.pmml.TreeModel;
import org.jpmml.model.JAXBUtil;
import weka.classifiers.Classifier;
import weka.classifiers.RandomForestUtils;
import weka.core.Attribute;
import weka.core.Instances;

/* loaded from: input_file:hr/irb/fastRandomForest/FastRandomForestPMMLConsumer.class */
public class FastRandomForestPMMLConsumer implements PMMLConsumer<FastRandomForest> {
    static final /* synthetic */ boolean $assertionsDisabled;

    @Override // com.feedzai.fos.impl.weka.utils.pmml.PMMLConsumer
    public FastRandomForest consume(String str) throws PMMLConversionException {
        try {
            ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(str.getBytes());
            Throwable th = null;
            try {
                try {
                    PMML unmarshalPMML = JAXBUtil.unmarshalPMML(new StreamSource(byteArrayInputStream));
                    if (byteArrayInputStream != null) {
                        if (0 != 0) {
                            try {
                                byteArrayInputStream.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            byteArrayInputStream.close();
                        }
                    }
                    return consume(unmarshalPMML);
                } finally {
                }
            } finally {
            }
        } catch (Exception e) {
            throw new PMMLConversionException("Failed to unmarshal PMML from string. Make sure it is a valid PMML.", e);
        }
    }

    @Override // com.feedzai.fos.impl.weka.utils.pmml.PMMLConsumer
    public FastRandomForest consume(File file) throws PMMLConversionException {
        try {
            FileInputStream fileInputStream = new FileInputStream(file);
            Throwable th = null;
            try {
                try {
                    PMML unmarshalPMML = JAXBUtil.unmarshalPMML(new StreamSource(fileInputStream));
                    if (fileInputStream != null) {
                        if (0 != 0) {
                            try {
                                fileInputStream.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            fileInputStream.close();
                        }
                    }
                    return consume(unmarshalPMML);
                } finally {
                }
            } finally {
            }
        } catch (Exception e) {
            throw new PMMLConversionException("Failed to unmarshal PMML file '" + file + "'. Make sure the file is a valid PMML.", e);
        }
    }

    @Override // com.feedzai.fos.impl.weka.utils.pmml.PMMLConsumer
    public FastRandomForest consume(PMML pmml) throws PMMLConversionException {
        MiningModel miningModel = PMMLConversionCommons.getMiningModel(pmml);
        List segments = miningModel.getSegmentation().getSegments();
        int size = segments.size();
        Instances buildInstances = PMMLConversionCommons.buildInstances(pmml.getDataDictionary());
        buildInstances.setClassIndex(PMMLConversionCommons.getClassIndex(buildInstances, miningModel.getMiningSchema()));
        FastRandomForest fastRandomForest = new FastRandomForest();
        FastRandomTree fastRandomTree = new FastRandomTree();
        fastRandomTree.m_MotherForest = fastRandomForest;
        fastRandomForest.m_bagger = new FastRfBagging();
        fastRandomForest.m_bagger.setNumIterations(size);
        fastRandomForest.m_bagger.setClassifier(fastRandomTree);
        fastRandomForest.m_Info = buildInstances;
        try {
            RandomForestUtils.setupBaggingClassifiers(fastRandomForest.m_bagger);
            Classifier[] baggingClassifiers = RandomForestUtils.getBaggingClassifiers(fastRandomForest.m_bagger);
            for (int i = 0; i < baggingClassifiers.length; i++) {
                baggingClassifiers[i] = buildRandomTree(fastRandomForest, ((Segment) segments.get(i)).getModel());
            }
            return fastRandomForest;
        } catch (Exception e) {
            throw new PMMLConversionException("Failed to initialize bagging classifiers.", e);
        }
    }

    private static FastRandomTree buildRandomTree(FastRandomForest fastRandomForest, TreeModel treeModel) {
        return buildRandomTreeNode(fastRandomForest, treeModel.getNode());
    }

    private static FastRandomTree buildRandomTreeNode(FastRandomForest fastRandomForest, Node node) {
        FastRandomTree fastRandomTree = new FastRandomTree();
        fastRandomTree.m_MotherForest = fastRandomForest;
        fastRandomTree.m_ClassProbs = getClassDistribution(node);
        Instances instances = fastRandomForest.m_Info;
        if (!(node.getNodes().size() == 0)) {
            List nodes = node.getNodes();
            Attribute attribute = instances.attribute(((Node) nodes.get(0)).getPredicate().getField().getValue());
            fastRandomTree.m_Attribute = attribute.index();
            if (attribute.isNumeric()) {
                if (!$assertionsDisabled && nodes.size() != 2) {
                    throw new AssertionError("Numeric attributes must have exactly 2 children");
                }
                Node node2 = (Node) nodes.get(0);
                Node node3 = (Node) nodes.get(1);
                SimplePredicate predicate = node2.getPredicate();
                Predicate predicate2 = node3.getPredicate();
                if (!$assertionsDisabled && (!(predicate instanceof SimplePredicate) || !predicate.getClass().equals(predicate2.getClass()))) {
                    throw new AssertionError("Numeric attribute's nodes must have the same simple predicate.");
                }
                fastRandomTree.m_SplitPoint = Double.valueOf(predicate.getValue()).doubleValue();
                fastRandomTree.m_Successors = new FastRandomTree[]{buildRandomTreeNode(fastRandomForest, node2), buildRandomTreeNode(fastRandomForest, node3)};
                fastRandomTree.m_Prop = new double[]{getNodeTrainingProportion(node2), getNodeTrainingProportion(node3)};
            } else {
                if (!attribute.isNominal()) {
                    throw new RuntimeException("Attribute type not supported: " + attribute);
                }
                fastRandomTree.m_Successors = new FastRandomTree[nodes.size()];
                fastRandomTree.m_Prop = new double[fastRandomTree.m_Successors.length];
                for (int i = 0; i < nodes.size(); i++) {
                    Node node4 = (Node) nodes.get(i);
                    int indexOfValue = attribute.indexOfValue(node4.getPredicate().getValue());
                    fastRandomTree.m_Successors[indexOfValue] = buildRandomTreeNode(fastRandomForest, node4);
                    fastRandomTree.m_Prop[indexOfValue] = getNodeTrainingProportion(node4);
                }
            }
        }
        return fastRandomTree;
    }

    private static double[] getClassDistribution(Node node) {
        List scoreDistributions = node.getScoreDistributions();
        double[] dArr = null;
        if (!scoreDistributions.isEmpty()) {
            dArr = new double[scoreDistributions.size()];
            for (int i = 0; i < scoreDistributions.size(); i++) {
                dArr[i] = ((ScoreDistribution) scoreDistributions.get(i)).getConfidence().doubleValue();
            }
        }
        return dArr;
    }

    private static double getNodeTrainingProportion(Node node) {
        for (Extension extension : node.getExtensions()) {
            if (PMMLConversionCommons.TRAINING_PROPORTION_ELEMENT.equals(extension.getName())) {
                return Double.valueOf(extension.getValue()).doubleValue();
            }
        }
        return 0.0d;
    }

    static {
        $assertionsDisabled = !FastRandomForestPMMLConsumer.class.desiredAssertionStatus();
    }
}
