package org.tribuo.classification.fs;

import com.oracle.labs.mlrg.olcut.util.MutableLong;
import java.util.HashMap;
import org.tribuo.CategoricalIDInfo;
import org.tribuo.Dataset;
import org.tribuo.Example;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.RealIDInfo;
import org.tribuo.VariableIDInfo;
import org.tribuo.classification.Label;
import org.tribuo.math.la.DenseVector;
import org.tribuo.transform.Transformer;
import org.tribuo.transform.transformations.BinningTransformation;
import org.tribuo.util.infotheory.InformationTheory;
import org.tribuo.util.infotheory.impl.CachedPair;
import org.tribuo.util.infotheory.impl.CachedTriple;
import org.tribuo.util.infotheory.impl.PairDistribution;
import org.tribuo.util.infotheory.impl.TripleDistribution;

/* JADX INFO: Access modifiers changed from: package-private */
/* loaded from: input_file:org/tribuo/classification/fs/DenseFSMatrix.class */
public final class DenseFSMatrix implements FSMatrix {
    private final int[] labels;
    private final int[][] features;
    private final ImmutableFeatureMap fmap;
    private final int numBins;
    private final int numLabels;

    private DenseFSMatrix(int[] iArr, int[][] iArr2, ImmutableFeatureMap immutableFeatureMap, int i, int i2) {
        this.labels = iArr;
        this.features = iArr2;
        this.fmap = immutableFeatureMap;
        this.numBins = i;
        this.numLabels = i2;
    }

    @Override // org.tribuo.classification.fs.FSMatrix
    public int getNumFeatures() {
        return this.features.length;
    }

    @Override // org.tribuo.classification.fs.FSMatrix
    public int getNumSamples() {
        return this.labels.length;
    }

    @Override // org.tribuo.classification.fs.FSMatrix
    public ImmutableFeatureMap getFeatureMap() {
        return this.fmap;
    }

    @Override // org.tribuo.classification.fs.FSMatrix
    public double mi(int i) {
        HashMap hashMap = new HashMap();
        for (int i2 = 0; i2 < this.labels.length; i2++) {
            ((MutableLong) hashMap.computeIfAbsent(new CachedPair(Integer.valueOf(this.features[i][i2]), Integer.valueOf(this.labels[i2])), cachedPair -> {
                return new MutableLong();
            })).increment();
        }
        return InformationTheory.mi(PairDistribution.constructFromMap(hashMap, this.numBins, this.numLabels));
    }

    @Override // org.tribuo.classification.fs.FSMatrix
    public double mi(int i, int i2) {
        HashMap hashMap = new HashMap();
        for (int i3 = 0; i3 < this.labels.length; i3++) {
            ((MutableLong) hashMap.computeIfAbsent(new CachedPair(Integer.valueOf(this.features[i][i3]), Integer.valueOf(this.features[i2][i3])), cachedPair -> {
                return new MutableLong();
            })).increment();
        }
        return InformationTheory.mi(PairDistribution.constructFromMap(hashMap, this.numBins, this.numBins));
    }

    @Override // org.tribuo.classification.fs.FSMatrix
    public double jmi(int i, int i2) {
        HashMap hashMap = new HashMap();
        for (int i3 = 0; i3 < this.labels.length; i3++) {
            ((MutableLong) hashMap.computeIfAbsent(new CachedTriple(Integer.valueOf(this.features[i][i3]), Integer.valueOf(this.features[i2][i3]), Integer.valueOf(this.labels[i3])), cachedTriple -> {
                return new MutableLong();
            })).increment();
        }
        return InformationTheory.jointMI(TripleDistribution.constructFromMap(hashMap));
    }

    @Override // org.tribuo.classification.fs.FSMatrix
    public double jmi(int i, int i2, int i3) {
        HashMap hashMap = new HashMap();
        for (int i4 = 0; i4 < this.labels.length; i4++) {
            ((MutableLong) hashMap.computeIfAbsent(new CachedTriple(Integer.valueOf(this.features[i][i4]), Integer.valueOf(this.features[i2][i4]), Integer.valueOf(this.features[i3][i4])), cachedTriple -> {
                return new MutableLong();
            })).increment();
        }
        return InformationTheory.jointMI(TripleDistribution.constructFromMap(hashMap));
    }

