package org.deeplearning4j.models.rntn;

import java.util.Iterator;
import java.util.List;
import java.util.SortedSet;
import org.deeplearning4j.eval.ConfusionMatrix;
import org.deeplearning4j.nn.layers.feedforward.autoencoder.recursive.Tree;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/models/rntn/RNTNEval.class */
public class RNTNEval {
    private ConfusionMatrix<Integer> cf = new ConfusionMatrix<>();
    private static final Logger log = LoggerFactory.getLogger(RNTNEval.class);

    public void eval(RNTN rntn, List<Tree> list) {
        for (Tree tree : list) {
            rntn.forwardPropagateTree(tree);
            count(tree);
        }
    }

    private void count(Tree tree) {
        if (tree.isLeaf() || tree.prediction() == null) {
            return;
        }
        Iterator it = tree.children().iterator();
        while (it.hasNext()) {
            count((Tree) it.next());
        }
        this.cf.add(Integer.valueOf(tree.goldLabel()), Integer.valueOf(Nd4j.getBlasWrapper().iamax(tree.prediction())));
    }

    public String stats() {
        StringBuilder append = new StringBuilder().append("\n");
        SortedSet<Integer> classes = this.cf.getClasses();
        for (Integer num : classes) {
            for (Integer num2 : classes) {
                int count = this.cf.getCount(num, num2);
                if (count != 0) {
                    append.append("\nActual Class " + num + " was predicted with Predicted " + num2 + " with count " + count + " times\n");
                }
            }
        }
        return append.toString();
    }
}
