package water.util;

import java.util.Arrays;
import java.util.Iterator;
import java.util.Random;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;
import water.Key;
import water.MRTask;
import water.fvec.C16Chunk;
import water.fvec.CStrChunk;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.fvec.Vec;
import water.parser.BufferedString;

/* loaded from: input_file:water/util/MRUtils.class */
public class MRUtils {
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:water/util/MRUtils$ClassDist.class */
    public static class ClassDist extends MRTask<ClassDist> {
        final int _nclass;
        protected double[] _ys;

        public ClassDist(Vec vec) {
            this._nclass = vec.domain().length;
        }

        public ClassDist(int i) {
            this._nclass = i;
        }

        public final double[] dist() {
            return this._ys;
        }

        public final double[] relDist() {
            double sum = ArrayUtils.sum(this._ys);
            return sum == CMAESOptimizer.DEFAULT_STOPFITNESS ? this._ys : ArrayUtils.div(Arrays.copyOf(this._ys, this._ys.length), sum);
        }

        @Override // water.MRTask
        public void map(Chunk chunk) {
            this._ys = new double[this._nclass];
            for (int i = 0; i < chunk._len; i++) {
                if (!chunk.isNA(i)) {
                    double[] dArr = this._ys;
                    int at8 = (int) chunk.at8(i);
                    dArr[at8] = dArr[at8] + 1.0d;
                }
            }
        }

        @Override // water.MRTask
        public void map(Chunk chunk, Chunk chunk2) {
            this._ys = new double[this._nclass];
            for (int i = 0; i < chunk._len; i++) {
                if (!chunk.isNA(i)) {
                    double[] dArr = this._ys;
                    int at8 = (int) chunk.at8(i);
                    dArr[at8] = dArr[at8] + chunk2.atd(i);
                }
            }
        }

        @Override // water.MRTask
        public void reduce(ClassDist classDist) {
            ArrayUtils.add(this._ys, classDist._ys);
        }
    }

    /* loaded from: input_file:water/util/MRUtils$ClassDistQuasibinomial.class */
    public static class ClassDistQuasibinomial extends MRTask<ClassDistQuasibinomial> {
        final int _nclass = 2;
        private double[] _ys;
        private String[] _domain;
        private double _firstDoubleDomain;

        public ClassDistQuasibinomial(String[] strArr) {
            this._domain = strArr;
            this._firstDoubleDomain = Double.valueOf(strArr[0]).doubleValue();
        }

        public final double[] dist() {
            return this._ys;
        }

        public final double[] relDist() {
            double sum = ArrayUtils.sum(this._ys);
            return sum == CMAESOptimizer.DEFAULT_STOPFITNESS ? this._ys : ArrayUtils.div(Arrays.copyOf(this._ys, this._ys.length), sum);
        }

        public final String[] domains() {
            return this._domain;
        }

        @Override // water.MRTask
        public void map(Chunk chunk) {
            this._ys = new double[this._nclass];
            for (int i = 0; i < chunk._len; i++) {
                if (!chunk.isNA(i)) {
                    boolean z = chunk.atd(i) != this._firstDoubleDomain;
                    double[] dArr = this._ys;
                    dArr[z ? 1 : 0] = dArr[z ? 1 : 0] + 1.0d;
                }
            }
        }

        @Override // water.MRTask
        public void map(Chunk chunk, Chunk chunk2) {
            this._ys = new double[this._nclass];
            for (int i = 0; i < chunk._len; i++) {
                if (!chunk.isNA(i)) {
                    boolean z = chunk.atd(i) != this._firstDoubleDomain;
                    double[] dArr = this._ys;
                    dArr[z ? 1 : 0] = dArr[z ? 1 : 0] + chunk2.atd(i);
                }
            }
        }

        @Override // water.MRTask
        public void reduce(ClassDistQuasibinomial classDistQuasibinomial) {
            ArrayUtils.add(this._ys, classDistQuasibinomial._ys);
        }
    }

    /* loaded from: input_file:water/util/MRUtils$Dist.class */
    public static class Dist extends MRTask<Dist> {
        private IcedHashMap<IcedDouble, IcedAtomicInt> _dist;

