package org.maochen.nlp.classifier.knn;

import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.function.BiFunction;
import org.maochen.nlp.datastructure.Tuple;
import org.maochen.nlp.utils.VectorUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/maochen/nlp/classifier/knn/KNNEngine.class */
final class KNNEngine {
    private static final Logger LOG = LoggerFactory.getLogger(KNNEngine.class);
    private Tuple predict;
    private List<Tuple> trainingData;
    private int k;
    public BiFunction<double[], double[], Double> euclideanDistance = (dArr, dArr2) -> {
        return Double.valueOf(Math.sqrt(Arrays.stream(VectorUtils.zip(dArr, dArr2, (d, d2) -> {
            return Math.pow(d - d2, 2.0d);
        })).parallel().sum()));
    };
    public BiFunction<double[], double[], Double> chebyshevDistance = (dArr, dArr2) -> {
        return Double.valueOf(Math.sqrt(Arrays.stream(VectorUtils.zip(dArr, dArr2, (d, d2) -> {
            return Math.abs(d - d2);
        })).max().getAsDouble()));
    };
    public BiFunction<double[], double[], Double> manhattanDistance = (dArr, dArr2) -> {
        return Double.valueOf(Math.sqrt(Arrays.stream(VectorUtils.zip(dArr, dArr2, (d, d2) -> {
            return Math.abs(d - d2);
        })).parallel().sum()));
    };

    public KNNEngine(Tuple tuple, List<Tuple> list, int i) {
        this.predict = tuple;
        this.trainingData = list;
        this.k = i;
    }

    public void getDistance(BiFunction<double[], double[], Double> biFunction) {
        for (Tuple tuple : this.trainingData) {
            if (this.predict.featureVector.length != tuple.featureVector.length) {
                LOG.error("2 Vectors must has same dimension.");
                return;
            }
            tuple.distance = biFunction.apply(this.predict.featureVector, tuple.featureVector).doubleValue();
        }
    }

    public String getResult() {
        HashMap hashMap = new HashMap();
        Collections.sort(this.trainingData, (tuple, tuple2) -> {
            if (Math.abs(tuple.distance - tuple2.distance) < Double.MIN_VALUE) {
                return 0;
            }
            return Double.compare(tuple.distance, tuple2.distance);
        });
        for (int i = 0; i < this.k; i++) {
            Tuple tuple3 = this.trainingData.get(i);
            hashMap.put(tuple3.label, Integer.valueOf((hashMap.containsKey(tuple3.label) ? ((Integer) hashMap.get(tuple3.label)).intValue() : 0) + 1));
        }
        String str = "";
        int i2 = 0;
        int i3 = 0;
        for (String str2 : hashMap.keySet()) {
            int intValue = ((Integer) hashMap.get(str2)).intValue();
            if (intValue == i2) {
                i3++;
            } else if (intValue > i2) {
                i2 = intValue;
                str = str2;
                i3 = 1;
            }
        }
        if (i3 != 1) {
            LOG.info("Equal Max Vote, take the first max!");
        }
        this.predict.label = str;
        return str;
    }
}
