package org.cleartk.ml.feature.selection;

import com.google.common.base.Function;
import com.google.common.base.Joiner;
import com.google.common.collect.Collections2;
import com.google.common.collect.HashMultiset;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Multiset;
import com.google.common.collect.Ordering;
import com.google.common.collect.Table;
import com.google.common.collect.TreeBasedTable;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.lang.Comparable;
import java.net.URI;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import org.apache.uima.jcas.JCas;
import org.apache.uima.jcas.tcas.Annotation;
import org.cleartk.ml.Feature;
import org.cleartk.ml.Instance;
import org.cleartk.ml.feature.extractor.CleartkExtractorException;
import org.cleartk.ml.feature.extractor.FeatureExtractor1;
import org.cleartk.ml.feature.transform.TransformableFeature;

/* loaded from: input_file:org/cleartk/ml/feature/selection/MutualInformationFeatureSelectionExtractor.class */
public class MutualInformationFeatureSelectionExtractor<OUTCOME_T extends Comparable<?>, FOCUS_T extends Annotation> extends FeatureSelectionExtractor<OUTCOME_T> implements FeatureExtractor1<FOCUS_T> {
    protected boolean isTrained;
    private MutualInformationStats<OUTCOME_T> mutualInfoStats;
    private FeatureExtractor1<FOCUS_T> subExtractor;
    private int numFeatures;
    private CombineScoreMethod combineScoreMethod;
    private List<String> selectedFeatures;
    private double smoothingCount;

    /* loaded from: input_file:org/cleartk/ml/feature/selection/MutualInformationFeatureSelectionExtractor$CombineScoreMethod.class */
    public enum CombineScoreMethod {
        AVERAGE,
        MAX;

        /* loaded from: input_file:org/cleartk/ml/feature/selection/MutualInformationFeatureSelectionExtractor$CombineScoreMethod$AverageScores.class */
        public static class AverageScores<OUTCOME_T> extends CombineScoreFunction<OUTCOME_T> {
            public Double apply(Map<OUTCOME_T, Double> map) {
                Collection<Double> values = map.values();
                int size = values.size();
                double d = 0.0d;
                Iterator<Double> it = values.iterator();
                while (it.hasNext()) {
                    d += it.next().doubleValue();
                }
                return Double.valueOf(d / size);
            }
        }

        /* loaded from: input_file:org/cleartk/ml/feature/selection/MutualInformationFeatureSelectionExtractor$CombineScoreMethod$CombineScoreFunction.class */
        public static abstract class CombineScoreFunction<OUTCOME_T> implements Function<Map<OUTCOME_T, Double>, Double> {
        }

        /* loaded from: input_file:org/cleartk/ml/feature/selection/MutualInformationFeatureSelectionExtractor$CombineScoreMethod$MaxScores.class */
        public static class MaxScores<OUTCOME_T> extends CombineScoreFunction<OUTCOME_T> {
            public Double apply(Map<OUTCOME_T, Double> map) {
                return (Double) Ordering.natural().max(map.values());
            }
        }
    }

    /* loaded from: input_file:org/cleartk/ml/feature/selection/MutualInformationFeatureSelectionExtractor$MutualInformationStats.class */
    public static class MutualInformationStats<OUTCOME_T extends Comparable<?>> {
        protected Multiset<OUTCOME_T> classCounts = HashMultiset.create();
        protected Table<String, OUTCOME_T, Integer> classConditionalCounts = TreeBasedTable.create();
        protected double smoothingCount;

        /* loaded from: input_file:org/cleartk/ml/feature/selection/MutualInformationFeatureSelectionExtractor$MutualInformationStats$ComputeFeatureScore.class */
        public static class ComputeFeatureScore<OUTCOME_T extends Comparable<?>> implements Function<String, Double> {
            private MutualInformationStats<OUTCOME_T> stats;
            private CombineScoreMethod.CombineScoreFunction<OUTCOME_T> combineScoreFunction;

