package org.fnlp.nlp.similarity.train;

import gnu.trove.iterator.TIntFloatIterator;
import gnu.trove.map.hash.TIntFloatHashMap;
import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.zip.GZIPInputStream;
import java.util.zip.GZIPOutputStream;
import org.fnlp.ml.types.alphabet.LabelAlphabet;
import org.fnlp.ml.types.sv.HashSparseVector;

/* loaded from: input_file:org/fnlp/nlp/similarity/train/KMeansWordCluster.class */
public class KMeansWordCluster implements Serializable {
    private static final long serialVersionUID = 1467092492463327579L;
    private int longestTemplate;
    private String trainPath;
    private LabelAlphabet alphabet = new LabelAlphabet();
    private HashMap<String, ArrayList<HashSparseVector>> trainData = new HashMap<>();
    private ArrayList<int[]> template = new ArrayList<>();
    private HashMap<Integer, ArrayList<String>> classOri = new HashMap<>();
    private HashMap<String, ArrayList<Integer>> classPerString = new HashMap<>();
    private ArrayList<HashSparseVector> classCenter = new ArrayList<>();
    private ArrayList<Integer> classCount = new ArrayList<>();
    private ArrayList<Float> baseDistList = new ArrayList<>();

    public KMeansWordCluster(String str, String str2, String str3) {
        this.trainPath = str2;
        try {
            readTemplete(str);
            readClass(str3);
            initCluster();
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    public KMeansWordCluster(String str, String str2, String str3, String str4) throws Exception {
        readTemplete(str3);
        readClass(str4);
        LabelAlphabet labelAlphabet = (LabelAlphabet) loadObject(str);
        ArrayList<HashSparseVector> arrayList = (ArrayList) loadObject(str2);
        setAlphabet(labelAlphabet);
        setClassCenter(arrayList);
        addClassCount();
        initBaseDist();
    }

    public LabelAlphabet getAlphabet() {
        return this.alphabet;
    }

    public void setAlphabet(LabelAlphabet labelAlphabet) {
        this.alphabet = labelAlphabet;
    }

    public ArrayList<HashSparseVector> getClassCenter() {
        return this.classCenter;
    }

    public void setClassCenter(ArrayList<HashSparseVector> arrayList) {
        this.classCenter = arrayList;
    }

    private void readClass(String str) throws IOException {
        BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(new FileInputStream(str), "UTF-8"));
        int i = 0;
        while (true) {
            String readLine = bufferedReader.readLine();
            if (readLine == null) {
                bufferedReader.close();
                System.out.println("Finish load class!");
                return;
            } else {
                add2Class(readLine.split("\\s+"), i);
                i++;
            }
        }
    }

    private void add2Class(String[] strArr, int i) {
        for (String str : strArr) {
            add2ClassOri(i, str);
            add2ClassPerString(str, i);
        }
    }

    void readData(String str) throws IOException {
        BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(new FileInputStream(str), "UTF-8"));
        int i = 0;
        while (true) {
            String readLine = bufferedReader.readLine();
            if (readLine == null) {
                bufferedReader.close();
                System.out.println("Finish load training data!");
                return;
            } else {
                genFeatures(readLine);
                printTerminal(i, 10000, "raad");
                i++;
            }
        }
    }

    private ArrayList<TrainInstance> genFeatures(String str) {
        ArrayList<TrainInstance> arrayList = new ArrayList<>();
        String[] string2StringSeq = string2StringSeq(str);
        for (int i = this.longestTemplate; i < this.longestTemplate + str.length(); i++) {
            String str2 = string2StringSeq[i];
            int[] iArr = new int[this.template.size()];
            for (int i2 = 0; i2 < this.template.size(); i2++) {
                iArr[i2] = this.alphabet.lookupIndex(i2 + ":" + perFea(i, string2StringSeq, this.template.get(i2)));
            }
            HashSparseVector hashSparseVector = new HashSparseVector();
            hashSparseVector.put(iArr, 1.0f);
            arrayList.add(new TrainInstance(str2, hashSparseVector));
        }
        return arrayList;
    }