        @Override // water.MRTask
        public void map(Chunk chunk) {
            this._dist = new IcedHashMap<>();
            IcedDouble icedDouble = new IcedDouble(CMAESOptimizer.DEFAULT_STOPFITNESS);
            for (int i = 0; i < chunk._len; i++) {
                if (!chunk.isNA(i)) {
                    icedDouble._val = chunk.atd(i);
                    IcedAtomicInt icedAtomicInt = this._dist.get(icedDouble);
                    if (icedAtomicInt == null) {
                        icedAtomicInt = this._dist.putIfAbsent(new IcedDouble(icedDouble._val), new IcedAtomicInt(1));
                    }
                    if (icedAtomicInt != null) {
                        icedAtomicInt.incrementAndGet();
                    }
                }
            }
        }

        @Override // water.MRTask
        public void reduce(Dist dist) {
            if (this._dist != dist._dist) {
                IcedHashMap<IcedDouble, IcedAtomicInt> icedHashMap = this._dist;
                IcedHashMap<IcedDouble, IcedAtomicInt> icedHashMap2 = dist._dist;
                if (icedHashMap.size() < icedHashMap2.size()) {
                    icedHashMap = icedHashMap2;
                    icedHashMap2 = this._dist;
                }
                for (IcedDouble icedDouble : icedHashMap2.keySet()) {
                    IcedAtomicInt putIfAbsent = icedHashMap.putIfAbsent(icedDouble, icedHashMap2.get(icedDouble));
                    if (putIfAbsent != null) {
                        putIfAbsent.addAndGet(icedHashMap2.get(icedDouble).get());
                    }
                }
                this._dist = icedHashMap;
                dist._dist = null;
            }
        }

        public double[] dist() {
            int i = 0;
            double[] dArr = new double[this._dist.size()];
            Iterator<IcedAtomicInt> it = this._dist.values().iterator();
            while (it.hasNext()) {
                int i2 = i;
                i++;
                dArr[i2] = it.next().get();
            }
            return dArr;
        }

        public double[] keys() {
            int i = 0;
            double[] dArr = new double[this._dist.size()];
            Iterator<IcedDouble> it = this._dist.keySet().iterator();
            while (it.hasNext()) {
                int i2 = i;
                i++;
                dArr[i2] = it.next()._val;
            }
            return dArr;
        }
    }

    public static Frame sampleFrame(Frame frame, long j, final long j2) {
        Key<Frame> key;
        if (frame == null) {
            return null;
        }
        final float numRows = j > 0 ? ((float) j) / ((float) frame.numRows()) : 1.0f;
        if (numRows >= 1.0f) {
            return frame;
        }
        if (frame._key != null) {
            key = Key.make(frame._key.toString() + (frame._key.toString().contains("temporary") ? ".sample." : ".temporary.sample.") + PrettyPrint.formatPct(numRows).replace(" ", ""));
        } else {
            key = null;
        }
        Frame outputFrame = new MRTask() { // from class: water.util.MRUtils.1
            @Override // water.MRTask
            public void map(Chunk[] chunkArr, NewChunk[] newChunkArr) {
                Random rng = RandomUtils.getRNG(0);
                BufferedString bufferedString = new BufferedString();
                int i = 0;
                for (int i2 = 0; i2 < chunkArr[0]._len; i2++) {
                    rng.setSeed(j2 + i2 + chunkArr[0].start());
                    if (rng.nextFloat() < numRows || (i == 0 && i2 == chunkArr[0]._len - 1)) {
                        i++;
                        for (int i3 = 0; i3 < newChunkArr.length; i3++) {
                            if (chunkArr[i3].isNA(i2)) {
                                newChunkArr[i3].addNA();
                            } else if (chunkArr[i3] instanceof CStrChunk) {
                                newChunkArr[i3].addStr(chunkArr[i3].atStr(bufferedString, i2));
                            } else if (chunkArr[i3] instanceof C16Chunk) {
                                newChunkArr[i3].addUUID(chunkArr[i3].at16l(i2), chunkArr[i3].at16h(i2));
                            } else {
                                newChunkArr[i3].addNum(chunkArr[i3].atd(i2));
                            }
                        }
                    }
                }
            }
        }.doAll(frame.types(), frame).outputFrame(key, frame.names(), frame.domains());
        if (outputFrame.numRows() != 0) {
            return outputFrame;
        }
        Log.warn("You asked for " + j + " rows (out of " + frame.numRows() + "), but you got none (seed=" + j2 + ").");
        Log.warn("Let's try again. You've gotta ask yourself a question: \"Do I feel lucky?\"");
        return sampleFrame(frame, j, j2 + 1);
    }

