package rocks.vilaverde.classifier.dt;

import java.io.BufferedReader;
import java.io.Reader;
import java.util.Map;
import java.util.Set;
import java.util.Stack;
import rocks.vilaverde.classifier.Operator;
import rocks.vilaverde.classifier.Prediction;

/* loaded from: input_file:rocks/vilaverde/classifier/dt/DecisionTreeClassifier.class */
public class DecisionTreeClassifier<T> implements TreeClassifier<T> {
    private final PredictionFactory<T> predictionFactory;
    private DecisionNode root;
    private Set<String> featureNames;

    public static <T> DecisionTreeClassifier<T> parse(Reader reader, PredictionFactory<T> predictionFactory) throws Exception {
        try {
            DecisionTreeClassifier<T> decisionTreeClassifier = new DecisionTreeClassifier<>(predictionFactory);
            decisionTreeClassifier.load(reader);
            FeatureNameVisitor featureNameVisitor = new FeatureNameVisitor();
            ((DecisionTreeClassifier) decisionTreeClassifier).root.accept((AbstractDecisionTreeVisitor) featureNameVisitor);
            ((DecisionTreeClassifier) decisionTreeClassifier).featureNames = featureNameVisitor.getFeatureNames();
            if (reader != null) {
                reader.close();
            }
            return decisionTreeClassifier;
        } catch (Throwable th) {
            if (reader != null) {
                try {
                    reader.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    private DecisionTreeClassifier(PredictionFactory<T> predictionFactory) {
        this.predictionFactory = predictionFactory;
    }

    @Override // rocks.vilaverde.classifier.Classifier
    public T predict(Map<String, Double> map) {
        return getClassification(map).get();
    }

    @Override // rocks.vilaverde.classifier.Classifier
    public double[] predict_proba(Map<String, Double> map) {
        return getClassification(map).getProbability();
    }

    @Override // rocks.vilaverde.classifier.dt.TreeClassifier
    public Prediction<T> getClassification(Map<String, Double> map) {
        validateFeature(map);
        Object obj = this.root;
        while (!(obj instanceof EndNode)) {
            if (obj != null) {
                DecisionNode decisionNode = (DecisionNode) obj;
                Double d = map.get(decisionNode.getFeatureName());
                if (d == null) {
                    d = Double.valueOf(Double.NaN);
                }
                if (decisionNode.getLeft().eval(d)) {
                    obj = decisionNode.getLeft().getChild();
                } else {
                    if (!decisionNode.getRight().eval(d)) {
                        throw new RuntimeException(String.format("no branches evaluated to true for feature '%s'", decisionNode.getFeatureName()));
                    }
                    obj = decisionNode.getRight().getChild();
                }
            }
        }
        return (Prediction) obj;
    }

    private void validateFeature(Map<String, Double> map) throws IllegalArgumentException {
        for (String str : this.featureNames) {
            if (!map.containsKey(str)) {
                throw new IllegalArgumentException(String.format("expected feature named '%s' but none provided", str));
            }
        }
    }

    @Override // rocks.vilaverde.classifier.Classifier
    public Set<String> getFeatureNames() {
        return this.featureNames;
    }

    private void load(Reader reader) throws Exception {
        Stack<TreeNode> stack = new Stack<>();
        BufferedReader bufferedReader = new BufferedReader(reader);
        try {
            for (String readLine = bufferedReader.readLine(); readLine != null; readLine = bufferedReader.readLine()) {
                if (readLine.length() != 0) {
                    String removeIndentations = removeIndentations(readLine);
                    if (stack.isEmpty()) {
                        processDecisionNode(stack, removeIndentations);
                    } else {
                        processChildNode(stack, removeIndentations);
                    }
                }
            }
            bufferedReader.close();
            this.root = (DecisionNode) stack.pop();
        } catch (Throwable th) {
            try {
                bufferedReader.close();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }

    private void processChildNode(Stack<TreeNode> stack, String str) throws Exception {
        if (!str.startsWith("weights: ") && !str.startsWith("class: ")) {
            processDecisionNode(stack, str);
            return;
        }
        ((ChoiceNode) stack.pop()).addChild(EndNode.create(str, this.predictionFactory));
        while (stack.size() > 1 && ((DecisionNode) stack.peek()).isComplete()) {
            stack.pop();
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    private void processDecisionNode(Stack<TreeNode> stack, String str) {
        DecisionNode create;
        int operatorIndex = getOperatorIndex(str);
        String trim = str.substring(0, operatorIndex).trim();
        String[] split = str.substring(operatorIndex).split(" ");
        Operator from = Operator.from(split[0]);
        Object[] objArr = true;
        if (from == Operator.EQ || from == Operator.LT || from == Operator.GT) {
            objArr = 2;
        }
        Double valueOf = Double.valueOf(Double.parseDouble(split[objArr == true ? 1 : 0]));
        if (stack.isEmpty()) {
            create = DecisionNode.create(trim);
            stack.push(create);
        } else {
            TreeNode peek = stack.peek();
            if ((peek instanceof DecisionNode) && ((DecisionNode) peek).getFeatureName().equals(trim)) {
                create = (DecisionNode) peek;
            } else {
                create = DecisionNode.create(trim);
                ((ChoiceNode) stack.pop()).addChild(create);
                stack.push(create);
            }
        }
        ChoiceNode create2 = ChoiceNode.create(from, valueOf);
        if (create.getLeft() == null) {
            create.setLeft(create2);
        } else {
            create.setRight(create2);
        }
        stack.push(create2);
    }

    private int getOperatorIndex(String str) {
        int i = -1;
        for (Operator operator : Operator.values()) {
            i = str.indexOf(operator.toString());
            if (i >= 0) {
                break;
            }
        }
        return i;
    }

    private String removeIndentations(String str) {
        int indexOf = str.indexOf("|--- ");
        while (true) {
            int i = indexOf;
            if (i < 0) {
                return str;
            }
            str = str.substring(i + 5);
            indexOf = str.indexOf("|--- ");
        }
    }
}