    protected void add2Data(String str, HashSparseVector hashSparseVector) {
        if (this.trainData.containsKey(str)) {
            this.trainData.get(str).add(hashSparseVector);
            return;
        }
        ArrayList<HashSparseVector> arrayList = new ArrayList<>();
        arrayList.add(hashSparseVector);
        this.trainData.put(str, arrayList);
    }

    private void add2ClassOri(int i, String str) {
        if (this.classOri.containsKey(Integer.valueOf(i))) {
            this.classOri.get(Integer.valueOf(i)).add(str);
            return;
        }
        ArrayList<String> arrayList = new ArrayList<>();
        arrayList.add(str);
        this.classOri.put(Integer.valueOf(i), arrayList);
    }

    private void add2ClassPerString(String str, int i) {
        if (this.classPerString.containsKey(str)) {
            this.classPerString.get(str).add(Integer.valueOf(i));
            return;
        }
        ArrayList<Integer> arrayList = new ArrayList<>();
        arrayList.add(Integer.valueOf(i));
        this.classPerString.put(str, arrayList);
    }

    private String perFea(int i, String[] strArr, int[] iArr) {
        StringBuffer stringBuffer = new StringBuffer();
        for (int i2 : iArr) {
            stringBuffer.append(strArr[i + i2]);
        }
        return stringBuffer.toString();
    }

    String perFea(int i, String str, int[] iArr) {
        StringBuffer stringBuffer = new StringBuffer();
        for (int i2 : iArr) {
            stringBuffer.append(str.charAt(i + i2));
        }
        return stringBuffer.toString();
    }

    private String[] string2StringSeq(String str) {
        int length = (this.longestTemplate * 2) + str.length();
        String[] strArr = new String[length];
        for (int i = 0; i < this.longestTemplate; i++) {
            strArr[i] = "Begin" + i;
            strArr[(length - i) - 1] = "End" + i;
        }
        for (int i2 = 0; i2 < str.length(); i2++) {
            strArr[i2 + this.longestTemplate] = String.valueOf(str.charAt(i2));
        }
        return strArr;
    }

