package org.cleartk.summarization;

import com.google.common.annotations.Beta;
import com.google.common.base.Function;
import com.google.common.collect.LinkedHashMultiset;
import com.google.common.collect.Multiset;
import com.google.common.collect.Ordering;
import com.lexicalscope.jewel.cli.CliFactory;
import com.lexicalscope.jewel.cli.Option;
import java.io.BufferedOutputStream;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.io.InputStream;
import java.io.ObjectOutputStream;
import java.net.URI;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.cleartk.ml.Feature;
import org.cleartk.ml.Instance;
import org.cleartk.ml.feature.transform.InstanceStream;

@Beta
/* loaded from: input_file:org/cleartk/summarization/SumBasicModel.class */
public class SumBasicModel extends SummarizationModel_ImplBase {
    private static final long serialVersionUID = -354873594945022087L;
    public static final String MODEL_NAME = "model.sumbasic";

    /* loaded from: input_file:org/cleartk/summarization/SumBasicModel$AverageCF.class */
    private static class AverageCF extends CompositionFunction {
        private SumCF sumcf;

        public AverageCF(Double d, TermFrequencyMap termFrequencyMap, Set<String> set) {
            super(d, termFrequencyMap, set);
            this.sumcf = new SumCF(d, termFrequencyMap, set);
        }

