package rocks.vilaverde.classifier.ensemble;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import org.apache.commons.compress.archivers.tar.TarArchiveEntry;
import org.apache.commons.compress.archivers.tar.TarArchiveInputStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import rocks.vilaverde.classifier.Classifier;
import rocks.vilaverde.classifier.Prediction;
import rocks.vilaverde.classifier.dt.DecisionTreeClassifier;
import rocks.vilaverde.classifier.dt.PredictionFactory;
import rocks.vilaverde.classifier.dt.TreeClassifier;

/* loaded from: input_file:rocks/vilaverde/classifier/ensemble/RandomForestClassifier.class */
public class RandomForestClassifier<T> implements Classifier<T> {
    private static final Logger LOG = LoggerFactory.getLogger(RandomForestClassifier.class);
    private int jobs;
    private List<TreeClassifier<T>> forest;

    public static <T> Classifier<T> parse(TarArchiveInputStream tarArchiveInputStream, PredictionFactory<T> predictionFactory) throws Exception {
        return parse(tarArchiveInputStream, predictionFactory, 0);
    }

    public static <T> Classifier<T> parse(final TarArchiveInputStream tarArchiveInputStream, PredictionFactory<T> predictionFactory, int i) throws Exception {
        ArrayList arrayList = new ArrayList();
        while (true) {
            try {
                TarArchiveEntry nextTarEntry = tarArchiveInputStream.getNextTarEntry();
                if (nextTarEntry == null) {
                    break;
                }
                if (!nextTarEntry.isDirectory()) {
                    LOG.debug("Parsing tree {}", nextTarEntry.getName());
                    arrayList.add(DecisionTreeClassifier.parse(new BufferedReader(new InputStreamReader(new InputStream() { // from class: rocks.vilaverde.classifier.ensemble.RandomForestClassifier.1
                        @Override // java.io.InputStream
                        public int read() throws IOException {
                            return tarArchiveInputStream.read();
                        }

                        @Override // java.io.InputStream, java.io.Closeable, java.lang.AutoCloseable
                        public void close() throws IOException {
                        }
                    })), predictionFactory));
                }
            } catch (Throwable th) {
                if (tarArchiveInputStream != null) {
                    try {
                        tarArchiveInputStream.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                }
                throw th;
            }
        }
        if (tarArchiveInputStream != null) {
            tarArchiveInputStream.close();
        }
        return new RandomForestClassifier(arrayList, i);
    }

    private RandomForestClassifier(List<TreeClassifier<T>> list, int i) {
        this.forest = list;
        this.jobs = i;
    }

    @Override // rocks.vilaverde.classifier.Classifier
    public T predict(Map<String, Double> map) {
        Map map2 = (Map) getPredictions(map).stream().collect(Collectors.groupingBy((v0) -> {
            return v0.get();
        }, Collectors.counting()));
        long asLong = map2.values().stream().mapToLong((v0) -> {
            return v0.longValue();
        }).max().getAsLong();
        for (Map.Entry entry : map2.entrySet()) {
            if (((Long) entry.getValue()).longValue() == asLong) {
                return (T) entry.getKey();
            }
        }
        throw new IllegalStateException("no classification");
    }

    @Override // rocks.vilaverde.classifier.Classifier
    public double[] predict_proba(Map<String, Double> map) {
        if (this.forest.size() == 1) {
            return this.forest.get(0).getClassification(map).getProbability();
        }
        double[] dArr = null;
        Iterator<Prediction<T>> it = getPredictions(map).iterator();
        while (it.hasNext()) {
            double[] probability = it.next().getProbability();
            if (dArr == null) {
                dArr = probability;
            } else {
                for (int i = 0; i < probability.length; i++) {
                    double[] dArr2 = dArr;
                    int i2 = i;
                    dArr2[i2] = dArr2[i2] + probability[i];
                }
            }
        }
        int size = this.forest.size();
        for (int i3 = 0; i3 < dArr.length; i3++) {
            double[] dArr3 = dArr;
            int i4 = i3;
            dArr3[i4] = dArr3[i4] / size;
        }
        return dArr;
    }

    protected List<Prediction<T>> getPredictions(Map<String, Double> map) {
        ArrayList arrayList = new ArrayList(this.forest.size());
        if (this.jobs == -1) {
            this.jobs = Runtime.getRuntime().availableProcessors();
        }
        if (this.jobs > 0) {
            ExecutorService newFixedThreadPool = Executors.newFixedThreadPool(this.jobs);
            try {
                for (TreeClassifier<T> treeClassifier : this.forest) {
                    newFixedThreadPool.submit(() -> {
                        Prediction<T> classification = treeClassifier.getClassification(map);
                        synchronized (arrayList) {
                            arrayList.add(classification);
                        }
                    });
                }
            } finally {
                newFixedThreadPool.shutdown();
                try {
                    newFixedThreadPool.awaitTermination(Long.MAX_VALUE, TimeUnit.NANOSECONDS);
                } catch (Exception e) {
                    LOG.error("interrupted while searching trees", e);
                }
            }
        } else {
            Iterator<TreeClassifier<T>> it = this.forest.iterator();
            while (it.hasNext()) {
                arrayList.add(it.next().getClassification(map));
            }
        }
        return arrayList;
    }

    @Override // rocks.vilaverde.classifier.Classifier
    public Set<String> getFeatureNames() {
        HashSet hashSet = new HashSet();
        Iterator<TreeClassifier<T>> it = this.forest.iterator();
        while (it.hasNext()) {
            hashSet.addAll(it.next().getFeatureNames());
        }
        return hashSet;
    }
}