    public static Frame shuffleFramePerChunk(Frame frame, final long j) {
        return new MRTask() { // from class: water.util.MRUtils.2
            @Override // water.MRTask
            public void map(Chunk[] chunkArr, NewChunk[] newChunkArr) {
                int[] iArr = new int[chunkArr[0]._len];
                for (int i = 0; i < iArr.length; i++) {
                    iArr[i] = i;
                }
                ArrayUtils.shuffleArray(iArr, RandomUtils.getRNG(j));
                for (long j2 : iArr) {
                    for (int i2 = 0; i2 < newChunkArr.length; i2++) {
                        if (chunkArr[i2] instanceof CStrChunk) {
                            newChunkArr[i2].addStr(chunkArr[i2], chunkArr[i2].start() + j2);
                        } else {
                            newChunkArr[i2].addNum(chunkArr[i2].atd((int) j2));
                        }
                    }
                }
            }
        }.doAll(frame.types(), frame).outputFrame(frame.names(), frame.domains());
    }

    public static Frame sampleFrameStratified(Frame frame, Vec vec, Vec vec2, float[] fArr, long j, long j2, boolean z, boolean z2) {
        return sampleFrameStratified(frame, vec, vec2, fArr, j, j2, z, z2, null);
    }

    public static Frame sampleFrameStratified(Frame frame, Vec vec, Vec vec2, float[] fArr, long j, long j2, boolean z, boolean z2, String[] strArr) {
        if (frame == null) {
            return null;
        }
        if (!$assertionsDisabled && !vec.isCategorical()) {
            throw new AssertionError();
        }
        if (j < vec.domain().length) {
            Log.warn("Attempting to do stratified sampling to fewer samples than there are class labels - automatically increasing to #rows == #labels (" + vec.domain().length + ").");
            j = vec.domain().length;
        }
        double[] dist = strArr != null ? vec2 != null ? new ClassDistQuasibinomial(strArr).doAll(vec, vec2).dist() : new ClassDistQuasibinomial(strArr).doAll(vec).dist() : vec2 != null ? new ClassDist(vec).doAll(vec, vec2).dist() : new ClassDist(vec).doAll(vec).dist();
        if (!$assertionsDisabled && dist.length <= 0) {
            throw new AssertionError();
        }
        Object[] objArr = new Object[1];
        objArr[0] = "Doing stratified sampling for data set containing " + frame.numRows() + " rows from " + dist.length + " classes. Oversampling: " + (z ? "on" : "off");
        Log.info(objArr);
        if (z2) {
            for (int i = 0; i < dist.length; i++) {
                Log.info("Class " + vec.factor(i) + ": count: " + dist[i] + " prior: " + (((float) dist[i]) / ((float) frame.numRows())));
            }
        }
        float[] fArr2 = fArr == null ? new float[dist.length] : (float[]) fArr.clone();
        if (!$assertionsDisabled && fArr2.length != dist.length) {
            throw new AssertionError();
        }
        if (ArrayUtils.minValue(fArr2) == 0.0f && ArrayUtils.maxValue(fArr2) == 0.0f) {
            for (int i2 = 0; i2 < dist.length; i2++) {
                fArr2[i2] = (((float) frame.numRows()) / vec.domain().length) / ((float) dist[i2]);
            }
            float minValue = ArrayUtils.minValue(fArr2);
            if (!Float.isNaN(minValue) && !Float.isInfinite(minValue)) {
                ArrayUtils.div(fArr2, minValue);
            }
        }
        if (!z) {
            for (int i3 = 0; i3 < fArr2.length; i3++) {
                fArr2[i3] = Math.min(1.0f, fArr2[i3]);
            }
        }
        float f = 0.0f;
        for (int i4 = 0; i4 < fArr2.length; i4++) {
            f = (float) (f + (fArr2[i4] * dist[i4]));
        }
        if (Float.isNaN(f)) {
            Log.err("Total number of sampled rows was NaN. Sampling ratios: " + Arrays.toString(fArr2) + "; Dist: " + Arrays.toString(dist));
            throw new IllegalArgumentException("Error during sampling - too few points?");
        }
        long min = Math.min(j, Math.round(f));
        if (!$assertionsDisabled && min < 0) {
            throw new AssertionError();
        }
        Object[] objArr2 = new Object[1];
        objArr2[0] = "Stratified sampling to a total of " + String.format("%,d", Long.valueOf(min)) + " rows" + (((float) min) < f ? " (limited by max_after_balance_size)." : ".");
        Log.info(objArr2);
        if (((float) min) != f) {
            ArrayUtils.mult(fArr2, ((float) min) / f);
            if (z2) {
                Log.info("Downsampling majority class by " + (((float) min) / f) + " to limit number of rows to " + String.format("%,d", Long.valueOf(j)));
            }
        }
        for (int i5 = 0; i5 < vec.domain().length; i5++) {
            Log.info("Class '" + vec.domain()[i5] + "' sampling ratio: " + fArr2[i5]);
        }
        return sampleFrameStratified(frame, vec, vec2, fArr2, j2, z2, strArr);
    }

