package org.cleartk.classifier.feature.transform.extractor;

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.net.URI;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.apache.uima.jcas.JCas;
import org.apache.uima.jcas.tcas.Annotation;
import org.cleartk.classifier.Feature;
import org.cleartk.classifier.Instance;
import org.cleartk.classifier.feature.extractor.CleartkExtractorException;
import org.cleartk.classifier.feature.extractor.simple.SimpleFeatureExtractor;
import org.cleartk.classifier.feature.transform.OneToOneTrainableExtractor_ImplBase;
import org.cleartk.classifier.feature.transform.TransformableFeature;

/* loaded from: input_file:org/cleartk/classifier/feature/transform/extractor/ZeroMeanUnitStddevExtractor.class */
public class ZeroMeanUnitStddevExtractor<OUTCOME_T> extends OneToOneTrainableExtractor_ImplBase<OUTCOME_T> implements SimpleFeatureExtractor {
    private SimpleFeatureExtractor subExtractor;
    private boolean isTrained;
    private Map<String, MeanStddevTuple> meanStddevMap;

    /* loaded from: input_file:org/cleartk/classifier/feature/transform/extractor/ZeroMeanUnitStddevExtractor$MeanStddevTuple.class */
    public static class MeanStddevTuple {
        public double mean;
        public double stddev;

        public MeanStddevTuple(double d, double d2) {
            this.mean = d;
            this.stddev = d2;
        }
    }

    /* loaded from: input_file:org/cleartk/classifier/feature/transform/extractor/ZeroMeanUnitStddevExtractor$MeanVarianceRunningStat.class */
    public static class MeanVarianceRunningStat implements Serializable {
        private static final long serialVersionUID = 1;
        private int numSamples;
        private double meanOld;
        private double meanNew;
        private double varOld;
        private double varNew;

        public MeanVarianceRunningStat() {
            clear();
        }

        public void init(int i, double d, double d2) {
            this.numSamples = i;
            this.meanNew = d;
            this.varNew = d2;
        }

        public void add(double d) {
            this.numSamples++;
            if (this.numSamples == 1) {
                this.meanNew = d;
                this.meanOld = d;
                this.varOld = 0.0d;
            } else {
                this.meanNew = this.meanOld + ((d - this.meanOld) / this.numSamples);
                this.varNew = this.varOld + ((d - this.meanOld) * (d - this.meanNew));
                this.meanOld = this.meanNew;
                this.varOld = this.varNew;
            }
        }

        public void clear() {
            this.numSamples = 0;
        }

        public int getNumSamples() {
            return this.numSamples;
        }

        public double mean() {
            if (this.numSamples > 0) {
                return this.meanNew;
            }
            return 0.0d;
        }

        public double variance() {
            if (this.numSamples > 1) {
                return this.varNew / this.numSamples;
            }
            return 0.0d;
        }

        public double stddev() {
            return Math.sqrt(variance());
        }

        public double variancePop() {
            if (this.numSamples > 1) {
                return this.varNew / (this.numSamples - 1);
            }
            return 0.0d;
        }

        public double stddevPop() {
            return Math.sqrt(variancePop());
        }

        private void writeObject(ObjectOutputStream objectOutputStream) throws IOException {
            objectOutputStream.defaultWriteObject();
            objectOutputStream.writeInt(this.numSamples);
            objectOutputStream.writeDouble(this.meanNew);
            objectOutputStream.writeDouble(this.varNew);
        }

        private void readObject(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
            objectInputStream.defaultReadObject();
            this.numSamples = objectInputStream.readInt();
            double readDouble = objectInputStream.readDouble();
            this.meanNew = readDouble;
            this.meanOld = readDouble;
            double readDouble2 = objectInputStream.readDouble();
            this.varNew = readDouble2;
            this.varOld = readDouble2;
        }
    }

    public ZeroMeanUnitStddevExtractor(String str) {
        this(str, null);
    }