    @Override // org.tribuo.classification.fs.FSMatrix
    public double cmi(int i, int i2) {
        HashMap hashMap = new HashMap();
        for (int i3 = 0; i3 < this.labels.length; i3++) {
            ((MutableLong) hashMap.computeIfAbsent(new CachedTriple(Integer.valueOf(this.features[i][i3]), Integer.valueOf(this.labels[i3]), Integer.valueOf(this.features[i2][i3])), cachedTriple -> {
                return new MutableLong();
            })).increment();
        }
        return InformationTheory.conditionalMI(TripleDistribution.constructFromMap(hashMap));
    }

    @Override // org.tribuo.classification.fs.FSMatrix
    public double cmi(int i, int i2, int i3) {
        HashMap hashMap = new HashMap();
        for (int i4 = 0; i4 < this.labels.length; i4++) {
            ((MutableLong) hashMap.computeIfAbsent(new CachedTriple(Integer.valueOf(this.features[i][i4]), Integer.valueOf(this.features[i2][i4]), Integer.valueOf(this.features[i3][i4])), cachedTriple -> {
                return new MutableLong();
            })).increment();
        }
        return InformationTheory.conditionalMI(TripleDistribution.constructFromMap(hashMap));
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static DenseFSMatrix equalWidthBins(Dataset<Label> dataset, int i) {
        ImmutableFeatureMap featureIDMap = dataset.getFeatureIDMap();
        ImmutableOutputInfo outputIDInfo = dataset.getOutputIDInfo();
        int size = featureIDMap.size();
        int size2 = dataset.size();
        int size3 = dataset.getOutputInfo().size();
        int[][] iArr = new int[size][size2];
        int[] iArr2 = new int[size2];
        Transformer[] transformerArr = new Transformer[size];
        for (int i2 = 0; i2 < size; i2++) {
            transformerArr[i2] = makeBinningTransformer(featureIDMap.get(i2), size2, i);
        }
        for (int i3 = 0; i3 < size2; i3++) {
            Example example = dataset.getExample(i3);
            DenseVector createDenseVector = DenseVector.createDenseVector(example, featureIDMap, false);
            for (int i4 = 0; i4 < size; i4++) {
                iArr[i4][i3] = (int) transformerArr[i4].transform(createDenseVector.get(i4));
            }
            iArr2[i3] = outputIDInfo.getID(example.getOutput());
        }
        return new DenseFSMatrix(iArr2, iArr, featureIDMap, i, size3);
    }

    private static Transformer makeBinningTransformer(VariableIDInfo variableIDInfo, int i, int i2) {
        int count = variableIDInfo.getCount();
        double d = Double.POSITIVE_INFINITY;
        double d2 = Double.NEGATIVE_INFINITY;
        if (variableIDInfo instanceof CategoricalIDInfo) {
            for (double d3 : ((CategoricalIDInfo) variableIDInfo).getValues()) {
                d = Math.min(d, d3);
                d2 = Math.max(d2, d3);
            }
        } else {
            if (!(variableIDInfo instanceof RealIDInfo)) {
                throw new IllegalStateException("Unknown variable info subclass " + variableIDInfo.getClass());
            }
            RealIDInfo realIDInfo = (RealIDInfo) variableIDInfo;
            d = realIDInfo.getMin();
            d2 = realIDInfo.getMax();
        }
        if (i != count) {
            d = Math.min(d, 0.0d);
            d2 = Math.max(d2, 0.0d);
        }
        double abs = Math.abs(d2 - d) / i2;
        double[] dArr = new double[i2];
        double[] dArr2 = new double[i2];
        for (int i3 = 0; i3 < dArr.length; i3++) {
            dArr[i3] = d + ((i3 + 1) * abs);
            dArr2[i3] = i3 + 1;
        }
        return new BinningTransformation.BinningTransformer(BinningTransformation.BinningType.EQUAL_WIDTH, dArr, dArr2);
    }
}
