package org.maochen.nlp.classifier.perceptron;

import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.maochen.nlp.classifier.IClassifier;
import org.maochen.nlp.datastructure.Tuple;
import org.maochen.nlp.utils.VectorUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/maochen/nlp/classifier/perceptron/PerceptronClassifier.class */
public class PerceptronClassifier implements IClassifier {
    private static final Logger LOG = LoggerFactory.getLogger(PerceptronClassifier.class);
    protected PerceptronModel model;
    private static final int MAX_ITERATION = 200;

    private Map<Integer, Double> predict(double[] dArr) {
        HashMap hashMap = new HashMap();
        for (int i = 0; i < this.model.weights.length; i++) {
            hashMap.put(Integer.valueOf(i), Double.valueOf(VectorUtils.dotProduct(dArr, this.model.weights[i]) + this.model.bias[i]));
        }
        return hashMap;
    }

    private Pair<Integer, Double> predictMax(double[] dArr) {
        Map.Entry<Integer, Double> orElse = predict(dArr).entrySet().stream().max((entry, entry2) -> {
            return ((Double) entry.getValue()).compareTo((Double) entry2.getValue());
        }).orElse(null);
        if (orElse == null) {
            return null;
        }
        return new ImmutablePair(orElse.getKey(), orElse.getValue());
    }

    private double[] reweight(double[] dArr, double[] dArr2, double d) {
        return IntStream.range(0, dArr.length).mapToDouble(i -> {
            return dArr2[i] + (this.model.learningRate * d * dArr[i]);
        }).toArray();
    }

    public void onlineTrain(double[] dArr, int i) {
        Map.Entry<Integer, Double> orElse = predict(dArr).entrySet().stream().max((entry, entry2) -> {
            return ((Double) entry.getValue()).compareTo((Double) entry2.getValue());
        }).orElse(null);
        if (orElse.getKey().intValue() != i) {
            this.model.weights[i] = reweight(dArr, this.model.weights[i], 1.0d);
            this.model.bias[i] = 1.0d;
            this.model.weights[orElse.getKey().intValue()] = reweight(dArr, this.model.weights[orElse.getKey().intValue()], -1.0d);
            this.model.bias[orElse.getKey().intValue()] = -1.0d;
        }
        if (LOG.isDebugEnabled()) {
            LOG.debug("New bias: " + Arrays.toString(this.model.bias));
            LOG.debug("New weight: " + ((String) Arrays.stream(this.model.weights).map(Arrays::toString).reduce((str, str2) -> {
                return str + ", " + str2;
            }).get()));
        }
    }

    @Override // org.maochen.nlp.classifier.IClassifier
    public IClassifier train(List<Tuple> list) {
        int count;
        this.model = new PerceptronModel(list);
        int i = 0;
        do {
            i++;
            LOG.info("Iteration " + i);
            Collections.shuffle(list);
            for (Tuple tuple : list) {
                onlineTrain(tuple.featureVector, this.model.labelIndexer.getIndex(tuple.label));
            }
            count = (int) list.stream().filter(tuple2 -> {
                return ((Integer) predictMax(tuple2.featureVector).getLeft()).intValue() != this.model.labelIndexer.getIndex(tuple2.label);
            }).count();
            if (count == 0) {
                break;
            }
        } while (i < MAX_ITERATION);
        LOG.debug("Err size: " + count);
        return this;
    }

    @Override // org.maochen.nlp.classifier.IClassifier
    public Map<String, Double> predict(Tuple tuple) {
        return (Map) predict(tuple.featureVector).entrySet().stream().map(entry -> {
            return new ImmutablePair(this.model.labelIndexer.getLabel(((Integer) entry.getKey()).intValue()), VectorUtils.sigmoid.apply(entry.getValue()));
        }).collect(Collectors.toMap((v0) -> {
            return v0.getLeft();
        }, (v0) -> {
            return v0.getRight();
        }));
    }

    @Override // org.maochen.nlp.classifier.IClassifier
    public void setParameter(Map<String, String> map) {
    }

    public PerceptronClassifier() {
        this.model = null;
        this.model = new PerceptronModel();
    }

    public static void main(String[] strArr) throws FileNotFoundException {
        String str = PerceptronClassifier.class.getResource("/").getPath() + "/perceptron_model.dat";
        System.out.println(str);
        PerceptronClassifier perceptronClassifier = new PerceptronClassifier();
        ArrayList arrayList = new ArrayList();
        arrayList.add(new Tuple(1, new double[]{1.0d, 0.0d, 0.0d}, String.valueOf(1)));
        arrayList.add(new Tuple(2, new double[]{1.0d, 0.0d, 1.0d}, String.valueOf(1)));
        arrayList.add(new Tuple(3, new double[]{1.0d, 1.0d, 0.0d}, String.valueOf(1)));
        arrayList.add(new Tuple(4, new double[]{1.0d, 1.0d, 1.0d}, String.valueOf(0)));
        perceptronClassifier.train(arrayList);
        perceptronClassifier.model.persist(str);
        PerceptronClassifier perceptronClassifier2 = new PerceptronClassifier();
        perceptronClassifier2.model.load(new FileInputStream(str));
        System.out.println(perceptronClassifier2.predict(new Tuple(5, new double[]{1.0d, 1.0d, 1.0d}, null)));
    }
}
