package org.maochen.nlp.ml.util;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import org.maochen.nlp.ml.IClassifier;
import org.maochen.nlp.ml.Tuple;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/maochen/nlp/ml/util/CrossValidation.class */
public class CrossValidation {
    private static final Logger LOG = LoggerFactory.getLogger(CrossValidation.class);
    private int round;
    private IClassifier classifier;
    private Set<String> labels;
    private Set<Score> scores = new HashSet();
    private boolean shuffledata;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/maochen/nlp/ml/util/CrossValidation$Score.class */
    public static class Score {
        int round;
        String label;
        int truePos = 0;
        int trueNeg = 0;
        int falsePos = 0;
        int falseNeg = 0;

        Score() {
        }

        public double getF1() {
            double precision = getPrecision();
            double recall = getRecall();
            return ((2.0d * precision) * recall) / (precision + recall);
        }

        public double getPrecision() {
            return this.truePos / (this.truePos + this.falsePos);
        }

        public double getRecall() {
            return this.truePos / (this.truePos + this.falseNeg);
        }

        public double getAccurancy() {
            return (this.trueNeg + this.truePos) / (((this.truePos + this.trueNeg) + this.falseNeg) + this.falsePos);
        }
    }

    public void run(List<Tuple> list) {
        ArrayList arrayList = new ArrayList(list);
        this.labels = (Set) list.parallelStream().map(tuple -> {
            return tuple.label;
        }).collect(Collectors.toSet());
        if (this.shuffledata) {
            Collections.shuffle(arrayList);
        }
        int size = list.size() / this.round;
        int size2 = list.size() % size;
        for (int size3 = list.size() - 1; size3 > (list.size() - 1) - size2; size3--) {
            LOG.info("Dropping the tail id: " + list.get(size3).id);
        }
        for (int i = 0; i < this.round; i++) {
            List<Tuple> subList = list.subList(i, i + size);
            List<Tuple> subList2 = list.subList(0, i);
            subList2.addAll(list.subList(i + size, list.size()));
            eval(subList2, subList, i);
        }
    }

    private void eval(List<Tuple> list, List<Tuple> list2, int i) {
        this.classifier.train(list);
        for (Tuple tuple : list2) {
            updateScore(tuple, (String) this.classifier.predict(tuple).entrySet().stream().max((entry, entry2) -> {
                return ((Double) entry.getValue()).compareTo((Double) entry2.getValue());
            }).map((v0) -> {
                return v0.getKey();
            }).orElse(""), i);
        }
    }

    private void updateScore(Tuple tuple, String str, int i) {
        this.labels.stream().forEach(str2 -> {
            Score score = new Score();
            score.round = i;
            score.label = str2;
            this.scores.add(score);
        });
        if (tuple.label.equals(str)) {
            this.scores.stream().filter(score -> {
                return score.round == i;
            }).filter(score2 -> {
                return score2.label.equals(tuple.label);
            }).forEach(score3 -> {
                score3.truePos++;
            });
            this.scores.stream().filter(score4 -> {
                return score4.round == i;
            }).filter(score5 -> {
                return !score5.label.equals(tuple.label);
            }).forEach(score6 -> {
                score6.trueNeg++;
            });
        } else {
            String str3 = tuple.label;
            this.scores.stream().filter(score7 -> {
                return score7.round == i;
            }).filter(score8 -> {
                return score8.label.equals(str);
            }).forEach(score9 -> {
                score9.falsePos++;
            });
            this.scores.stream().filter(score10 -> {
                return score10.round == i;
            }).filter(score11 -> {
                return !score11.label.equals(str3);
            }).forEach(score12 -> {
                score12.falseNeg++;
            });
            this.scores.stream().filter(score13 -> {
                return score13.round == i;
            }).filter(score14 -> {
                return (score14.label.equals(str3) || score14.label.equals(str)) ? false : true;
            }).forEach(score15 -> {
                score15.trueNeg++;
            });
        }
    }

    public CrossValidation(int i, IClassifier iClassifier, boolean z) {
        this.round = i;
        this.classifier = iClassifier;
        this.shuffledata = z;
    }
}
