package water.util;

import hex.Interaction;
import joptsimple.internal.Strings;
import water.Iced;
import water.Job;
import water.Key;
import water.Keyed;
import water.MRTask;
import water.exceptions.H2OIllegalArgumentException;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.AtomicUtils;

/* loaded from: input_file:water/util/Tabulate.class */
public class Tabulate extends Keyed<Tabulate> {
    public Frame _dataset;
    public String _predictor;
    public String _response;
    public String _weight;
    double[][] _count_data;
    double[][] _response_data;
    public TwoDimTable _count_table;
    public TwoDimTable _response_table;
    static final /* synthetic */ boolean $assertionsDisabled;
    public Key[] _vecs = new Key[2];
    int _nbins_predictor = 20;
    int _nbins_response = 10;
    private final Stats[] _stats = new Stats[2];
    public final Job<Tabulate> _job = new Job<>(Key.make(), Tabulate.class.getName(), "Tabulate job");

    /* loaded from: input_file:water/util/Tabulate$CoOccurrence.class */
    private static class CoOccurrence extends MRTask<CoOccurrence> {
        final Tabulate _sp;

        CoOccurrence(Tabulate tabulate) {
            this._sp = tabulate;
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // water.MRTask
        public void setupLocal() {
            this._sp._count_data = new double[this._sp.res(0)][this._sp.res(1)];
            this._sp._response_data = new double[this._sp.res(0)][2];
        }

        @Override // water.MRTask
        public void map(Chunk chunk, Chunk chunk2) {
            map(chunk, chunk2, (Chunk) null);
        }

        @Override // water.MRTask
        public void map(Chunk chunk, Chunk chunk2, Chunk chunk3) {
            for (int i = 0; i < chunk.len(); i++) {
                int bin = this._sp.bin(0, chunk.atd(i));
                int bin2 = this._sp.bin(1, chunk2.atd(i));
                double atd = chunk3 != null ? chunk3.atd(i) : 1.0d;
                if (!Double.isNaN(atd)) {
                    AtomicUtils.DoubleArray.add(this._sp._count_data[bin], bin2, atd);
                    if (!chunk2.isNA(i)) {
                        AtomicUtils.DoubleArray.add(this._sp._response_data[bin], 0, atd * chunk2.atd(i));
                        AtomicUtils.DoubleArray.add(this._sp._response_data[bin], 1, atd);
                    }
                }
            }
        }

        @Override // water.MRTask
        public void reduce(CoOccurrence coOccurrence) {
            if (this._sp._response_data == coOccurrence._sp._response_data) {
                return;
            }
            ArrayUtils.add(this._sp._response_data, coOccurrence._sp._response_data);
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // water.MRTask
        public void postGlobal() {
            for (int i = 0; i < this._sp._response_data.length; i++) {
                double[] dArr = this._sp._response_data[i];
                dArr[0] = dArr[0] / this._sp._response_data[i][1];
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:water/util/Tabulate$Stats.class */
    public static class Stats extends Iced {
        final double _min;
        final double _max;
        final boolean _isCategorical;
        final boolean _isInt;
        final int _cardinality;
        final int _missing;
        final String[] _domain;

        Stats(Vec vec) {
            this._min = vec.min();
            this._max = vec.max();
            this._isCategorical = vec.isCategorical();
            this._isInt = vec.isInt();
            this._cardinality = vec.cardinality();
            this._missing = vec.naCnt() > 0 ? 1 : 0;
            this._domain = vec.domain();
        }
    }

    private int bins(int i) {
        return i == 1 ? this._nbins_response : this._nbins_predictor;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public int res(int i) {
        int i2 = this._stats[i]._missing;
        return this._stats[i]._isCategorical ? this._stats[i]._cardinality + i2 : bins(i) + i2;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public int bin(int i, double d) {
        int min;
        if (Double.isNaN(d)) {
            return 0;
        }
        int bins = bins(i);
        if (!this._stats[i]._isCategorical) {
            int i2 = (int) ((d - this._stats[i]._min) / ((this._stats[i]._max - this._stats[i]._min) / bins));
            if (!$assertionsDisabled && (i2 < 0 || i2 > bins)) {
                throw new AssertionError();
            }
            min = Math.min(i2, bins - 1);
        } else {
            if (!$assertionsDisabled && ((int) d) != d) {
                throw new AssertionError();
            }
            min = (int) d;
        }
        return min + this._stats[i]._missing;
    }

    private String labelForBin(int i, int i2) {
        int i3 = this._stats[i]._missing;
        if (i3 == 1 && i2 == 0) {
            return "missing(NA)";
        }
        if (i3 == 1) {
            i2--;
        }
        if (this._stats[i]._isCategorical) {
            return this._stats[i]._domain[i2];
        }
        int bins = bins(i);
        return (!this._stats[i]._isInt || (this._stats[i]._max - this._stats[i]._min) + 1.0d > ((double) bins)) ? String.format("%5f", Double.valueOf(this._stats[i]._min + ((i2 + 0.5d) * ((this._stats[i]._max - this._stats[i]._min) / bins)))) : Integer.toString((int) (this._stats[i]._min + i2));
    }

    public Tabulate execImpl() {
        if (this._dataset == null) {
            throw new H2OIllegalArgumentException("Dataset not found");
        }
        if (this._nbins_predictor < 1) {
            throw new H2OIllegalArgumentException("Number of bins for predictor must be >= 1");
        }
        if (this._nbins_response < 1) {
            throw new H2OIllegalArgumentException("Number of bins for response must be >= 1");
        }
        Vec vec = this._dataset.vec(this._predictor);
        if (vec == null) {
            throw new H2OIllegalArgumentException("Predictor column " + this._predictor + " not found");
        }
        if (vec.cardinality() > this._nbins_predictor) {
            Interaction interaction = new Interaction();
            interaction._source_frame = this._dataset._key;
            interaction._factor_columns = new String[]{this._predictor};
            interaction._max_factors = this._nbins_predictor - 1;
            interaction.execImpl(null);
            vec = interaction._job._result.get().anyVec();
        } else if (vec.isInt() && (vec.max() - vec.min()) + 1.0d <= this._nbins_predictor) {
            vec = vec.toCategoricalVec();
        }
        Vec vec2 = this._dataset.vec(this._response);
        if (vec2 == null) {
            throw new H2OIllegalArgumentException("Response column " + this._response + " not found");
        }
        if (vec2.cardinality() > this._nbins_response) {
            Interaction interaction2 = new Interaction();
            interaction2._source_frame = this._dataset._key;
            interaction2._factor_columns = new String[]{this._response};
            interaction2._max_factors = this._nbins_response - 1;
            interaction2.execImpl(null);
            vec2 = interaction2._job._result.get().anyVec();
        } else if (vec2.isInt() && (vec2.max() - vec2.min()) + 1.0d <= this._nbins_response) {
            vec2 = vec2.toCategoricalVec();
        }
        if (vec2 != null && vec2.cardinality() > 2) {
            Log.warn("Response column has more than two factor levels - mean response depends on lexicographic order of factors!");
        }
        Vec vec3 = this._dataset.vec(this._weight);
        if (vec3 != null && !vec3.isNumeric() && vec3.min() < 0.0d) {
            throw new H2OIllegalArgumentException("Observation weights must be numeric with values >= 0");
        }
        if (vec != null) {
            this._vecs[0] = vec._key;
            this._stats[0] = new Stats(vec);
        }
        if (vec2 != null) {
            this._vecs[1] = vec2._key;
            this._stats[1] = new Stats(vec2);
        }
        Tabulate tabulate = vec3 != null ? new CoOccurrence(this).doAll(vec, vec2, vec3)._sp : new CoOccurrence(this).doAll(vec, vec2)._sp;
        this._count_table = tabulate.tabulationTwoDimTable();
        this._response_table = tabulate.responseCharTwoDimTable();
        Log.info(this._count_table.toString(2, false));
        Log.info(this._response_table.toString(2, false));
        return tabulate;
    }

    public TwoDimTable tabulationTwoDimTable() {
        if (this._response_data == null) {
            return null;
        }
        int length = this._count_data.length;
        int length2 = this._count_data[0].length;
        String str = "(Weighted) co-occurrence counts of '" + this._predictor + "' and '" + this._response + Strings.SINGLE_QUOTE;
        String[] strArr = new String[length * length2];
        String[] strArr2 = new String[r0.length];
        String[] strArr3 = new String[r0.length];
        strArr2[0] = "string";
        strArr3[0] = "%s";
        strArr2[1] = "string";
        strArr3[1] = "%s";
        String[] strArr4 = {this._predictor, this._response, "counts"};
        strArr2[2] = "double";
        strArr3[2] = "%f";
        TwoDimTable twoDimTable = new TwoDimTable(str, null, strArr, strArr4, strArr2, strArr3, null);
        for (int i = 0; i < length; i++) {
            String labelForBin = labelForBin(0, i);
            for (int i2 = 0; i2 < length2; i2++) {
                String labelForBin2 = labelForBin(1, i2);
                for (int i3 = 0; i3 < 3; i3++) {
                    twoDimTable.set((i2 * length) + i, 0, labelForBin);
                    twoDimTable.set((i2 * length) + i, 1, labelForBin2);
                    twoDimTable.set((i2 * length) + i, 2, Double.valueOf(this._count_data[i][i2]));
                }
            }
        }
        return twoDimTable;
    }

    public TwoDimTable responseCharTwoDimTable() {
        if (this._response_data == null) {
            return null;
        }
        String str = "Mean value of '" + this._response + "' and (weighted) counts for '" + this._predictor + "' values";
        int length = this._count_data.length;
        String[] strArr = new String[length];
        String[] strArr2 = new String[r0.length];
        String[] strArr3 = new String[r0.length];
        strArr2[0] = "string";
        strArr3[0] = "%s";
        strArr2[2] = "double";
        strArr3[2] = "%f";
        String[] strArr4 = {this._predictor, "mean " + this._response, "counts"};
        strArr2[1] = "double";
        strArr3[1] = "%f";
        TwoDimTable twoDimTable = new TwoDimTable(str, null, strArr, strArr4, strArr2, strArr3, null);
        for (int i = 0; i < length; i++) {
            twoDimTable.set(i, 0, labelForBin(0, i));
            twoDimTable.set(i, 1, Double.valueOf(this._response_data[i][0]));
            twoDimTable.set(i, 2, Double.valueOf(this._response_data[i][1]));
        }
        return twoDimTable;
    }

    static {
        $assertionsDisabled = !Tabulate.class.desiredAssertionStatus();
    }
}
