package org.fnlp.nlp.langmodel;

import gnu.trove.list.array.TFloatArrayList;
import gnu.trove.list.array.TIntArrayList;
import gnu.trove.map.hash.TIntIntHashMap;
import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.LinkedList;
import java.util.List;
import java.util.Scanner;
import java.util.zip.GZIPInputStream;
import java.util.zip.GZIPOutputStream;
import org.fnlp.ml.types.alphabet.LabelAlphabet;

/* loaded from: input_file:org/fnlp/nlp/langmodel/NGramModel.class */
public class NGramModel {
    private static final String base = "BBB";
    protected int n;
    TFloatArrayList prob = new TFloatArrayList();
    TIntArrayList n1gramCount = new TIntArrayList();
    TIntIntHashMap index = new TIntIntHashMap();
    LabelAlphabet wordDict = new LabelAlphabet();
    LabelAlphabet ngramDict = new LabelAlphabet();
    LabelAlphabet n1gramDict = new LabelAlphabet();
    static final /* synthetic */ boolean $assertionsDisabled;

    public NGramModel(int i) {
        this.n = i;
    }

    public void build(String... strArr) throws Exception {
        System.out.println("read file ...");
        for (String str : strArr) {
            LinkedList linkedList = new LinkedList();
            for (int i = 0; i < this.n; i++) {
                linkedList.add(-1);
            }
            Scanner scanner = new Scanner(new FileInputStream(str), "utf-8");
            while (scanner.hasNext()) {
                linkedList.add(Integer.valueOf(this.wordDict.lookupIndex(scanner.next())));
                linkedList.remove();
                if (!$assertionsDisabled && linkedList.size() != this.n) {
                    throw new AssertionError();
                }
                String[] ngram = getNgram(linkedList);
                String str2 = ngram[0];
                String str3 = ngram[1];
                int lookupIndex = this.n1gramDict.lookupIndex(str2);
                if (lookupIndex == this.n1gramCount.size()) {
                    this.n1gramCount.add(1);
                } else {
                    if (lookupIndex > this.n1gramCount.size()) {
                        throw new Exception();
                    }
                    this.n1gramCount.set(lookupIndex, this.n1gramCount.get(lookupIndex) + 1);
                }
                int lookupIndex2 = this.ngramDict.lookupIndex(str3);
                if (lookupIndex2 == this.prob.size()) {
                    this.prob.add(1.0f);
                } else {
                    if (lookupIndex2 > this.prob.size()) {
                        throw new Exception();
                    }
                    this.prob.set(lookupIndex2, this.prob.get(lookupIndex2) + 1.0f);
                }
                if (!this.index.contains(lookupIndex2)) {
                    this.index.put(lookupIndex2, lookupIndex);
                } else if (!$assertionsDisabled && this.index.get(lookupIndex2) != lookupIndex) {
                    throw new AssertionError();
                }
            }
            scanner.close();
        }
        this.ngramDict.setStopIncrement(true);
        this.n1gramDict.setStopIncrement(true);
        this.wordDict.setStopIncrement(true);
        System.out.println("词表大小" + this.wordDict.size());
    }

    public String[] getNgram(List<Integer> list) {
        StringBuilder sb = new StringBuilder();
        for (int i = 0; i < this.n - 1; i++) {
            sb.append(list.get(i));
            sb.append(' ');
        }
        String sb2 = sb.toString();
        sb.append(list.get(this.n - 1));
        return new String[]{sb2, sb.toString()};
    }

    public void save(String str) throws IOException {
        System.out.println("save ...");
        ObjectOutputStream objectOutputStream = new ObjectOutputStream(new BufferedOutputStream(new GZIPOutputStream(new FileOutputStream(str))));
        objectOutputStream.writeObject(new Integer(this.n));
        objectOutputStream.close();
        System.out.println("OK");
    }

    public void load(String str) throws IOException, ClassNotFoundException {
        System.out.println("加载N-Gram模型:" + str);
        System.out.println("load ...");
        System.out.println(str);
        ObjectInputStream objectInputStream = new ObjectInputStream(new BufferedInputStream(new GZIPInputStream(new FileInputStream(str))));
        this.n = ((Integer) objectInputStream.readObject()).intValue();
        System.out.println("ngram " + this.n);
        objectInputStream.close();
        System.out.println("load end");
    }

    public double getProbability(String str) {
        double d = 0.0d;
        String[] split = str.split("\\s+");
        if (split.length < this.n - 1) {
            return 0.0d;
        }
        LinkedList linkedList = new LinkedList();
        int i = 0;
        linkedList.add(-1);
        while (i < this.n - 1) {
            linkedList.add(Integer.valueOf(this.wordDict.lookupIndex(split[i])));
            i++;
        }
        while (i < split.length) {
            linkedList.add(Integer.valueOf(this.wordDict.lookupIndex(split[i])));
            linkedList.remove();
            if (!$assertionsDisabled && linkedList.size() != this.n) {
                throw new AssertionError();
            }
            String[] ngram = getNgram(linkedList);
            String str2 = ngram[0];
            String str3 = ngram[1];
            double d2 = 0.0d;
            if (this.n1gramDict.lookupIndex(str2) != -1) {
                d2 = this.ngramDict.lookupIndex(str3) == -1 ? 1.0d / (this.n1gramCount.get(r0) + this.wordDict.size()) : this.prob.get(r0);
            }
            d += Math.log(d2);
            i++;
        }
        return d;
    }

    public double normalise(String str) {
        return getProbability(str) / str.length();
    }

    static {
        $assertionsDisabled = !NGramModel.class.desiredAssertionStatus();
    }
}