    public static Frame sampleFrameStratified(Frame frame, Vec vec, Vec vec2, float[] fArr, long j, boolean z, String[] strArr) {
        return sampleFrameStratified(frame, vec, vec2, fArr, j, z, 0, strArr);
    }

    private static Frame sampleFrameStratified(Frame frame, Vec vec, Vec vec2, final float[] fArr, final long j, boolean z, int i, String[] strArr) {
        if (frame == null) {
            return null;
        }
        if (!$assertionsDisabled && !vec.isCategorical()) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && (fArr == null || fArr.length != vec.domain().length)) {
            throw new AssertionError();
        }
        final int find = frame.find(vec);
        if (!$assertionsDisabled && find < 0) {
            throw new AssertionError();
        }
        int find2 = frame.find(vec2);
        Frame outputFrame = new MRTask() { // from class: water.util.MRUtils.3
            static final /* synthetic */ boolean $assertionsDisabled;

            @Override // water.MRTask
            public void map(Chunk[] chunkArr, NewChunk[] newChunkArr) {
                Random rng = RandomUtils.getRNG(j);
                for (int i2 = 0; i2 < chunkArr[0]._len; i2++) {
                    if (!chunkArr[find].isNA(i2)) {
                        rng.setSeed(chunkArr[0].start() + i2 + j);
                        int at8 = (int) chunkArr[find].at8(i2);
                        if (!$assertionsDisabled && (fArr.length <= at8 || at8 < 0)) {
                            throw new AssertionError();
                        }
                        int i3 = ((int) fArr[at8]) + (rng.nextFloat() < fArr[at8] - ((float) ((int) fArr[at8])) ? 1 : 0);
                        for (int i4 = 0; i4 < newChunkArr.length; i4++) {
                            if (chunkArr[i4] instanceof CStrChunk) {
                                for (int i5 = 0; i5 < i3; i5++) {
                                    newChunkArr[i4].addStr(chunkArr[i4], chunkArr[0].start() + i2);
                                }
                            } else {
                                for (int i6 = 0; i6 < i3; i6++) {
                                    newChunkArr[i4].addNum(chunkArr[i4].atd(i2));
                                }
                            }
                        }
                    }
                }
            }

            static {
                $assertionsDisabled = !MRUtils.class.desiredAssertionStatus();
            }
        }.doAll(frame.types(), frame).outputFrame(frame.names(), frame.domains());
        Vec vec3 = outputFrame.vecs()[find];
        Vec vec4 = find2 != -1 ? outputFrame.vecs()[find2] : null;
        double[] dist = strArr != null ? vec4 != null ? new ClassDistQuasibinomial(strArr).doAll(vec3, vec4).dist() : new ClassDistQuasibinomial(strArr).doAll(vec3).dist() : vec4 != null ? new ClassDist(vec3).doAll(vec3, vec4).dist() : new ClassDist(vec3).doAll(vec3).dist();
        if (dist == null) {
            return frame;
        }
        if (z) {
            double sum = ArrayUtils.sum(dist);
            Log.info("After stratified sampling: " + sum + " rows.");
            for (int i2 = 0; i2 < dist.length; i2++) {
                Log.info("Class " + outputFrame.vecs()[find].factor(i2) + ": count: " + dist[i2] + " sampling ratio: " + fArr[i2] + " actual relative frequency: " + ((((float) dist[i2]) / sum) * dist.length));
            }
        }
        if (ArrayUtils.minValue(dist) != CMAESOptimizer.DEFAULT_STOPFITNESS || i >= 10) {
            Frame shuffleFramePerChunk = shuffleFramePerChunk(outputFrame, j + 92339987);
            outputFrame.remove();
            return shuffleFramePerChunk;
        }
        Log.info("Re-doing stratified sampling because not all classes were represented (unlucky draw).");
        outputFrame.remove();
        return sampleFrameStratified(frame, vec, vec2, fArr, j + 1, z, i + 1, strArr);
    }

    static {
        $assertionsDisabled = !MRUtils.class.desiredAssertionStatus();
    }
}
