package org.tribuo.classification.fs;

import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.config.PropertyException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Objects;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ForkJoinPool;
import java.util.logging.Level;
import java.util.logging.Logger;
import java.util.stream.IntStream;
import org.tribuo.Dataset;
import org.tribuo.FeatureSelector;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.SelectedFeatureSet;
import org.tribuo.classification.Label;
import org.tribuo.provenance.FeatureSelectorProvenance;
import org.tribuo.provenance.FeatureSetProvenance;
import org.tribuo.provenance.impl.FeatureSelectorProvenanceImpl;

/* loaded from: input_file:org/tribuo/classification/fs/CMIM.class */
public final class CMIM implements FeatureSelector<Label> {
    private static final Logger logger = Logger.getLogger(CMIM.class.getName());

    @Config(mandatory = true, description = "Number of bins to use when discretising continuous features.")
    private int numBins;

    @Config(description = "Number of features to select, defaults to ranking all features.")
    private int k;

    @Config(description = "Number of computation threads to use.")
    private int numThreads;

    private CMIM() {
        this.k = -1;
        this.numThreads = 1;
    }

    public CMIM(int i, int i2, int i3) {
        this.k = -1;
        this.numThreads = 1;
        this.k = i;
        this.numBins = i2;
        this.numThreads = i3;
        if (i != -1 && i < 1) {
            throw new IllegalArgumentException("k must be -1 to select all features, or a positive number, found " + i);
        }
        if (i2 < 2) {
            throw new IllegalArgumentException("numBins must be >= 2, found " + i2);
        }
    }

    public void postConfig() {
        if (this.k != -1 && this.k < 1) {
            throw new PropertyException("", "k", "k must be -1 to select all features, or a positive number, found " + this.k);
        }
        if (this.numBins < 2) {
            throw new PropertyException("", "numBins", "numBins must be >= 2, found " + this.numBins);
        }
    }

    public boolean isOrdered() {
        return true;
    }

    public SelectedFeatureSet select(Dataset<Label> dataset) {
        double[] array;
        FSMatrix buildMatrix = FSMatrix.buildMatrix(dataset, this.numBins);
        ImmutableFeatureMap featureMap = buildMatrix.getFeatureMap();
        int size = this.k == -1 ? featureMap.size() : Math.min(this.k, featureMap.size());
        int size2 = featureMap.size();
        boolean[] zArr = new boolean[size2];
        Arrays.fill(zArr, true);
        int[] iArr = new int[size];
        double[] dArr = new double[size];
        int[] iArr2 = new int[size2];
        if (this.numThreads > 1) {
            ForkJoinPool forkJoinPool = new ForkJoinPool(this.numThreads);
            try {
                array = (double[]) forkJoinPool.submit(() -> {
                    IntStream parallel = IntStream.range(0, size2).parallel();
                    Objects.requireNonNull(buildMatrix);
                    return parallel.mapToDouble(buildMatrix::mi).toArray();
                }).get();
                forkJoinPool.shutdown();
            } catch (InterruptedException | ExecutionException e) {
                throw new RuntimeException(e);
            }
        } else {
            IntStream range = IntStream.range(0, size2);
            Objects.requireNonNull(buildMatrix);
            array = range.mapToDouble(buildMatrix::mi).toArray();
        }
        int i = -1;
        double d = -1.0d;
        for (int i2 = 0; i2 < size2; i2++) {
            if (array[i2] > d) {
                i = i2;
                d = array[i2];
            }
        }
        iArr[0] = i;
        dArr[0] = d;
        zArr[i] = false;
        logger.log(Level.INFO, "Itr 0: selected feature " + featureMap.get(i).getName() + ", score = " + dArr[0]);
        for (int i3 = 1; i3 < size; i3++) {
            double d2 = -1.0d;
            int i4 = -1;
            for (int i5 = 0; i5 < size2; i5++) {
                if (zArr[i5]) {
                    while (array[i5] > d2 && iArr2[i5] < i3) {
                        double cmi = buildMatrix.cmi(i5, iArr[iArr2[i5]]);
                        if (cmi < array[i5]) {
                            array[i5] = cmi;
                        }
                        int i6 = i5;
                        iArr2[i6] = iArr2[i6] + 1;
                    }
                    if (array[i5] > d2) {
                        d2 = array[i5];
                        i4 = i5;
                    }
                }
            }
            iArr[i3] = i4;
            dArr[i3] = d2;
            zArr[i4] = false;
            logger.log(Level.INFO, "Itr " + i3 + ": selected feature " + featureMap.get(i4).getName() + ", score = " + d2);
        }
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        for (int i7 = 0; i7 < size; i7++) {
            arrayList.add(featureMap.get(iArr[i7]).getName());
            arrayList2.add(Double.valueOf(dArr[i7]));
        }
        return new SelectedFeatureSet(arrayList, arrayList2, isOrdered(), new FeatureSetProvenance(SelectedFeatureSet.class.getName(), dataset.getProvenance(), m1getProvenance()));
    }

    /* renamed from: getProvenance, reason: merged with bridge method [inline-methods] */
    public FeatureSelectorProvenance m1getProvenance() {
        return new FeatureSelectorProvenanceImpl(this);
    }
}
