package io.cdap.mmds.data;

import com.google.common.collect.Sets;
import io.cdap.mmds.NullableMath;
import io.cdap.mmds.stats.CategoricalHisto;
import io.cdap.mmds.stats.NumericBin;
import io.cdap.mmds.stats.NumericHisto;
import java.math.RoundingMode;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;

/* loaded from: input_file:lib/mmds-model-1.6.0.jar:io/cdap/mmds/data/ColumnSplitStats.class */
public class ColumnSplitStats {
    private static final DecimalFormat NOTATION_FORMAT = new DecimalFormat("0.00E0");
    private static final DecimalFormat DECIMAL_FORMAT = new DecimalFormat("###.####");
    private final String field;
    private final SplitVal<Long> numTotal;
    private final SplitVal<Long> numNull;
    private final SplitVal<Long> numEmpty;
    private final SplitVal<Long> unique;
    private final SplitVal<Long> numZero;
    private final SplitVal<Long> numPositive;
    private final SplitVal<Long> numNegative;
    private final SplitVal<Double> min;
    private final SplitVal<Double> max;
    private final SplitVal<Double> mean;
    private final SplitVal<Double> stddev;
    private final List<SplitHistogramBin> histo;
    private final double divergence;

    public ColumnSplitStats(String str, SplitVal<Long> splitVal, SplitVal<Long> splitVal2, SplitVal<Long> splitVal3, SplitVal<Long> splitVal4, SplitVal<Long> splitVal5, SplitVal<Long> splitVal6, SplitVal<Long> splitVal7, SplitVal<Double> splitVal8, SplitVal<Double> splitVal9, SplitVal<Double> splitVal10, SplitVal<Double> splitVal11, List<SplitHistogramBin> list) {
        this.field = str;
        this.numTotal = splitVal;
        this.numNull = splitVal2;
        this.numEmpty = splitVal3;
        this.unique = splitVal4;
        this.numZero = splitVal5;
        this.numPositive = splitVal6;
        this.numNegative = splitVal7;
        this.min = splitVal8;
        this.max = splitVal9;
        this.mean = splitVal10;
        this.stddev = splitVal11;
        this.histo = list;
        double d = 0.0d;
        double longValue = (splitVal.getTrain().longValue() - splitVal2.getTrain().longValue()) + list.size();
        double longValue2 = (splitVal.getTest().longValue() - splitVal2.getTest().longValue()) + list.size();
        for (SplitHistogramBin splitHistogramBin : list) {
            double longValue3 = (1 + splitHistogramBin.getCount().getTrain().longValue()) / longValue;
            double longValue4 = (1 + splitHistogramBin.getCount().getTest().longValue()) / longValue2;
            d += longValue4 * Math.log(longValue4 / longValue3);
        }
        this.divergence = Math.max(0.0d, Math.min(1.0d, d));
    }

    public ColumnSplitStats(String str, NumericHisto numericHisto, NumericHisto numericHisto2) {
        this(str, new SplitCountVal(numericHisto.getTotalCount(), numericHisto2.getTotalCount()), new SplitCountVal(numericHisto.getNullCount(), numericHisto2.getNullCount()), null, null, new SplitCountVal(numericHisto.getZeroCount(), numericHisto2.getZeroCount()), new SplitCountVal(numericHisto.getPositiveCount(), numericHisto2.getPositiveCount()), new SplitCountVal(numericHisto.getNegativeCount(), numericHisto2.getNegativeCount()), new SplitVal(numericHisto.getMin(), numericHisto2.getMin(), NullableMath.min(numericHisto.getMin(), numericHisto2.getMin())), new SplitVal(numericHisto.getMax(), numericHisto2.getMax(), NullableMath.max(numericHisto.getMax(), numericHisto2.getMax())), new SplitVal(numericHisto.getMean(), numericHisto2.getMean(), NullableMath.mean(numericHisto.getMean(), numericHisto.getNonNullCount(), numericHisto2.getMean(), numericHisto2.getNonNullCount())), new SplitVal(numericHisto.getStddev(), numericHisto2.getStddev(), NullableMath.stddev(numericHisto.getM2(), numericHisto.getMean(), numericHisto.getNonNullCount(), numericHisto2.getM2(), numericHisto2.getMean(), numericHisto2.getNonNullCount())), convert(numericHisto, numericHisto2));
    }

