package org.fnlp.ml.classifier.bayes;

import gnu.trove.iterator.TIntFloatIterator;
import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.Arrays;
import java.util.zip.GZIPInputStream;
import java.util.zip.GZIPOutputStream;
import org.fnlp.ml.classifier.AbstractClassifier;
import org.fnlp.ml.classifier.LabelParser;
import org.fnlp.ml.classifier.Predict;
import org.fnlp.ml.feature.FeatureSelect;
import org.fnlp.ml.types.Instance;
import org.fnlp.ml.types.alphabet.AlphabetFactory;
import org.fnlp.ml.types.sv.HashSparseVector;
import org.fnlp.nlp.pipe.Pipe;
import org.fnlp.util.exception.LoadModelException;

/* loaded from: input_file:org/fnlp/ml/classifier/bayes/BayesClassifier.class */
public class BayesClassifier extends AbstractClassifier implements Serializable {
    protected ItemFrequency tf;
    protected Pipe pipe;
    protected FeatureSelect fs;

    @Override // org.fnlp.ml.classifier.AbstractClassifier
    public Predict classify(Instance instance, int i) {
        int typeSize = this.tf.getTypeSize();
        float[] fArr = new float[typeSize];
        Arrays.fill(fArr, 0.0f);
        Object data = instance.getData();
        if (!(data instanceof HashSparseVector)) {
            System.out.println("error 输入类型非HashSparseVector！");
            return null;
        }
        HashSparseVector hashSparseVector = (HashSparseVector) data;
        if (this.fs != null) {
            hashSparseVector = this.fs.select(hashSparseVector);
        }
        TIntFloatIterator it = hashSparseVector.data.iterator();
        float featureSize = this.tf.getFeatureSize();
        while (it.hasNext()) {
            it.advance();
            if (it.key() != 0) {
                int key = it.key();
                for (int i2 = 0; i2 < typeSize; i2++) {
                    fArr[i2] = (float) (fArr[r1] + (it.value() * Math.log((this.tf.getItemFrequency(key, i2) + 0.1d) / (this.tf.getTypeFrequency(i2) + featureSize))));
                }
            }
        }
        Predict predict = new Predict(i);
        for (int i3 = 0; i3 < typeSize; i3++) {
            predict.add(Integer.valueOf(i3), fArr[i3]);
        }
        return predict;
    }

    @Override // org.fnlp.ml.classifier.AbstractClassifier
    public Predict classify(Instance instance, LabelParser.Type type, int i) {
        return LabelParser.parse(classify(instance, i), this.factory.DefaultLabelAlphabet(), type);
    }

    public String getLabel(int i) {
        return this.factory.DefaultLabelAlphabet().lookupString(i);
    }

    public void saveTo(String str) throws IOException {
        File parentFile = new File(str).getParentFile();
        if (!parentFile.exists()) {
            parentFile.mkdirs();
        }
        ObjectOutputStream objectOutputStream = new ObjectOutputStream(new GZIPOutputStream(new BufferedOutputStream(new FileOutputStream(str))));
        objectOutputStream.writeObject(this);
        objectOutputStream.close();
    }

    public static BayesClassifier loadFrom(String str) throws LoadModelException {
        try {
            ObjectInputStream objectInputStream = new ObjectInputStream(new GZIPInputStream(new BufferedInputStream(new FileInputStream(str))));
            BayesClassifier bayesClassifier = (BayesClassifier) objectInputStream.readObject();
            objectInputStream.close();
            return bayesClassifier;
        } catch (Exception e) {
            throw new LoadModelException(e, str);
        }
    }

    public void fS_CS(float f) {
        featureSelectionChiSquare(f);
    }

    public void featureSelectionChiSquare(float f) {
        this.fs = new FeatureSelect(this.tf.getFeatureSize());
        this.fs.fS_CS(this.tf, f);
    }

    public void fS_CS_Max(float f) {
        featureSelectionChiSquareMax(f);
    }

    public void featureSelectionChiSquareMax(float f) {
        this.fs = new FeatureSelect(this.tf.getFeatureSize());
        this.fs.fS_CS_Max(this.tf, f);
    }

    public void fS_IG(float f) {
        featureSelectionInformationGain(f);
    }

    public void featureSelectionInformationGain(float f) {
        this.fs = new FeatureSelect(this.tf.getFeatureSize());
        this.fs.fS_IG(this.tf, f);
    }

    public void noFeatureSelection() {
        this.fs = null;
    }

    public ItemFrequency getTf() {
        return this.tf;
    }

    public void setTf(ItemFrequency itemFrequency) {
        this.tf = itemFrequency;
    }

    public Pipe getPipe() {
        return this.pipe;
    }

    public void setPipe(Pipe pipe) {
        this.pipe = pipe;
    }

    public void setFactory(AlphabetFactory alphabetFactory) {
        this.factory = alphabetFactory;
    }

    public AlphabetFactory getFactory() {
        return this.factory;
    }
}
