package rocks.vilaverde.classifier.ensemble;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.stream.Collectors;
import org.apache.commons.compress.archivers.tar.TarArchiveEntry;
import org.apache.commons.compress.archivers.tar.TarArchiveInputStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import rocks.vilaverde.classifier.Classifier;
import rocks.vilaverde.classifier.Prediction;
import rocks.vilaverde.classifier.dt.DecisionTreeClassifier;
import rocks.vilaverde.classifier.dt.PredictionFactory;
import rocks.vilaverde.classifier.dt.TreeClassifier;
import rocks.vilaverde.classifier.util.ThrowingFunction;

/* loaded from: input_file:rocks/vilaverde/classifier/ensemble/RandomForestClassifier.class */
public class RandomForestClassifier<T> implements Classifier<T> {
    private static final Logger LOG = LoggerFactory.getLogger(RandomForestClassifier.class);
    private final ExecutorService executorService;
    private final List<TreeClassifier<T>> forest;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:rocks/vilaverde/classifier/ensemble/RandomForestClassifier$ParallelPrediction.class */
    public static class ParallelPrediction<T> implements Callable<List<Prediction<T>>> {
        private final int start;
        private final int offset;
        private final List<TreeClassifier<T>> forest;
        private final Map<String, Double> features;

        private ParallelPrediction(List<TreeClassifier<T>> list, Map<String, Double> map, int i, int i2) {
            this.offset = i2;
            this.start = i;
            this.forest = list;
            this.features = map;
        }

        @Override // java.util.concurrent.Callable
        public List<Prediction<T>> call() throws Exception {
            ArrayList arrayList = new ArrayList();
            int i = this.start;
            while (true) {
                int i2 = i;
                if (i2 >= this.forest.size()) {
                    return arrayList;
                }
                arrayList.add(this.forest.get(i2).getClassification(this.features));
                i = i2 + this.offset;
            }
        }
    }

    public static <T> Classifier<T> parse(TarArchiveInputStream tarArchiveInputStream, PredictionFactory<T> predictionFactory) throws Exception {
        return parse(tarArchiveInputStream, predictionFactory, null);
    }

    public static <T> Classifier<T> parse(final TarArchiveInputStream tarArchiveInputStream, PredictionFactory<T> predictionFactory, ExecutorService executorService) throws Exception {
        ArrayList arrayList = new ArrayList();
        while (true) {
            try {
                TarArchiveEntry nextTarEntry = tarArchiveInputStream.getNextTarEntry();
                if (nextTarEntry == null) {
                    break;
                }
                if (!nextTarEntry.isDirectory()) {
                    LOG.debug("Parsing tree {}", nextTarEntry.getName());
                    arrayList.add(DecisionTreeClassifier.parse(new BufferedReader(new InputStreamReader(new InputStream() { // from class: rocks.vilaverde.classifier.ensemble.RandomForestClassifier.1
                        @Override // java.io.InputStream
                        public int read() throws IOException {
                            return tarArchiveInputStream.read();
                        }

                        @Override // java.io.InputStream, java.io.Closeable, java.lang.AutoCloseable
                        public void close() throws IOException {
                        }
                    })), predictionFactory));
                }
            } catch (Throwable th) {
                if (tarArchiveInputStream != null) {
                    try {
                        tarArchiveInputStream.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                }
                throw th;
            }
        }
        if (tarArchiveInputStream != null) {
            tarArchiveInputStream.close();
        }
        return new RandomForestClassifier(arrayList, executorService);
    }

    private RandomForestClassifier(List<TreeClassifier<T>> list, ExecutorService executorService) {
        this.forest = list;
        this.executorService = executorService;
    }

    @Override // rocks.vilaverde.classifier.Classifier
    public T predict(Map<String, Double> map) {
        Map map2 = (Map) getPredictions(map).stream().collect(Collectors.groupingBy((v0) -> {
            return v0.get();
        }, Collectors.counting()));
        long orElse = map2.values().stream().mapToLong((v0) -> {
            return v0.longValue();
        }).max().orElse(0L);
        for (Map.Entry entry : map2.entrySet()) {
            if (((Long) entry.getValue()).longValue() == orElse) {
                return (T) entry.getKey();
            }
        }
        throw new IllegalStateException("no classification");
    }

    @Override // rocks.vilaverde.classifier.Classifier
    public double[] predict_proba(Map<String, Double> map) {
        if (this.forest.size() == 1) {
            return this.forest.get(0).getClassification(map).getProbability();
        }
        double[] dArr = null;
        Iterator<Prediction<T>> it = getPredictions(map).iterator();
        while (it.hasNext()) {
            double[] probability = it.next().getProbability();
            if (dArr == null) {
                dArr = probability;
            } else {
                for (int i = 0; i < probability.length; i++) {
                    double[] dArr2 = dArr;
                    int i2 = i;
                    dArr2[i2] = dArr2[i2] + probability[i];
                }
            }
        }
        if (dArr != null) {
            int size = this.forest.size();
            for (int i3 = 0; i3 < dArr.length; i3++) {
                double[] dArr3 = dArr;
                int i4 = i3;
                dArr3[i4] = dArr3[i4] / size;
            }
        }
        return dArr;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v29, types: [java.util.List] */
    protected List<Prediction<T>> getPredictions(Map<String, Double> map) {
        ArrayList arrayList;
        if (this.executorService != null) {
            int availableProcessors = Runtime.getRuntime().availableProcessors();
            ArrayList arrayList2 = new ArrayList(availableProcessors);
            for (int i = 0; i < availableProcessors; i++) {
                arrayList2.add(new ParallelPrediction(this.forest, map, i, availableProcessors));
            }
            try {
                arrayList = (List) this.executorService.invokeAll(arrayList2).stream().flatMap(ThrowingFunction.wrap(future -> {
                    return ((List) future.get()).stream();
                })).collect(Collectors.toList());
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
        } else {
            arrayList = new ArrayList(this.forest.size());
            Iterator<TreeClassifier<T>> it = this.forest.iterator();
            while (it.hasNext()) {
                arrayList.add(it.next().getClassification(map));
            }
        }
        return arrayList;
    }

    @Override // rocks.vilaverde.classifier.Classifier
    public Set<String> getFeatureNames() {
        HashSet hashSet = new HashSet();
        Iterator<TreeClassifier<T>> it = this.forest.iterator();
        while (it.hasNext()) {
            hashSet.addAll(it.next().getFeatureNames());
        }
        return hashSet;
    }
}
