package org.fnlp.ml.classifier.knn;

import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.zip.GZIPInputStream;
import java.util.zip.GZIPOutputStream;
import org.fnlp.ml.classifier.AbstractClassifier;
import org.fnlp.ml.classifier.LabelParser;
import org.fnlp.ml.classifier.LinkedPredict;
import org.fnlp.ml.classifier.TPredict;
import org.fnlp.ml.types.Instance;
import org.fnlp.ml.types.InstanceSet;
import org.fnlp.nlp.pipe.Pipe;
import org.fnlp.nlp.similarity.ISimilarity;
import org.fnlp.util.exception.LoadModelException;

/* loaded from: input_file:org/fnlp/ml/classifier/knn/KNN.class */
public class KNN extends AbstractClassifier {
    private static final long serialVersionUID = 4459814160943364300L;
    private ISimilarity sim;
    private int k;
    protected Pipe pipe;
    protected InstanceSet prototypes;
    private boolean useScore = true;

    public ISimilarity getSim() {
        return this.sim;
    }

    public void setSim(ISimilarity iSimilarity) {
        this.sim = iSimilarity;
    }

    public KNN(InstanceSet instanceSet, Pipe pipe, ISimilarity iSimilarity, int i) {
        this.prototypes = instanceSet;
        this.pipe = pipe;
        this.sim = iSimilarity;
        this.k = i;
        int i2 = 0;
        int i3 = 0;
        int size = this.prototypes.size();
        System.out.println("实例数量：" + size);
        for (int i4 = 0; i4 < size; i4++) {
            Instance instance = this.prototypes.get(i4);
            i2 = classify(instance, 1).getLabel(0).equals(instance.getTarget()) ? i2 + 1 : i2;
            this.prototypes.remove(i4);
            if (classify(instance, 1).getLabel(0).equals(instance.getTarget())) {
                i3++;
            }
            this.prototypes.add(i4, instance);
        }
        System.out.println("Leave-zero-out正确率：" + ((i2 * 1.0f) / size));
        System.out.println("Leave-one-out正确率：" + ((i3 * 1.0f) / size));
    }

    public void setPipe(Pipe pipe) {
        this.pipe = pipe;
    }

    @Override // org.fnlp.ml.classifier.AbstractClassifier
    public TPredict classify(Instance instance, int i) {
        LinkedPredict linkedPredict = new LinkedPredict(this.k);
        for (int i2 = 0; i2 < this.prototypes.size(); i2++) {
            Instance instance2 = this.prototypes.get(i2);
            try {
                linkedPredict.add((String) instance2.getTarget(), this.sim.calc(instance.getData(), instance2.getData()), (String) instance2.getSource());
            } catch (Exception e) {
                e.printStackTrace();
                return null;
            }
        }
        LinkedPredict mergeDuplicate = linkedPredict.mergeDuplicate(this.useScore);
        mergeDuplicate.assertSize(i);
        return mergeDuplicate;
    }

    @Override // org.fnlp.ml.classifier.AbstractClassifier
    public TPredict classify(Instance instance, LabelParser.Type type, int i) {
        return classify(instance, i);
    }

    public void saveTo(String str) throws IOException {
        File parentFile = new File(str).getParentFile();
        if (!parentFile.exists()) {
            parentFile.mkdirs();
        }
        ObjectOutputStream objectOutputStream = new ObjectOutputStream(new GZIPOutputStream(new BufferedOutputStream(new FileOutputStream(str))));
        objectOutputStream.writeObject(this);
        objectOutputStream.close();
    }

    public static KNN loadFrom(String str) throws LoadModelException {
        try {
            ObjectInputStream objectInputStream = new ObjectInputStream(new GZIPInputStream(new BufferedInputStream(new FileInputStream(str))));
            KNN knn = (KNN) objectInputStream.readObject();
            objectInputStream.close();
            return knn;
        } catch (Exception e) {
            throw new LoadModelException(e, str);
        }
    }

    public void setK(int i) {
        this.k = i;
    }
}