            /* JADX WARN: Failed to find 'out' block for switch in B:2:0x0011. Please report as an issue. */
            public ComputeFeatureScore(MutualInformationStats<OUTCOME_T> mutualInformationStats, CombineScoreMethod combineScoreMethod) {
                this.stats = mutualInformationStats;
                switch (combineScoreMethod) {
                    case AVERAGE:
                        this.combineScoreFunction = new CombineScoreMethod.AverageScores();
                    case MAX:
                        this.combineScoreFunction = new CombineScoreMethod.MaxScores();
                        return;
                    default:
                        return;
                }
            }

            /* JADX WARN: Multi-variable type inference failed */
            public Double apply(String str) {
                Set<Comparable> columnKeySet = this.stats.classConditionalCounts.columnKeySet();
                HashMap newHashMap = Maps.newHashMap();
                for (Comparable comparable : columnKeySet) {
                    newHashMap.put(comparable, Double.valueOf(this.stats.mutualInformation(str, comparable)));
                }
                return (Double) this.combineScoreFunction.apply(newHashMap);
            }
        }

        public MutualInformationStats(double d) {
            this.smoothingCount += d;
        }

        public void update(String str, OUTCOME_T outcome_t, int i) {
            Integer num = (Integer) this.classConditionalCounts.get(str, outcome_t);
            if (num == null) {
                num = 0;
            }
            this.classConditionalCounts.put(str, outcome_t, Integer.valueOf(num.intValue() + i));
            this.classCounts.add(outcome_t, i);
        }

        public double mutualInformation(String str, OUTCOME_T outcome_t) {
            int[][] iArr = new int[2][2];
            int size = this.classCounts.size();
            int[] iArr2 = {size - iArr2[1], sum(this.classConditionalCounts.row(str).values())};
            int[] iArr3 = {size - iArr3[1], this.classCounts.count(outcome_t)};
            iArr[1][1] = this.classConditionalCounts.contains(str, outcome_t) ? ((Integer) this.classConditionalCounts.get(str, outcome_t)).intValue() : 0;
            iArr[1][0] = iArr2[1] - iArr[1][1];
            iArr[0][1] = iArr3[1] - iArr[1][1];
            iArr[0][0] = ((size - iArr2[1]) - iArr3[1]) + iArr[1][1];
            double d = 0.0d;
            for (int i = 0; i <= 1; i++) {
                for (int i2 = 0; i2 <= 1; i2++) {
                    iArr[i][i2] = (int) (r0[r1] + this.smoothingCount);
                    d += (iArr[i][i2] / size) * Math.log((size * iArr[i][i2]) / (iArr2[i] * iArr3[i2]));
                }
            }
            return d;
        }

        private int sum(Collection<Integer> collection) {
            int i = 0;
            Iterator<Integer> it = collection.iterator();
            while (it.hasNext()) {
                i += it.next().intValue();
            }
            return i;
        }

        /* JADX WARN: Multi-variable type inference failed */
        public void save(URI uri) throws IOException {
            BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(new File(uri)));
            bufferedWriter.append((CharSequence) "Mutual Information Data\n");
            bufferedWriter.append((CharSequence) "Feature\t");
            bufferedWriter.append((CharSequence) Joiner.on("\t").join(this.classConditionalCounts.columnKeySet()));
            bufferedWriter.append((CharSequence) "\n");
            for (String str : this.classConditionalCounts.rowKeySet()) {
                bufferedWriter.append((CharSequence) str);
                for (Comparable comparable : this.classConditionalCounts.columnKeySet()) {
                    bufferedWriter.append((CharSequence) "\t");
                    bufferedWriter.append((CharSequence) String.format(Locale.ROOT, "%f", Double.valueOf(mutualInformation(str, comparable))));
                }
                bufferedWriter.append((CharSequence) "\n");
            }
            bufferedWriter.append((CharSequence) "\n");
            bufferedWriter.append((CharSequence) this.classConditionalCounts.toString());
            bufferedWriter.close();
        }

