package com.top.knn;

import com.top.constants.OrderEnum;
import com.top.matrix.Matrix;
import com.top.utils.MatrixUtil;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;

/* loaded from: input_file:com/top/knn/KNN.class */
public class KNN {
    public static Matrix classify(Matrix matrix, Matrix matrix2, Matrix matrix3, int i) throws Exception {
        if (matrix2.getMatrixRowCount() != matrix3.getMatrixRowCount()) {
            throw new IllegalArgumentException("矩阵训练集与标签维度不一致");
        }
        if (matrix.getMatrixColCount() != matrix2.getMatrixColCount()) {
            throw new IllegalArgumentException("待分类矩阵列数与训练集列数不一致");
        }
        if (matrix2.getMatrixRowCount() < i) {
            throw new IllegalArgumentException("训练集样本数小于k");
        }
        int matrixRowCount = matrix2.getMatrixRowCount();
        int matrixRowCount2 = matrix.getMatrixRowCount();
        Matrix matrix4 = (Matrix) MatrixUtil.normalize(matrix2.splice(2, matrix), 0.0d, 1.0d).get("res");
        Matrix subMatrix = matrix4.subMatrix(0, matrixRowCount, 0, matrix4.getMatrixColCount());
        Matrix subMatrix2 = matrix4.subMatrix(0, matrixRowCount2, 0, matrix4.getMatrixColCount());
        ArrayList<Double> arrayList = new ArrayList();
        for (int i2 = 0; i2 < matrix3.getMatrixRowCount(); i2++) {
            if (!arrayList.contains(Double.valueOf(matrix3.getValOfIdx(i2, 0)))) {
                arrayList.add(Double.valueOf(matrix3.getValOfIdx(i2, 0)));
            }
        }
        Matrix matrix5 = new Matrix(new double[subMatrix2.getMatrixRowCount()][1]);
        for (int i3 = 0; i3 < subMatrix2.getMatrixRowCount(); i3++) {
            Matrix splice = subMatrix.subtract(subMatrix2.getRowOfIdx(i3).extend(2, subMatrix.getMatrixRowCount())).square().sumRow().pow(0.5d).splice(1, matrix3);
            splice.sort(0, OrderEnum.ASC);
            HashMap hashMap = new HashMap();
            for (int i4 = 0; i4 < i; i4++) {
                for (Double d : arrayList) {
                    if (splice.getValOfIdx(i4, 1) == d.doubleValue()) {
                        hashMap.put(d, Integer.valueOf(((Integer) hashMap.getOrDefault(d, 0)).intValue() + 1));
                    }
                }
            }
            matrix5.setValue(i3, 0, getKeyOfMaxValue(hashMap).doubleValue());
        }
        return matrix5;
    }

    private static Double getKeyOfMaxValue(Map<Double, Integer> map) {
        if (map == null) {
            return null;
        }
        Double valueOf = Double.valueOf(0.0d);
        Integer num = 0;
        for (Double d : map.keySet()) {
            if (map.get(d).intValue() > num.intValue()) {
                valueOf = d;
                num = map.get(d);
            }
        }
        return valueOf;
    }
}
