package org.maochen.nlp.ml.classifier.perceptron;

import com.google.common.collect.Lists;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.util.Arrays;
import java.util.List;
import org.maochen.nlp.ml.Tuple;
import org.maochen.nlp.ml.classifier.LabelIndexer;
import org.maochen.nlp.ml.util.ModelSerializeUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/maochen/nlp/ml/classifier/perceptron/PerceptronModel.class */
public class PerceptronModel {
    private static final Logger LOG = LoggerFactory.getLogger(PerceptronModel.class);
    double learningRate;
    double threshold;
    double[] bias;
    double[][] weights;
    LabelIndexer labelIndexer;

    public PerceptronModel() {
        this.learningRate = 0.1d;
        this.threshold = 0.5d;
        this.bias = null;
        this.weights = (double[][]) null;
    }

    /* JADX WARN: Type inference failed for: r1v17, types: [double[], double[][]] */
    public PerceptronModel(PerceptronModel perceptronModel) {
        this.learningRate = 0.1d;
        this.threshold = 0.5d;
        this.bias = null;
        this.weights = (double[][]) null;
        this.learningRate = perceptronModel.learningRate;
        this.threshold = perceptronModel.threshold;
        this.bias = Arrays.copyOf(perceptronModel.bias, perceptronModel.bias.length);
        this.labelIndexer = perceptronModel.labelIndexer;
        this.weights = new double[perceptronModel.weights.length];
        for (int i = 0; i < perceptronModel.weights.length; i++) {
            double[] dArr = perceptronModel.weights[i];
            this.weights[i] = new double[dArr.length];
            System.arraycopy(dArr, 0, this.weights[i], 0, dArr.length);
        }
    }

    public void init(List<Tuple> list, boolean z) {
        this.labelIndexer = new LabelIndexer(list);
        this.weights = new double[this.labelIndexer.getLabelSize()][list.stream().findFirst().orElse(null).vector.getVector().length];
        this.bias = new double[this.labelIndexer.getLabelSize()];
        if (z) {
            for (int i = 0; i < this.weights.length; i++) {
                for (int i2 = 0; i2 < this.weights[i].length; i2++) {
                    this.weights[i][i2] = Math.random();
                }
            }
        }
    }

    public void persist(String str) {
        try {
            BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(new File(str)));
            Throwable th = null;
            try {
                try {
                    bufferedWriter.write(String.valueOf(this.learningRate));
                    bufferedWriter.write(System.lineSeparator());
                    bufferedWriter.write(String.valueOf(this.threshold));
                    bufferedWriter.write(System.lineSeparator());
                    bufferedWriter.write(ModelSerializeUtils.oneDimensionArraySerialize(this.bias));
                    bufferedWriter.write(ModelSerializeUtils.twoDimensionalArraySerialize(this.weights));
                    bufferedWriter.write("li" + System.lineSeparator());
                    bufferedWriter.write(ModelSerializeUtils.mapSerialize(this.labelIndexer.labelIndexer.entrySet()));
                    if (bufferedWriter != null) {
                        if (0 != 0) {
                            try {
                                bufferedWriter.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            bufferedWriter.close();
                        }
                    }
                } catch (Throwable th3) {
                    th = th3;
                    throw th3;
                }
            } finally {
            }
        } catch (IOException e) {
            LOG.error("Persist model err.", e);
        }
    }

    /* JADX WARN: Type inference failed for: r1v23, types: [double[], double[][]] */
    public void load(InputStream inputStream) {
        try {
            BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(inputStream));
            Throwable th = null;
            int i = 0;
            boolean z = false;
            while (true) {
                try {
                    try {
                        String readLine = bufferedReader.readLine();
                        if (readLine == null) {
                            break;
                        }
                        i++;
                        String trim = readLine.trim();
                        if (!trim.isEmpty()) {
                            if (i == 1) {
                                this.learningRate = Double.valueOf(trim).doubleValue();
                            } else if (i == 2) {
                                this.threshold = Double.valueOf(trim).doubleValue();
                            } else if (i == 4) {
                                this.bias = Arrays.stream(trim.split("\\s")).mapToDouble(Double::parseDouble).toArray();
                            } else if (i == 5) {
                                int parseInt = Integer.parseInt(trim.split("\\s")[0]);
                                this.weights = new double[parseInt];
                                i++;
                                while (i < parseInt + 6) {
                                    this.weights[i - 6] = Arrays.stream(bufferedReader.readLine().trim().split("\\s")).mapToDouble(Double::parseDouble).toArray();
                                    i++;
                                }
                            } else if (trim.equalsIgnoreCase("li")) {
                                z = true;
                                this.labelIndexer = new LabelIndexer(Lists.newArrayList());
                            } else if (z) {
                                this.labelIndexer.labelIndexer.put(trim.split("\\s")[0], Integer.valueOf(Integer.parseInt(trim.split("\\s")[1])));
                            }
                        }
                    } catch (Throwable th2) {
                        th = th2;
                        throw th2;
                    }
                } finally {
                }
            }
            if (bufferedReader != null) {
                if (0 != 0) {
                    try {
                        bufferedReader.close();
                    } catch (Throwable th3) {
                        th.addSuppressed(th3);
                    }
                } else {
                    bufferedReader.close();
                }
            }
        } catch (IOException e) {
            LOG.error("Load model err.", e);
        }
    }
}