        public ComputeFeatureScore<OUTCOME_T> getScoreFunction(CombineScoreMethod combineScoreMethod) {
            return new ComputeFeatureScore<>(this, combineScoreMethod);
        }
    }

    public String nameFeature(Feature feature) {
        return feature.getValue() instanceof Number ? feature.getName() : feature.getName() + ":" + feature.getValue();
    }

    public MutualInformationFeatureSelectionExtractor(String str, FeatureExtractor1<FOCUS_T> featureExtractor1) {
        super(str);
        init(featureExtractor1, CombineScoreMethod.MAX, 1.0d, 10);
    }

    public MutualInformationFeatureSelectionExtractor(String str, FeatureExtractor1<FOCUS_T> featureExtractor1, int i) {
        super(str);
        init(featureExtractor1, CombineScoreMethod.MAX, 1.0d, i);
    }

    public MutualInformationFeatureSelectionExtractor(String str, FeatureExtractor1<FOCUS_T> featureExtractor1, CombineScoreMethod combineScoreMethod, double d, int i) {
        super(str);
        init(featureExtractor1, combineScoreMethod, d, i);
    }

    private void init(FeatureExtractor1<FOCUS_T> featureExtractor1, CombineScoreMethod combineScoreMethod, double d, int i) {
        this.subExtractor = featureExtractor1;
        this.combineScoreMethod = combineScoreMethod;
        this.smoothingCount = d;
        this.numFeatures = i;
    }

    @Override // org.cleartk.ml.feature.extractor.FeatureExtractor1
    public List<Feature> extract(JCas jCas, FOCUS_T focus_t) throws CleartkExtractorException {
        List<Feature> extract = this.subExtractor.extract(jCas, focus_t);
        ArrayList arrayList = new ArrayList();
        if (this.isTrained) {
            arrayList.addAll(Collections2.filter(extract, this));
        } else {
            arrayList.add(new TransformableFeature(this.name, extract));
        }
        return arrayList;
    }

    @Override // org.cleartk.ml.feature.transform.TrainableExtractor
    public void train(Iterable<Instance<OUTCOME_T>> iterable) {
        this.mutualInfoStats = new MutualInformationStats<>(this.smoothingCount);
        for (Instance<OUTCOME_T> instance : iterable) {
            OUTCOME_T outcome = instance.getOutcome();
            for (Feature feature : instance.getFeatures()) {
                if (isTransformable(feature)) {
                    Iterator<Feature> it = ((TransformableFeature) feature).getFeatures().iterator();
                    while (it.hasNext()) {
                        this.mutualInfoStats.update(nameFeature(it.next()), outcome, 1);
                    }
                }
            }
        }
        this.selectedFeatures = Ordering.natural().onResultOf(this.mutualInfoStats.getScoreFunction(this.combineScoreMethod)).reverse().immutableSortedCopy(this.mutualInfoStats.classConditionalCounts.rowKeySet());
        this.isTrained = true;
    }

    @Override // org.cleartk.ml.feature.transform.TrainableExtractor
    public void save(URI uri) throws IOException {
        if (!this.isTrained) {
            throw new IOException("MutualInformationFeatureExtractor: Cannot save before training.");
        }
        BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(new File(uri)));
        bufferedWriter.append((CharSequence) "CombineScoreType\t");
        bufferedWriter.append((CharSequence) this.combineScoreMethod.toString());
        bufferedWriter.append((CharSequence) "\n");
        MutualInformationStats.ComputeFeatureScore<OUTCOME_T> scoreFunction = this.mutualInfoStats.getScoreFunction(this.combineScoreMethod);
        for (String str : this.selectedFeatures) {
            bufferedWriter.append((CharSequence) String.format(Locale.ROOT, "%s\t%f\n", str, scoreFunction.apply(str)));
        }
        bufferedWriter.close();
    }

    @Override // org.cleartk.ml.feature.transform.TrainableExtractor
    public void load(URI uri) throws IOException {
        this.selectedFeatures = Lists.newArrayList();
        BufferedReader bufferedReader = new BufferedReader(new FileReader(new File(uri)));
        this.combineScoreMethod = CombineScoreMethod.valueOf(bufferedReader.readLine().split("\\t")[1]);
        int i = 0;
        while (true) {
            String readLine = bufferedReader.readLine();
            if (readLine == null || i >= this.numFeatures) {
                break;
            }
            this.selectedFeatures.add(readLine.split("\\t")[0]);
            i++;
        }
        bufferedReader.close();
        this.isTrained = true;
    }

    public boolean apply(Feature feature) {
        return this.selectedFeatures.contains(nameFeature(feature));
    }

    public final List<String> getSelectedFeatures() {
        return this.selectedFeatures;
    }
}
