package org.jpmml.rexp;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import org.dmg.pmml.CompoundPredicate;
import org.dmg.pmml.DataType;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Predicate;
import org.dmg.pmml.ScoreFrequency;
import org.dmg.pmml.SimplePredicate;
import org.dmg.pmml.True;
import org.dmg.pmml.tree.ClassifierNode;
import org.dmg.pmml.tree.CountingBranchNode;
import org.dmg.pmml.tree.CountingLeafNode;
import org.dmg.pmml.tree.Node;
import org.dmg.pmml.tree.TreeModel;
import org.jpmml.converter.CategoricalFeature;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.Feature;
import org.jpmml.converter.FeatureImportanceMap;
import org.jpmml.converter.FortranMatrixUtil;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.ValueUtil;

/* loaded from: input_file:org/jpmml/rexp/RPartConverter.class */
public class RPartConverter extends TreeModelConverter<RGenericVector> implements HasFeatureImportances {
    private int useSurrogate;
    private Formula formula;
    private static final int INDEX_LEAF = 0;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/jpmml/rexp/RPartConverter$ScoreEncoder.class */
    public interface ScoreEncoder {
        Node encode(Node node, int i);
    }

    /* JADX WARN: Multi-variable type inference failed */
    public RPartConverter(RGenericVector rGenericVector) {
        super(rGenericVector);
        this.useSurrogate = 0;
        this.formula = null;
        this.useSurrogate = ValueUtil.asInt((Number) rGenericVector.getGenericElement("control").getNumericElement("usesurrogate").asScalar());
        switch (this.useSurrogate) {
            case 0:
            case 1:
            case 2:
                return;
            default:
                throw new IllegalArgumentException();
        }
    }

