package org.fnlp.nlp.cn.anaphora.train;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.fnlp.data.reader.ListReader;
import org.fnlp.data.reader.SimpleFileReader;
import org.fnlp.ml.classifier.linear.Linear;
import org.fnlp.ml.classifier.linear.OnlineTrainer;
import org.fnlp.ml.classifier.linear.inf.LinearMax;
import org.fnlp.ml.classifier.linear.update.LinearMaxPAUpdate;
import org.fnlp.ml.feature.SFGenerator;
import org.fnlp.ml.loss.ZeroOneLoss;
import org.fnlp.ml.types.Instance;
import org.fnlp.ml.types.InstanceSet;
import org.fnlp.ml.types.alphabet.AlphabetFactory;
import org.fnlp.ml.types.alphabet.LabelAlphabet;
import org.fnlp.nlp.pipe.Pipe;
import org.fnlp.nlp.pipe.SeriesPipes;
import org.fnlp.nlp.pipe.StringArray2IndexArray;
import org.fnlp.nlp.pipe.Target2Label;

/* loaded from: input_file:org/fnlp/nlp/cn/anaphora/train/ARClassifier.class */
public class ARClassifier {
    static InstanceSet train;
    static InstanceSet test;
    static Pipe pipe;
    private String trainFile = "../tmp/ar-train.txt";
    static AlphabetFactory factory = AlphabetFactory.buildFactory();
    static LabelAlphabet al = factory.DefaultLabelAlphabet();
    static String path = null;
    private static String modelFile = "../models/ar.m";

    public static void main(String[] strArr) throws Exception {
        new ARClassifier().train();
        Linear loadFrom = Linear.loadFrom(modelFile);
        int i = 0;
        int i2 = 0;
        int i3 = 0;
        int i4 = 0;
        int i5 = 0;
        int i6 = 0;
        InstanceSet instanceSet = new InstanceSet(loadFrom.getPipe(), loadFrom.getAlphabetFactory());
        SimpleFileReader simpleFileReader = new SimpleFileReader("../tmp/ar-train.txt", true);
        ArrayList arrayList = new ArrayList();
        while (simpleFileReader.hasNext()) {
            arrayList.add(simpleFileReader.next());
        }
        List[] listArr = new List[arrayList.size()];
        String[] strArr2 = new String[arrayList.size()];
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            Instance instance = (Instance) it.next();
            listArr[i] = (List) instance.getData();
            strArr2[i] = (String) instance.getTarget();
            i++;
        }
        for (String str : strArr2) {
            if (str.equals("1")) {
                i3++;
            }
        }
        instanceSet.loadThruPipes(new ListReader(listArr));
        for (int i7 = 0; i7 < listArr.length; i7++) {
            String stringLabel = loadFrom.getStringLabel(instanceSet.getInstance(i7));
            if (stringLabel.equals("1")) {
                i2++;
            }
            if (stringLabel.equals("1") && stringLabel.equals(strArr2[i7])) {
                i4++;
            }
            if (stringLabel.equals("0") && stringLabel.equals(strArr2[i7])) {
                i6++;
            }
            if (stringLabel.equals(strArr2[i7])) {
                i5++;
            }
        }
        System.out.print("整体正确率：" + ((i5 + 0.0d) / strArr2.length));
        System.out.print('\n');
        System.out.print("判断为指代关系的正确率：" + ((i4 + 0.0d) / i3));
        System.out.print('\n');
        System.out.print("判断为非指代关系的正确率：" + ((i6 + 0.0d) / (strArr2.length - i3)));
        System.out.print('\n');
        System.gc();
    }

    public void train() throws Exception {
        SeriesPipes seriesPipes = new SeriesPipes(new Pipe[]{new Target2Label(al), new StringArray2IndexArray(factory, true)});
        InstanceSet instanceSet = new InstanceSet(seriesPipes, factory);
        instanceSet.loadThruStagePipes(new SimpleFileReader(this.trainFile, " ", true, SimpleFileReader.Type.LabelData));
        SFGenerator sFGenerator = new SFGenerator();
        ZeroOneLoss zeroOneLoss = new ZeroOneLoss();
        Linear train2 = new OnlineTrainer(new LinearMax(sFGenerator, factory.getLabelSize()), new LinearMaxPAUpdate(zeroOneLoss), zeroOneLoss, factory, 50, 0.005f).train(instanceSet, instanceSet);
        seriesPipes.removeTargetPipe();
        train2.setPipe(seriesPipes);
        factory.setStopIncrement(true);
        train2.saveTo(modelFile);
    }
}
