package org.tribuo.classification.fs;

import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.config.PropertyException;
import com.oracle.labs.mlrg.olcut.util.Pair;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ForkJoinPool;
import java.util.logging.Level;
import java.util.logging.Logger;
import java.util.stream.IntStream;
import org.tribuo.Dataset;
import org.tribuo.FeatureSelector;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.SelectedFeatureSet;
import org.tribuo.classification.Label;
import org.tribuo.provenance.FeatureSelectorProvenance;
import org.tribuo.provenance.FeatureSetProvenance;
import org.tribuo.provenance.impl.FeatureSelectorProvenanceImpl;

/* loaded from: input_file:org/tribuo/classification/fs/JMI.class */
public final class JMI implements FeatureSelector<Label> {
    private static final Logger logger = Logger.getLogger(JMI.class.getName());

    @Config(mandatory = true, description = "Number of bins to use when discretising continuous features.")
    private int numBins;

    @Config(description = "Number of features to select, defaults to ranking all features.")
    private int k;

    @Config(description = "Number of computation threads to use.")
    private int numThreads;

    private JMI() {
        this.k = -1;
        this.numThreads = 1;
    }

    public JMI(int i, int i2, int i3) {
        this.k = -1;
        this.numThreads = 1;
        this.k = i;
        this.numBins = i2;
        this.numThreads = i3;
        if (i != -1 && i < 1) {
            throw new IllegalArgumentException("k must be -1 to select all features, or a positive number, found " + i);
        }
        if (i2 < 2) {
            throw new IllegalArgumentException("numBins must be >= 2, found " + i2);
        }
    }

    public void postConfig() {
        if (this.k != -1 && this.k < 1) {
            throw new PropertyException("", "k", "k must be -1 to select all features, or a positive number, found " + this.k);
        }
        if (this.numBins < 2) {
            throw new PropertyException("", "numBins", "numBins must be >= 2, found " + this.numBins);
        }
    }

    public boolean isOrdered() {
        return true;
    }

    public SelectedFeatureSet select(Dataset<Label> dataset) {
        Pair pair;
        Pair pair2;
        FSMatrix buildMatrix = FSMatrix.buildMatrix(dataset, this.numBins);
        ImmutableFeatureMap featureMap = buildMatrix.getFeatureMap();
        int size = this.k == -1 ? featureMap.size() : Math.min(this.k, featureMap.size());
        int size2 = featureMap.size();
        boolean[] zArr = new boolean[size2];
        Arrays.fill(zArr, true);
        int[] iArr = new int[size];
        double[] dArr = new double[size];
        ForkJoinPool forkJoinPool = null;
        if (this.numThreads > 1) {
            forkJoinPool = new ForkJoinPool(this.numThreads);
            try {
                pair = (Pair) forkJoinPool.submit(() -> {
                    return (Pair) IntStream.range(0, size2).parallel().mapToObj(i -> {
                        return new Pair(Integer.valueOf(i), Double.valueOf(buildMatrix.mi(i)));
                    }).max(Comparator.comparingDouble((v0) -> {
                        return v0.getB();
                    })).get();
                }).get();
            } catch (InterruptedException | ExecutionException e) {
                throw new RuntimeException(e);
            }
        } else {
            pair = (Pair) IntStream.range(0, size2).mapToObj(i -> {
                return new Pair(Integer.valueOf(i), Double.valueOf(buildMatrix.mi(i)));
            }).max(Comparator.comparingDouble((v0) -> {
                return v0.getB();
            })).get();
        }
        int intValue = ((Integer) pair.getA()).intValue();
        iArr[0] = intValue;
        zArr[intValue] = false;
        dArr[0] = ((Double) pair.getB()).doubleValue();
        logger.log(Level.INFO, "Itr 0: selected feature " + featureMap.get(intValue).getName() + ", score = " + dArr[0]);
        double[] dArr2 = new double[size2];
        for (int i2 = 1; i2 < size; i2++) {
            if (this.numThreads > 1) {
                int i3 = iArr[i2 - 1];
                try {
                    double[] dArr3 = (double[]) forkJoinPool.submit(() -> {
                        return IntStream.range(0, size2).parallel().mapToDouble(i4 -> {
                            if (zArr[i4]) {
                                return buildMatrix.jmi(i4, i3);
                            }
                            return 0.0d;
                        }).toArray();
                    }).get();
                    for (int i4 = 0; i4 < dArr2.length; i4++) {
                        int i5 = i4;
                        dArr2[i5] = dArr2[i5] + dArr3[i4];
                    }
                    pair2 = (Pair) forkJoinPool.submit(() -> {
                        return (Pair) IntStream.range(0, size2).parallel().filter(i6 -> {
                            return zArr[i6];
                        }).mapToObj(i7 -> {
                            return new Pair(Integer.valueOf(i7), Double.valueOf(dArr2[i7]));
                        }).max(Comparator.comparingDouble((v0) -> {
                            return v0.getB();
                        })).get();
                    }).get();
                } catch (InterruptedException | ExecutionException e2) {
                    throw new RuntimeException(e2);
                }
            } else {
                int i6 = -1;
                double d = Double.NEGATIVE_INFINITY;
                for (int i7 = 0; i7 < size2; i7++) {
                    if (zArr[i7]) {
                        int i8 = i7;
                        dArr2[i8] = dArr2[i8] + buildMatrix.jmi(i7, iArr[i2 - 1]);
                        if (dArr2[i7] > d) {
                            d = dArr2[i7];
                            i6 = i7;
                        }
                    }
                }
                pair2 = new Pair(Integer.valueOf(i6), Double.valueOf(d));
            }
            int intValue2 = ((Integer) pair2.getA()).intValue();
            iArr[i2] = intValue2;
            zArr[intValue2] = false;
            dArr[i2] = ((Double) pair2.getB()).doubleValue() / i2;
            logger.log(Level.INFO, "Itr " + i2 + ": selected feature " + featureMap.get(intValue2).getName() + ", score = " + pair2.getB() + ", average score = " + dArr[i2]);
        }
        if (forkJoinPool != null) {
            forkJoinPool.shutdown();
        }
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        for (int i9 = 0; i9 < size; i9++) {
            arrayList.add(featureMap.get(iArr[i9]).getName());
            arrayList2.add(Double.valueOf(dArr[i9]));
        }
        return new SelectedFeatureSet(arrayList, arrayList2, isOrdered(), new FeatureSetProvenance(SelectedFeatureSet.class.getName(), dataset.getProvenance(), m3getProvenance()));
    }

    /* renamed from: getProvenance, reason: merged with bridge method [inline-methods] */
    public FeatureSelectorProvenance m3getProvenance() {
        return new FeatureSelectorProvenanceImpl(this);
    }
}