    public ColumnSplitStats(String str, CategoricalHisto categoricalHisto, CategoricalHisto categoricalHisto2) {
        this(str, new SplitCountVal(categoricalHisto.getTotalCount(), categoricalHisto2.getTotalCount()), new SplitCountVal(categoricalHisto.getNullCount(), categoricalHisto2.getNullCount()), new SplitCountVal(categoricalHisto.getEmptyCount(), categoricalHisto2.getEmptyCount()), new SplitVal(Long.valueOf(categoricalHisto.getCounts().size()), Long.valueOf(categoricalHisto2.getCounts().size()), Long.valueOf(Sets.union(categoricalHisto.getCounts().keySet(), categoricalHisto2.getCounts().keySet()).size())), null, null, null, null, null, null, null, convert(categoricalHisto, categoricalHisto2));
    }

    public List<SplitHistogramBin> getHisto() {
        return this.histo;
    }

    public String getField() {
        return this.field;
    }

    public SplitVal<Long> getNumTotal() {
        return this.numTotal;
    }

    public SplitVal<Long> getNumNull() {
        return this.numNull;
    }

    public SplitVal<Long> getNumEmpty() {
        return this.numEmpty;
    }

    public SplitVal<Long> getUnique() {
        return this.unique;
    }

    public SplitVal<Long> getNumZero() {
        return this.numZero;
    }

    public SplitVal<Long> getNumPositive() {
        return this.numPositive;
    }

    public SplitVal<Long> getNumNegative() {
        return this.numNegative;
    }

    public SplitVal<Double> getMin() {
        return this.min;
    }

    public SplitVal<Double> getMax() {
        return this.max;
    }

    public SplitVal<Double> getMean() {
        return this.mean;
    }

    public SplitVal<Double> getStddev() {
        return this.stddev;
    }

    public double getDivergence() {
        return this.divergence;
    }

    private static List<SplitHistogramBin> convert(NumericHisto numericHisto, NumericHisto numericHisto2) {
        if (numericHisto.getBins().size() != numericHisto2.getBins().size()) {
            throw new IllegalArgumentException("Cannot combine numeric histograms with different bins.");
        }
        ArrayList arrayList = new ArrayList(numericHisto.getBins().size());
        Iterator<NumericBin> it = numericHisto2.getBins().iterator();
        for (NumericBin numericBin : numericHisto.getBins()) {
            NumericBin next = it.next();
            if (numericBin.getLo() != next.getLo() || numericBin.getHi() != next.getHi() || numericBin.isHiInclusive() != next.isHiInclusive()) {
                throw new IllegalArgumentException("Cannot combine numeric histograms with different bins. Bin1 = " + format(numericBin) + ", Bin2 = " + format(next));
            }
            arrayList.add(new SplitHistogramBin(format(numericBin), new SplitCountVal(numericBin.getCount(), next.getCount())));
        }
        return arrayList;
    }

    public static List<SplitHistogramBin> convert(CategoricalHisto categoricalHisto, CategoricalHisto categoricalHisto2) {
        ArrayList arrayList = new ArrayList(categoricalHisto.getCounts().size());
        for (Map.Entry<String, Long> entry : categoricalHisto.getCounts().entrySet()) {
            String key = entry.getKey();
            Long value = entry.getValue();
            Long l = categoricalHisto2.getCounts().get(key);
            arrayList.add(new SplitHistogramBin(key, new SplitCountVal(value.longValue(), l == null ? 0L : l.longValue())));
        }
        for (Map.Entry<String, Long> entry2 : categoricalHisto2.getCounts().entrySet()) {
            String key2 = entry2.getKey();
            Long value2 = entry2.getValue();
            if (!categoricalHisto.getCounts().containsKey(key2)) {
                arrayList.add(new SplitHistogramBin(key2, new SplitCountVal(0L, value2.longValue())));
            }
        }
        arrayList.sort((splitHistogramBin, splitHistogramBin2) -> {
            int compare = Long.compare(splitHistogramBin2.getCount().getTrain().longValue(), splitHistogramBin.getCount().getTrain().longValue());
            if (compare != 0) {
                return compare;
            }
            int compare2 = Long.compare(splitHistogramBin2.getCount().getTest().longValue(), splitHistogramBin.getCount().getTest().longValue());
            return compare2 != 0 ? compare2 : splitHistogramBin.getBin().compareTo(splitHistogramBin2.getBin());
        });
        return arrayList;
    }

    private static String format(NumericBin numericBin) {
        return String.format(numericBin.isHiInclusive() ? "[%s,%s]" : "[%s,%s)", format(numericBin.getLo()), format(numericBin.getHi()));
    }

    private static String format(double d) {
        double abs = Math.abs(d);
        return ((abs > 1000.0d || (abs < 0.001d && abs > 0.0d)) ? NOTATION_FORMAT : DECIMAL_FORMAT).format(d);
    }

    static {
        NOTATION_FORMAT.setRoundingMode(RoundingMode.HALF_UP);
    }
}