        public Double apply(List<Feature> list) {
            int size = list.size();
            return Double.valueOf(size == 0 ? 0.0d : this.sumcf.apply(list).doubleValue() / size);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/cleartk/summarization/SumBasicModel$CompositionFunction.class */
    public static abstract class CompositionFunction implements Function<List<Feature>, Double> {
        protected Double seenWordProbability;
        protected TermFrequencyMap tfMap;
        protected Set<String> seenWords;

        public CompositionFunction(Double d, TermFrequencyMap termFrequencyMap, Set<String> set) {
            this.seenWordProbability = d;
            this.tfMap = termFrequencyMap;
            this.seenWords = set;
        }
    }

    /* loaded from: input_file:org/cleartk/summarization/SumBasicModel$CompositionFunctionType.class */
    public enum CompositionFunctionType {
        AVERAGE,
        PRODUCT,
        SUM
    }

    /* loaded from: input_file:org/cleartk/summarization/SumBasicModel$Options.class */
    public interface Options {
        @Option(longName = {"max-num-sentences"}, description = "Specifies the maximum number of sentences to extract in the summary", defaultValue = {"10"})
        int getMaxNumSentences();

        @Option(longName = {"seen-words-prob"}, description = "Specify the probability for seen words.", defaultValue = {"0.0001"})
        double getSeenWordsProbability();

        @Option(longName = {"composition-function"}, description = "Specifies how word probabilities are combined (AVERAGE|SUM|PRODUCT, default=AVERAGE)", defaultValue = {"AVERAGE"})
        CompositionFunctionType getCFType();
    }

    /* loaded from: input_file:org/cleartk/summarization/SumBasicModel$ProductCF.class */
    private static class ProductCF extends CompositionFunction {
        public ProductCF(Double d, TermFrequencyMap termFrequencyMap, Set<String> set) {
            super(d, termFrequencyMap, set);
        }

        public Double apply(List<Feature> list) {
            double d = 1.0d;
            for (Feature feature : list) {
                d = this.seenWords.contains(feature.getName()) ? d * this.seenWordProbability.doubleValue() : d * this.tfMap.getProbability(feature.getName());
            }
            return Double.valueOf(d);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/cleartk/summarization/SumBasicModel$SumCF.class */
    public static class SumCF extends CompositionFunction {
        public SumCF(Double d, TermFrequencyMap termFrequencyMap, Set<String> set) {
            super(d, termFrequencyMap, set);
        }

        public Double apply(List<Feature> list) {
            Double valueOf = Double.valueOf(0.0d);
            for (Feature feature : list) {
                valueOf = this.seenWords.contains(feature.getName()) ? Double.valueOf(valueOf.doubleValue() + this.seenWordProbability.doubleValue()) : Double.valueOf(valueOf.doubleValue() + this.tfMap.getProbability(feature.getName()));
            }
            return valueOf;
        }
    }

    /* loaded from: input_file:org/cleartk/summarization/SumBasicModel$TermFrequencyMap.class */
    public static class TermFrequencyMap {
        private Multiset<String> termFrequencies = LinkedHashMultiset.create();

        public void add(String str, int i) {
            this.termFrequencies.add(str, i);
        }

        public double getProbability(String str) {
            return this.termFrequencies.count(str) / this.termFrequencies.size();
        }

        public void save(URI uri) throws IOException {
            BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(new File(uri)));
            for (Multiset.Entry entry : this.termFrequencies.entrySet()) {
                bufferedWriter.append((CharSequence) String.format("%s\t%d\n", entry.getElement(), Integer.valueOf(entry.getCount())));
            }
            bufferedWriter.close();
        }

        public void load(URI uri) throws IOException {
            BufferedReader bufferedReader = new BufferedReader(new FileReader(new File(uri)));
            while (true) {
                String readLine = bufferedReader.readLine();
                if (readLine == null) {
                    bufferedReader.close();
                    return;
                } else {
                    String[] split = readLine.split("\\t");
                    this.termFrequencies.add(split[0], Integer.parseInt(split[1]));
                }
            }
        }
    }

    public static void trainAndWriteModel(File file, String... strArr) throws Exception {
        CompositionFunction averageCF;
        Options options = (Options) CliFactory.parseArguments(Options.class, strArr);
        int maxNumSentences = options.getMaxNumSentences();
        CompositionFunctionType cFType = options.getCFType();
        double seenWordsProbability = options.getSeenWordsProbability();
        Iterable<Instance> loadFromDirectory = InstanceStream.loadFromDirectory(file);
        ArrayList arrayList = new ArrayList();
        TermFrequencyMap termFrequencyMap = new TermFrequencyMap();
        for (Instance instance : loadFromDirectory) {
            arrayList.add(instance.getFeatures());
            for (Feature feature : instance.getFeatures()) {
                termFrequencyMap.add(feature.getName(), ((Integer) feature.getValue()).intValue());
            }
        }
        HashSet hashSet = new HashSet();
        switch (cFType) {
            case PRODUCT:
                averageCF = new ProductCF(Double.valueOf(seenWordsProbability), termFrequencyMap, hashSet);
                break;
            case SUM:
                averageCF = new SumCF(Double.valueOf(seenWordsProbability), termFrequencyMap, hashSet);
                break;
            case AVERAGE:
            default:
                averageCF = new AverageCF(Double.valueOf(seenWordsProbability), termFrequencyMap, hashSet);
                break;
        }
        HashMap hashMap = new HashMap();
        for (int i = 0; i < maxNumSentences; i++) {
            List list = (List) Ordering.natural().onResultOf(averageCF).max(arrayList);
            hashMap.put(list, averageCF.apply(list));
            Iterator it = list.iterator();
            while (it.hasNext()) {
                hashSet.add(((Feature) it.next()).getName());
            }
        }
        try {
            SumBasicModel sumBasicModel = new SumBasicModel(hashMap);
            ObjectOutputStream objectOutputStream = new ObjectOutputStream(new BufferedOutputStream(new FileOutputStream(new File(file, sumBasicModel.getModelName()))));
            objectOutputStream.writeObject(sumBasicModel);
            objectOutputStream.close();
        } catch (FileNotFoundException e) {
            throw new Exception(e);
        } catch (IOException e2) {
            throw new Exception(e2);
        }
    }

    public SumBasicModel(Map<List<Feature>, Double> map) {
        super(map);
    }

    public SumBasicModel(InputStream inputStream) throws IOException {
        super(inputStream);
    }

    @Override // org.cleartk.summarization.SummarizationModel_ImplBase
    public String getModelName() {
        return MODEL_NAME;
    }
}