    private void readTemplete(String str) throws IOException {
        BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(new FileInputStream(str), "UTF-8"));
        while (true) {
            String readLine = bufferedReader.readLine();
            if (readLine == null) {
                setLongestTemplate();
                bufferedReader.close();
                System.out.println("Finish load template!");
                return;
            }
            String[] split = readLine.split(",");
            int[] iArr = new int[split.length];
            for (int i = 0; i < iArr.length; i++) {
                int i2 = 0;
                try {
                    i2 = Integer.parseInt(split[i]);
                } catch (Exception e) {
                    e.printStackTrace();
                }
                iArr[i] = i2;
            }
            this.template.add(iArr);
        }
    }

    private void setLongestTemplate() {
        int i = 0;
        Iterator<int[]> it = this.template.iterator();
        while (it.hasNext()) {
            for (int i2 : it.next()) {
                if (Math.abs(i2) > i) {
                    i = Math.abs(i2);
                }
            }
        }
        this.longestTemplate = i;
    }

    private void initCluster() throws IOException {
        initVector();
        BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(new FileInputStream(this.trainPath), "UTF-8"));
        int i = 0;
        while (true) {
            String readLine = bufferedReader.readLine();
            if (readLine == null) {
                bufferedReader.close();
                normalClassCenter();
                initBaseDist();
                return;
            } else {
                initAddInstanceList(genFeatures(readLine));
                printTerminal(i, 10000, "init line");
                i++;
            }
        }
    }

    private void initVector() {
        for (int i = 0; i < this.classOri.size(); i++) {
            this.classCenter.add(new HashSparseVector());
            this.classCount.add(0);
        }
        System.out.println("All cluster centers have been created!");
    }

    private void addClassCount() {
        for (int i = 0; i < this.classOri.size(); i++) {
            this.classCount.add(1);
        }
    }

    private void initAddInstanceList(ArrayList<TrainInstance> arrayList) {
        Iterator<TrainInstance> it = arrayList.iterator();
        while (it.hasNext()) {
            initAddInstance(it.next());
        }
    }

    private void initAddInstance(TrainInstance trainInstance) {
        String key = trainInstance.getKey();
        HashSparseVector vector = trainInstance.getVector();
        ArrayList<Integer> arrayList = this.classPerString.get(key);
        if (arrayList == null) {
            return;
        }
        Iterator<Integer> it = arrayList.iterator();
        while (it.hasNext()) {
            int intValue = it.next().intValue();
            this.classCount.set(intValue, Integer.valueOf(this.classCount.get(intValue).intValue() + 1));
            this.classCenter.get(intValue).plus(vector);
        }
    }

    private void normalClassCenter() {
        for (int i = 0; i < this.classOri.size(); i++) {
            this.classCenter.get(i).scaleDivide(this.classCount.get(i).intValue());
            this.classCount.set(i, 1);
        }
    }

    protected void initPerClassCenter(int i) {
        ArrayList<String> arrayList = this.classOri.get(Integer.valueOf(i));
        HashSparseVector hashSparseVector = new HashSparseVector();
        int i2 = 0;
        Iterator<String> it = arrayList.iterator();
        while (it.hasNext()) {
            ArrayList<HashSparseVector> arrayList2 = this.trainData.get(it.next());
            if (arrayList2 != null) {
                i2 += arrayList2.size();
                Iterator<HashSparseVector> it2 = arrayList2.iterator();
                while (it2.hasNext()) {
                    hashSparseVector.plus(it2.next());
                }
            }
        }
        hashSparseVector.scaleDivide(i2);
        this.classCenter.add(hashSparseVector);
        this.classCount.add(1);
    }

    public void cluster() throws IOException {
        BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(new FileInputStream(this.trainPath), "UTF-8"));
        int i = 0;
        while (true) {
            String readLine = bufferedReader.readLine();
            if (readLine == null) {
                bufferedReader.close();
                updateAverageCenter();
                return;
            } else {
                clusterList(genFeatures(readLine));
                printTerminal(i, 10000, "cluster line");
                if ((i + 1) % 10000 == 0) {
                    saveObject("tmpdata/classCenterTemp", this.classCenter);
                }
                i++;
            }
        }
    }

    private void updateAverageCenter() {
        for (int i = 0; i < this.classCount.size(); i++) {
            this.classCenter.get(i).scaleDivide(this.classCount.get(i).intValue());
        }
    }

    private void clusterList(ArrayList<TrainInstance> arrayList) {
        Iterator<TrainInstance> it = arrayList.iterator();
        while (it.hasNext()) {
            TrainInstance next = it.next();
            String key = next.getKey();
            HashSparseVector vector = next.getVector();
            if (this.classPerString.containsKey(key)) {
                updateCenter(minClass(key, vector), vector);
            }
        }
    }

    private int minClass(TrainInstance trainInstance) {
        return minClass(trainInstance.getKey(), trainInstance.getVector());
    }

    private int minClass(String str, HashSparseVector hashSparseVector) {
        float f = Float.MAX_VALUE;
        int i = 0;
        ArrayList<Integer> arrayList = this.classPerString.get(str);
        if (arrayList == null) {
            return -1;
        }
        Iterator<Integer> it = arrayList.iterator();
        while (it.hasNext()) {
            Integer next = it.next();
            float distanceEuclidean = distanceEuclidean(next.intValue(), hashSparseVector, this.baseDistList.get(next.intValue()).floatValue());
            if (distanceEuclidean < f) {
                f = distanceEuclidean;
                i = next.intValue();
            }
        }
        return i;
    }

    private float distanceEuclidean(int i, HashSparseVector hashSparseVector, float f) {
        HashSparseVector hashSparseVector2 = this.classCenter.get(i);
        int intValue = this.classCount.get(i).intValue();
        float f2 = f / (intValue * intValue);
        TIntFloatHashMap tIntFloatHashMap = hashSparseVector2.data;
        TIntFloatIterator it = hashSparseVector.data.iterator();
        while (it.hasNext()) {
            it.advance();
            int key = it.key();
            if (tIntFloatHashMap.containsKey(key)) {
                float f3 = tIntFloatHashMap.get(key) / intValue;
                f2 = (f2 - (f3 * f3)) + ((it.value() - f3) * (it.value() - f3));
            } else {
                f2 += it.value() * it.value();
            }
        }
        return f2;
    }

    float distanceEuclidean(HashSparseVector hashSparseVector, HashSparseVector hashSparseVector2) {
        float f = 0.0f;
        TIntFloatIterator it = hashSparseVector.data.iterator();
        TIntFloatIterator it2 = hashSparseVector2.data.iterator();
        if (it.hasNext() && it2.hasNext()) {
            it.advance();
            it2.advance();
        }
        while (it.hasNext() && it2.hasNext()) {
            if (it.key() < it2.key()) {
                f += it.value() * it.value();
                it.advance();
            } else if (it.key() > it2.key()) {
                f += it2.value() * it2.value();
                it2.advance();
            } else {
                float value = it.value() - it2.value();
                f += value * value;
                it.advance();
                it2.advance();
            }
        }
        while (it.hasNext()) {
            it.advance();
            f += it.value() * it.value();
        }
        while (it2.hasNext()) {
            it2.advance();
            f += it2.value() * it2.value();
        }
        return f;
    }

    private void updateCenter(int i, HashSparseVector hashSparseVector) {
        this.classCount.set(i, Integer.valueOf(this.classCount.get(i).intValue() + 1));
        HashSparseVector hashSparseVector2 = this.classCenter.get(i);
        updateBaseDist(i, hashSparseVector);
        hashSparseVector2.plus(hashSparseVector);
    }

    private void updateBaseDist(int i, HashSparseVector hashSparseVector) {
        float floatValue = this.baseDistList.get(i).floatValue();
        TIntFloatHashMap tIntFloatHashMap = this.classCenter.get(i).data;
        TIntFloatIterator it = hashSparseVector.data.iterator();
        while (it.hasNext()) {
            it.advance();
            if (tIntFloatHashMap.containsKey(it.key())) {
                float f = tIntFloatHashMap.get(it.key());
                floatValue = (floatValue - (f * f)) + ((it.value() - f) * (it.value() - f));
            } else {
                floatValue += it.value() * it.value();
            }
        }
        this.baseDistList.set(i, Float.valueOf(floatValue));
    }

    private float getBaseDist(int i) {
        float f = 0.0f;
        TIntFloatIterator it = this.classCenter.get(i).data.iterator();
        while (it.hasNext()) {
            it.advance();
            f += it.value() * it.value();
        }
        return f;
    }

    private void initBaseDist() {
        for (int i = 0; i < this.classCenter.size(); i++) {
            this.baseDistList.add(Float.valueOf(getBaseDist(i)));
        }
        System.out.println("Finish init base distance list");
    }

    protected Object loadObject(String str) throws IOException, ClassNotFoundException {
        ObjectInputStream objectInputStream = new ObjectInputStream(new BufferedInputStream(new GZIPInputStream(new FileInputStream(str))));
        Object readObject = objectInputStream.readObject();
        objectInputStream.close();
        return readObject;
    }

    protected void saveObject(String str, Object obj) throws IOException {
        ObjectOutputStream objectOutputStream = new ObjectOutputStream(new BufferedOutputStream(new GZIPOutputStream(new FileOutputStream(str))));
        objectOutputStream.writeObject(obj);
        objectOutputStream.close();
    }

    private void printTerminal(int i, int i2, String str) {
        if ((i + 1) % i2 == 0) {
            System.out.println(str + " " + (i + 1));
        }
    }

    public int classifier(String str) {
        return minClass(genFeaForClassifier(str));
    }

    private TrainInstance genFeaForClassifier(String str) {
        int[] iArr = new int[this.template.size()];
        String[] string2StringSeqWithBE = string2StringSeqWithBE(str);
        for (int i = 0; i < this.template.size(); i++) {
            iArr[i] = this.alphabet.lookupIndex(i + ":" + perFea(1, string2StringSeqWithBE, this.template.get(i)));
        }
        HashSparseVector hashSparseVector = new HashSparseVector();
        hashSparseVector.put(iArr, 1.0f);
        return new TrainInstance(string2StringSeqWithBE[1], hashSparseVector);
    }

    private String[] string2StringSeqWithBE(String str) {
        String[] strArr = new String[3];
        if (str.startsWith("Begin0")) {
            strArr[0] = "Begin0";
            strArr[1] = String.valueOf(str.charAt(6));
            strArr[2] = String.valueOf(str.charAt(7));
        } else if (str.endsWith("End0")) {
            strArr[0] = String.valueOf(str.charAt(0));
            strArr[1] = String.valueOf(str.charAt(1));
            strArr[2] = "End0";
        } else {
            for (int i = 0; i < str.length(); i++) {
                strArr[i] = String.valueOf(str.charAt(i));
            }
        }
        return strArr;
    }

    public static void main(String[] strArr) throws Exception {
        if (strArr.length == 5) {
            KMeansWordCluster kMeansWordCluster = new KMeansWordCluster(strArr[0], strArr[1], strArr[2]);
            kMeansWordCluster.saveObject(strArr[4], kMeansWordCluster.getAlphabet());
            kMeansWordCluster.cluster();
            kMeansWordCluster.saveObject(strArr[3], kMeansWordCluster.getClassCenter());
        }
        if (strArr.length == 4) {
            KMeansWordCluster kMeansWordCluster2 = new KMeansWordCluster(strArr[0], strArr[1], strArr[2], strArr[3]);
            System.out.println(kMeansWordCluster2.classifier("123"));
            System.out.println(kMeansWordCluster2.classifier("sdf"));
            System.out.println(kMeansWordCluster2.classifier("gjl"));
            System.out.println(kMeansWordCluster2.classifier("打日本"));
            System.out.println(kMeansWordCluster2.classifier("中日韩"));
            System.out.println(kMeansWordCluster2.classifier("几日呢"));
            System.out.println(kMeansWordCluster2.classifier("Begin0几日"));
            System.out.println(kMeansWordCluster2.classifier("几日End0"));
        }
        if (strArr.length == 0) {
            KMeansWordCluster kMeansWordCluster3 = new KMeansWordCluster("./exp/featureCluster/alphabet", "./exp/featureCluster/clusterCenter", "./exp/featureCluster/template", "./exp/featureCluster/charsynset.txt");
            System.out.println(kMeansWordCluster3.classifier("１２３"));
            System.out.println(kMeansWordCluster3.classifier("123"));
            System.out.println(kMeansWordCluster3.classifier("sdf"));
            System.out.println(kMeansWordCluster3.classifier("ＡＢＢ"));
            System.out.println(kMeansWordCluster3.classifier("gjl"));
            System.out.println(kMeansWordCluster3.classifier("打日本"));
            System.out.println(kMeansWordCluster3.classifier("中日韩"));
            System.out.println(kMeansWordCluster3.classifier("几日呢"));
            System.out.println(kMeansWordCluster3.classifier("Begin0几日"));
            System.out.println(kMeansWordCluster3.classifier("几日End0"));
        }
    }
}