    public ZeroMeanUnitStddevExtractor(String str, SimpleFeatureExtractor simpleFeatureExtractor) {
        super(str);
        this.subExtractor = simpleFeatureExtractor;
        this.isTrained = false;
    }

    @Override // org.cleartk.classifier.feature.transform.OneToOneTrainableExtractor_ImplBase
    protected Feature transform(Feature feature) {
        String name = feature.getName();
        MeanStddevTuple meanStddevTuple = this.meanStddevMap.get(name);
        return new Feature("ZMUS_" + name, Double.valueOf((((Number) feature.getValue()).doubleValue() - meanStddevTuple.mean) / meanStddevTuple.stddev));
    }

    @Override // org.cleartk.classifier.feature.extractor.simple.SimpleFeatureExtractor
    public List<Feature> extract(JCas jCas, Annotation annotation) throws CleartkExtractorException {
        List<Feature> extract = this.subExtractor.extract(jCas, annotation);
        ArrayList arrayList = new ArrayList();
        if (this.isTrained) {
            Iterator<Feature> it = extract.iterator();
            while (it.hasNext()) {
                arrayList.add(transform(it.next()));
            }
        } else {
            arrayList.add(new TransformableFeature(this.name, extract));
        }
        return arrayList;
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.cleartk.classifier.feature.transform.TrainableExtractor
    public void train(Iterable<Instance<OUTCOME_T>> iterable) {
        MeanVarianceRunningStat meanVarianceRunningStat;
        HashMap hashMap = new HashMap();
        Iterator<Instance<OUTCOME_T>> it = iterable.iterator();
        while (it.hasNext()) {
            for (Feature feature : it.next().getFeatures()) {
                if (isTransformable(feature)) {
                    for (Feature feature2 : ((TransformableFeature) feature).getFeatures()) {
                        String name = feature2.getName();
                        Object value = feature2.getValue();
                        if (!(value instanceof Number)) {
                            throw new IllegalArgumentException("Cannot normalize non-numeric feature values");
                        }
                        if (hashMap.containsKey(name)) {
                            meanVarianceRunningStat = (MeanVarianceRunningStat) hashMap.get(name);
                        } else {
                            meanVarianceRunningStat = new MeanVarianceRunningStat();
                            hashMap.put(name, meanVarianceRunningStat);
                        }
                        meanVarianceRunningStat.add(((Number) value).doubleValue());
                    }
                }
            }
        }
        this.meanStddevMap = new HashMap();
        for (Map.Entry entry : hashMap.entrySet()) {
            MeanVarianceRunningStat meanVarianceRunningStat2 = (MeanVarianceRunningStat) entry.getValue();
            this.meanStddevMap.put(entry.getKey(), new MeanStddevTuple(meanVarianceRunningStat2.mean(), meanVarianceRunningStat2.stddev()));
        }
        this.isTrained = true;
    }

    @Override // org.cleartk.classifier.feature.transform.TrainableExtractor
    public void save(URI uri) throws IOException {
        BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(new File(uri)));
        for (Map.Entry<String, MeanStddevTuple> entry : this.meanStddevMap.entrySet()) {
            MeanStddevTuple value = entry.getValue();
            bufferedWriter.append((CharSequence) String.format("%s\t%f\t%f\n", entry.getKey(), Double.valueOf(value.mean), Double.valueOf(value.stddev)));
        }
        bufferedWriter.close();
    }

    @Override // org.cleartk.classifier.feature.transform.TrainableExtractor
    public void load(URI uri) throws IOException {
        File file = new File(uri);
        this.meanStddevMap = new HashMap();
        BufferedReader bufferedReader = new BufferedReader(new FileReader(file));
        while (true) {
            String readLine = bufferedReader.readLine();
            if (readLine == null) {
                bufferedReader.close();
                this.isTrained = true;
                return;
            } else {
                String[] split = readLine.split("\\t");
                this.meanStddevMap.put(split[0], new MeanStddevTuple(Double.parseDouble(split[1]), Double.parseDouble(split[2])));
            }
        }
    }
}