    public boolean hasScoreDistribution() {
        return true;
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.jpmml.rexp.ModelConverter
    public void encodeSchema(RExpEncoder rExpEncoder) {
        List<String> featureNames;
        RGenericVector rGenericVector = (RGenericVector) getObject();
        RGenericVector genericElement = rGenericVector.getGenericElement("frame");
        RExp element = rGenericVector.getElement("terms");
        RGenericVector genericAttribute = rGenericVector.getGenericAttribute("xlevels", false);
        RStringVector stringAttribute = rGenericVector.getStringAttribute("ylevels", false);
        RVector<?> vectorElement = genericElement.getVectorElement("var");
        Formula createFormula = FormulaUtil.createFormula(element, new XLevelsFormulaContext(genericAttribute), rExpEncoder);
        FormulaUtil.setLabel(createFormula, element, stringAttribute, rExpEncoder);
        if (vectorElement instanceof RStringVector) {
            featureNames = getFeatureNames(((RStringVector) vectorElement).getValues());
        } else {
            if (!(vectorElement instanceof RFactorVector)) {
                throw new IllegalArgumentException();
            }
            featureNames = getFeatureNames(((RFactorVector) vectorElement).getFactorValues());
        }
        FormulaUtil.addFeatures(createFormula, featureNames, false, rExpEncoder);
        this.formula = createFormula;
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.jpmml.rexp.ModelConverter
    /* renamed from: encodeModel, reason: merged with bridge method [inline-methods] */
    public TreeModel mo0encodeModel(Schema schema) {
        RGenericVector rGenericVector = (RGenericVector) getObject();
        RGenericVector genericElement = rGenericVector.getGenericElement("frame");
        RStringVector stringElement = rGenericVector.getStringElement("method");
        RNumberVector<?> numericElement = rGenericVector.getNumericElement("splits");
        RIntegerVector integerElement = rGenericVector.getIntegerElement("csplit", false);
        RVector<?> vectorElement = genericElement.getVectorElement("var");
        RIntegerVector integerElement2 = genericElement.getIntegerElement("n");
        RIntegerVector integerElement3 = genericElement.getIntegerElement("ncompete");
        RIntegerVector integerElement4 = genericElement.getIntegerElement("nsurrogate");
        RIntegerVector integerAttribute = genericElement.getIntegerAttribute("row.names");
        if (integerAttribute.getValues().indexOf(Integer.MIN_VALUE) > -1) {
            throw new IllegalArgumentException();
        }
        List features = schema.getFeatures();
        int[][] iArr = new int[1 + integerAttribute.size()][3];
        for (int i = 0; i < integerAttribute.size(); i++) {
            int featureIndex = getFeatureIndex(vectorElement, i, features);
            iArr[i][1] = integerElement3.getValue(i).intValue();
            iArr[i][2] = integerElement4.getValue(i).intValue();
            iArr[i + 1][0] = iArr[i][0] + iArr[i][1] + iArr[i][2] + (featureIndex != 0 ? 1 : 0);
        }
        String asScalar = stringElement.asScalar();
        boolean z = -1;
        switch (asScalar.hashCode()) {
            case 92968973:
                if (asScalar.equals("anova")) {
                    z = false;
                    break;
                }
                break;
            case 94742904:
                if (asScalar.equals("class")) {
                    z = true;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                return encodeRegression(genericElement, integerAttribute, vectorElement, integerElement2, iArr, numericElement, integerElement, schema);
            case true:
                return encodeClassification(genericElement, integerAttribute, vectorElement, integerElement2, iArr, numericElement, integerElement, schema);
            default:
                throw new IllegalArgumentException();
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.jpmml.rexp.HasFeatureImportances
    public FeatureImportanceMap getFeatureImportances(Schema schema) {
        RDoubleVector doubleElement = ((RGenericVector) getObject()).getDoubleElement("variable.importance", false);
        if (doubleElement == null) {
            return null;
        }
        List features = schema.getFeatures();
        FeatureImportanceMap featureImportanceMap = new FeatureImportanceMap((String) null);
        for (int i = 0; i < features.size(); i++) {
            Feature feature = (Feature) features.get(i);
            featureImportanceMap.put(feature, (Double) doubleElement.getElement(feature.getName()));
        }
        return featureImportanceMap;
    }

    private TreeModel encodeRegression(RGenericVector rGenericVector, RIntegerVector rIntegerVector, RVector<?> rVector, final RIntegerVector rIntegerVector2, int[][] iArr, RNumberVector<?> rNumberVector, RIntegerVector rIntegerVector3, Schema schema) {
        final RNumberVector<?> numericElement = rGenericVector.getNumericElement("yval");
        return configureTreeModel(new TreeModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(schema.getLabel()), encodeNode(True.INSTANCE, 1, rIntegerVector, rVector, rIntegerVector2, iArr, rNumberVector, rIntegerVector3, new ScoreEncoder() { // from class: org.jpmml.rexp.RPartConverter.1
            /* JADX WARN: Multi-variable type inference failed */
            @Override // org.jpmml.rexp.RPartConverter.ScoreEncoder
            public Node encode(Node node, int i) {
                Number number = (Number) numericElement.getValue(i);
                node.setScore(number).setRecordCount(rIntegerVector2.getValue(i));
                return node;
            }
        }, schema)));
    }

    private TreeModel encodeClassification(RGenericVector rGenericVector, final RIntegerVector rIntegerVector, RVector<?> rVector, final RIntegerVector rIntegerVector2, int[][] iArr, RNumberVector<?> rNumberVector, RIntegerVector rIntegerVector3, Schema schema) {
        final RDoubleVector doubleElement = rGenericVector.getDoubleElement("yval2");
        CategoricalLabel label = schema.getLabel();
        final List values = label.getValues();
        final boolean hasScoreDistribution = hasScoreDistribution();
        TreeModel treeModel = new TreeModel(MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema(schema.getLabel()), encodeNode(True.INSTANCE, 1, rIntegerVector, rVector, rIntegerVector2, iArr, rNumberVector, rIntegerVector3, new ScoreEncoder() { // from class: org.jpmml.rexp.RPartConverter.2
            private List<Integer> classes;
            private List<List<? extends Number>> recordCounts;

            {
                this.classes = null;
                this.recordCounts = null;
                int size = rIntegerVector.size();
                int size2 = 1 + (2 * values.size()) + 1;
                this.classes = new ArrayList(ValueUtil.asIntegers(FortranMatrixUtil.getColumn(doubleElement.getValues(), size, size2, 0)));
                if (hasScoreDistribution) {
                    this.recordCounts = new ArrayList();
                    for (int i = 0; i < values.size(); i++) {
                        this.recordCounts.add(new ArrayList(FortranMatrixUtil.getColumn(doubleElement.getValues(), size, size2, 1 + i)));
                    }
                }
            }

            @Override // org.jpmml.rexp.RPartConverter.ScoreEncoder
            public Node encode(Node node, int i) {
                node.setScore(values.get(this.classes.get(i).intValue() - 1)).setRecordCount(rIntegerVector2.getValue(i));
                if (hasScoreDistribution) {
                    node = new ClassifierNode(node);
                    List scoreDistributions = node.getScoreDistributions();
                    for (int i2 = 0; i2 < values.size(); i2++) {
                        scoreDistributions.add(new ScoreFrequency().setValue(values.get(i2)).setRecordCount(this.recordCounts.get(i2).get(i)));
                    }
                }
                return node;
            }
        }, schema));
        if (hasScoreDistribution) {
            treeModel.setOutput(ModelUtil.createProbabilityOutput(DataType.DOUBLE, label));
        }
        return configureTreeModel(treeModel);
    }

    private TreeModel configureTreeModel(TreeModel treeModel) {
        TreeModel.MissingValueStrategy missingValueStrategy;
        TreeModel.NoTrueChildStrategy noTrueChildStrategy = TreeModel.NoTrueChildStrategy.RETURN_LAST_PREDICTION;
        switch (this.useSurrogate) {
            case 0:
                missingValueStrategy = TreeModel.MissingValueStrategy.NULL_PREDICTION;
                break;
            case 1:
                missingValueStrategy = TreeModel.MissingValueStrategy.LAST_PREDICTION;
                break;
            case 2:
                missingValueStrategy = null;
                break;
            default:
                throw new IllegalArgumentException();
        }
        treeModel.setNoTrueChildStrategy(noTrueChildStrategy).setMissingValueStrategy(missingValueStrategy);
        return treeModel;
    }

    private Node encodeNode(Predicate predicate, int i, RIntegerVector rIntegerVector, RVector<?> rVector, RIntegerVector rIntegerVector2, int[][] iArr, RNumberVector<?> rNumberVector, RIntegerVector rIntegerVector3, ScoreEncoder scoreEncoder, Schema schema) {
        int index = getIndex(rIntegerVector, i);
        Integer valueOf = Integer.valueOf(i);
        List features = schema.getFeatures();
        int featureIndex = getFeatureIndex(rVector, index, features);
        if (featureIndex == 0) {
            return scoreEncoder.encode(new CountingLeafNode((Object) null, predicate).setId(valueOf), index);
        }
        int i2 = i * 2;
        int i3 = (i * 2) + 1;
        Integer valueOf2 = this.useSurrogate == 2 ? Integer.valueOf(Double.compare(rIntegerVector2.getValue(getIndex(rIntegerVector, i2)).intValue(), rIntegerVector2.getValue(getIndex(rIntegerVector, i3)).intValue())) : null;
        Feature feature = (Feature) features.get(featureIndex - 1);
        int i4 = iArr[index][0];
        int i5 = iArr[index][1];
        int i6 = iArr[index][2];
        List<Predicate> encodePredicates = encodePredicates(feature, i4, rNumberVector, rIntegerVector3);
        CompoundPredicate compoundPredicate = (Predicate) encodePredicates.get(0);
        CompoundPredicate compoundPredicate2 = (Predicate) encodePredicates.get(1);
        if (this.useSurrogate > 0 && i6 > 0) {
            CompoundPredicate addPredicates = new CompoundPredicate(CompoundPredicate.BooleanOperator.SURROGATE, (List) null).addPredicates(new Predicate[]{compoundPredicate});
            CompoundPredicate addPredicates2 = new CompoundPredicate(CompoundPredicate.BooleanOperator.SURROGATE, (List) null).addPredicates(new Predicate[]{compoundPredicate2});
            RStringVector dimnames = rNumberVector.dimnames(0);
            for (int i7 = 0; i7 < i6; i7++) {
                int i8 = i4 + 1 + i5 + i7;
                List<Predicate> encodePredicates2 = encodePredicates(getFeature(dimnames.getValue(i8)), i8, rNumberVector, rIntegerVector3);
                addPredicates.addPredicates(new Predicate[]{encodePredicates2.get(0)});
                addPredicates2.addPredicates(new Predicate[]{encodePredicates2.get(1)});
            }
            compoundPredicate = addPredicates;
            compoundPredicate2 = addPredicates2;
        }
        Node encodeNode = encodeNode(compoundPredicate, i2, rIntegerVector, rVector, rIntegerVector2, iArr, rNumberVector, rIntegerVector3, scoreEncoder, schema);
        Node encodeNode2 = encodeNode(compoundPredicate2, i3, rIntegerVector, rVector, rIntegerVector2, iArr, rNumberVector, rIntegerVector3, scoreEncoder, schema);
        if (this.useSurrogate == 2) {
            if (valueOf2.intValue() < 0) {
                makeDefault(encodeNode2);
            } else if (valueOf2.intValue() > 0) {
                makeDefault(encodeNode);
                encodeNode = encodeNode2;
                encodeNode2 = encodeNode;
            }
        }
        return scoreEncoder.encode(new CountingBranchNode((Object) null, predicate).setId(valueOf).addNodes(encodeNode, encodeNode2), index);
    }

    private List<Predicate> encodePredicates(Feature feature, int i, RNumberVector<?> rNumberVector, RIntegerVector rIntegerVector) {
        Predicate createPredicate;
        Predicate createPredicate2;
        SimplePredicate.Operator operator;
        SimplePredicate.Operator operator2;
        RIntegerVector dim = rNumberVector.dim();
        int intValue = dim.getValue(0).intValue();
        int intValue2 = dim.getValue(1).intValue();
        List column = FortranMatrixUtil.getColumn(rNumberVector.getValues(), intValue, intValue2, 1);
        List column2 = FortranMatrixUtil.getColumn(rNumberVector.getValues(), intValue, intValue2, 3);
        int asInt = ValueUtil.asInt((Number) column.get(i));
        Number number = (Number) column2.get(i);
        if (Math.abs(asInt) == 1) {
            if (asInt == -1) {
                operator = SimplePredicate.Operator.LESS_THAN;
                operator2 = SimplePredicate.Operator.GREATER_OR_EQUAL;
            } else {
                operator = SimplePredicate.Operator.GREATER_OR_EQUAL;
                operator2 = SimplePredicate.Operator.LESS_THAN;
            }
            createPredicate = createSimplePredicate(feature, operator, number);
            createPredicate2 = createSimplePredicate(feature, operator2, number);
        } else {
            Feature feature2 = (CategoricalFeature) feature;
            RIntegerVector dim2 = rIntegerVector.dim();
            List row = FortranMatrixUtil.getRow(rIntegerVector.getValues(), dim2.getValue(0).intValue(), dim2.getValue(1).intValue(), ValueUtil.asInt(number) - 1);
            List values = feature2.getValues();
            createPredicate = createPredicate(feature2, selectValues(values, row, 1));
            createPredicate2 = createPredicate(feature2, selectValues(values, row, 3));
        }
        return Arrays.asList(createPredicate, createPredicate2);
    }

    private void makeDefault(Node node) {
        Predicate addPredicates;
        Predicate requirePredicate = node.requirePredicate();
        if (requirePredicate instanceof CompoundPredicate) {
            addPredicates = (CompoundPredicate) requirePredicate;
        } else {
            addPredicates = new CompoundPredicate(CompoundPredicate.BooleanOperator.SURROGATE, (List) null).addPredicates(new Predicate[]{requirePredicate});
            node.setPredicate(addPredicates);
        }
        addPredicates.addPredicates(new Predicate[]{True.INSTANCE});
    }

    private Feature getFeature(String str) {
        return this.formula.resolveComplexFeature(str);
    }

    private static List<String> getFeatureNames(List<String> list) {
        return (List) list.stream().filter(str -> {
            return !"<leaf>".equals(str);
        }).distinct().collect(Collectors.toList());
    }

    private static int getFeatureIndex(RVector<?> rVector, int i, List<? extends Feature> list) {
        if (!(rVector instanceof RStringVector)) {
            if (rVector instanceof RFactorVector) {
                return ((RFactorVector) rVector).getValue(i).intValue() - 1;
            }
            throw new IllegalArgumentException();
        }
        String value = ((RStringVector) rVector).getValue(i);
        if ("<leaf>".equals(value)) {
            return 0;
        }
        for (int i2 = 0; i2 < list.size(); i2++) {
            if (list.get(i2).getName().equals(value)) {
                return i2 + 1;
            }
        }
        throw new IllegalArgumentException();
    }

    private static int getIndex(RIntegerVector rIntegerVector, int i) {
        int indexOf = rIntegerVector.indexOf(Integer.valueOf(i));
        if (indexOf < 0) {
            throw new IllegalArgumentException();
        }
        return indexOf;
    }

    private static <E> List<E> selectValues(List<E> list, List<Integer> list2, int i) {
        ArrayList arrayList = new ArrayList(list.size());
        for (int i2 = 0; i2 < list.size(); i2++) {
            E e = list.get(i2);
            if (list2.get(i2).intValue() == i) {
                arrayList.add(e);
            }
        }
        return arrayList;
    }
}
