package ai.h2o.targetencoding;

import java.util.HashMap;
import java.util.Map;
import java.util.Random;
import water.DKV;
import water.Iced;
import water.Key;
import water.MRTask;
import water.MemoryManager;
import water.Scope;
import water.fvec.CategoricalWrappedVec;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.Vec;
import water.fvec.task.FillNAWithLongValueTask;
import water.fvec.task.FilterByValueTask;
import water.fvec.task.IsNotNaTask;
import water.fvec.task.UniqTask;
import water.logging.Logger;
import water.logging.LoggerFactory;
import water.rapids.Rapids;
import water.rapids.Val;
import water.rapids.ast.prims.advmath.AstKFold;
import water.rapids.ast.prims.mungers.AstGroup;
import water.rapids.ast.prims.mungers.AstMelt;
import water.rapids.vals.ValFrame;
import water.rapids.vals.ValNum;
import water.rapids.vals.ValStr;
import water.rapids.vals.ValStrs;
import water.util.ArrayUtils;
import water.util.FrameUtils;

/* loaded from: input_file:ai/h2o/targetencoding/TargetEncoderHelper.class */
public class TargetEncoderHelper extends Iced<TargetEncoderHelper> {
    static String NUMERATOR_COL;
    static String DENOMINATOR_COL;
    static String TARGETCLASS_COL;
    static String NA_POSTFIX;
    private static final Logger logger;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:ai/h2o/targetencoding/TargetEncoderHelper$AddNoiseTask.class */
    public static class AddNoiseTask extends MRTask<AddNoiseTask> {
        private int _columnIdx;
        private int _runifIdx;
        private double _noiseLevel;

        public AddNoiseTask(int i, int i2, double d) {
            this._columnIdx = i;
            this._runifIdx = i2;
            this._noiseLevel = d;
        }

        public void map(Chunk[] chunkArr) {
            Chunk chunk = chunkArr[this._columnIdx];
            Chunk chunk2 = chunkArr[this._runifIdx];
            for (int i = 0; i < chunk._len; i++) {
                if (!chunk.isNA(i)) {
                    chunk.set(i, chunk.atd(i) + (((chunk2.atd(i) * 2.0d) - 1.0d) * this._noiseLevel));
                }
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:ai/h2o/targetencoding/TargetEncoderHelper$ApplyEncodings.class */
    public static class ApplyEncodings extends MRTask<ApplyEncodings> {
        private int _encodedColIdx;
        private int _numeratorIdx;
        private int _denominatorIdx;
        private double _priorMean;
        private BlendingParams _blendingParams;

        ApplyEncodings(int i, int i2, int i3, double d, BlendingParams blendingParams) {
            this._encodedColIdx = i;
            this._numeratorIdx = i2;
            this._denominatorIdx = i3;
            this._priorMean = d;
            this._blendingParams = blendingParams;
        }

        public void map(Chunk[] chunkArr) {
            Chunk chunk = chunkArr[this._numeratorIdx];
            Chunk chunk2 = chunkArr[this._denominatorIdx];
            Chunk chunk3 = chunkArr[this._encodedColIdx];
            boolean z = this._blendingParams != null;
            for (int i = 0; i < chunk._len; i++) {
                if (chunk.isNA(i) || chunk2.isNA(i)) {
                    chunk3.setNA(i);
                } else if (chunk2.at8(i) == 0) {
                    if (TargetEncoderHelper.logger.isDebugEnabled()) {
                        TargetEncoderHelper.logger.debug("Denominator is zero for column index = " + this._encodedColIdx + ". Imputing with _priorMean = " + this._priorMean);
                    }
                    chunk3.set(i, this._priorMean);
                } else {
                    double atd = chunk.atd(i) / chunk2.atd(i);
                    chunk3.set(i, z ? TargetEncoderHelper.getBlendedValue(atd, this._priorMean, chunk2.at8(i), this._blendingParams) : atd);
                }
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:ai/h2o/targetencoding/TargetEncoderHelper$SubtractCurrentRowForLeaveOneOutTask.class */
    public static class SubtractCurrentRowForLeaveOneOutTask extends MRTask<SubtractCurrentRowForLeaveOneOutTask> {
        private int _numeratorIdx;
        private int _denominatorIdx;
        private int _targetIdx;
        private int _targetClass;

        public SubtractCurrentRowForLeaveOneOutTask(int i, int i2, int i3, int i4) {
            this._numeratorIdx = i;
            this._denominatorIdx = i2;
            this._targetIdx = i3;
            this._targetClass = i4;
        }

        public void map(Chunk[] chunkArr) {
            Chunk chunk = chunkArr[this._numeratorIdx];
            Chunk chunk2 = chunkArr[this._denominatorIdx];
            Chunk chunk3 = chunkArr[this._targetIdx];
            for (int i = 0; i < chunk._len; i++) {
                if (!chunk3.isNA(i)) {
                    double atd = chunk3.atd(i);
                    if (this._targetClass == -1) {
                        chunk.set(i, chunk.atd(i) - chunk3.atd(i));
                    } else if (this._targetClass == atd) {
                        chunk.set(i, chunk.atd(i) - 1.0d);
                    }
                    chunk2.set(i, chunk2.atd(i) - 1.0d);
                }
            }
        }
    }

    private TargetEncoderHelper() {
    }

    public static int addKFoldColumn(Frame frame, String str, int i, long j) {
        frame.add(str, AstKFold.kfoldColumn(frame.anyVec().makeZero(), i, j == -1 ? new Random().nextLong() : j));
        return frame.numCols() - 1;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static double computePriorMean(Frame frame) {
        if ($assertionsDisabled || frame.find(TARGETCLASS_COL) < 0) {
            return computePriorMean(frame, -1);
        }
        throw new AssertionError();
    }

    static double computePriorMean(Frame frame, int i) {
        Frame filterByValue;
        int find = frame.find(TARGETCLASS_COL);
        if (!$assertionsDisabled) {
            if ((i == -1) != (find < 0)) {
                throw new AssertionError();
            }
        }
        Frame frame2 = null;
        if (find < 0) {
            filterByValue = frame;
        } else {
            try {
                filterByValue = filterByValue(frame, find, i);
            } catch (Throwable th) {
                if (0 != 0 && null != frame) {
                    frame2.delete();
                }
                throw th;
            }
        }
        Frame frame3 = filterByValue;
        Vec vec = frame3.vec(NUMERATOR_COL);
        Vec vec2 = frame3.vec(DENOMINATOR_COL);
        if (!$assertionsDisabled && vec == null) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && vec2 == null) {
            throw new AssertionError();
        }
        double mean = vec.mean() / vec2.mean();
        if (frame3 != null && frame3 != frame) {
            frame3.delete();
        }
        return mean;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static Frame buildEncodingsFrame(Frame frame, int i, int i2, int i3, int i4) {
        Frame frame2;
        try {
            Scope.enter();
            int[] iArr = i3 < 0 ? new int[]{i} : new int[]{i, i3};
            if (i4 > 2) {
                String name = frame.name(i2);
                Vec vec = frame.vec(i2);
                Frame frame3 = new FrameUtils.CategoricalOneHotEncoder(new Frame(new String[]{name}, new Vec[]{vec}), new String[0]).exec().get();
                Scope.track(new Frame[]{frame3});
                Frame add = new Frame(frame).add(frame3);
                AstGroup.AGG[] aggArr = new AstGroup.AGG[frame3.numCols() + 1];
                for (int i5 = 0; i5 < frame3.numCols(); i5++) {
                    aggArr[i5] = new AstGroup.AGG(AstGroup.FCN.sum, frame.numCols() + i5, AstGroup.NAHandling.ALL, -1);
                }
                aggArr[aggArr.length - 1] = new AstGroup.AGG(AstGroup.FCN.nrow, i2, AstGroup.NAHandling.ALL, -1);
                Frame frame4 = new AstGroup().performGroupingWithAggregations(add, iArr, aggArr).getFrame();
                Scope.track(new Frame[]{frame4});
                String[] strArr = new String[frame3.numCols()];
                for (int i6 = 0; i6 < frame3.names().length; i6++) {
                    String name2 = frame3.name(i6);
                    String replaceFirst = name2.replaceFirst(name + ".", "");
                    renameColumn(frame4, "sum_" + name2, replaceFirst);
                    strArr[i6] = replaceFirst;
                }
                renameColumn(frame4, "nrow", DENOMINATOR_COL);
                frame2 = melt(frame4, i3 < 0 ? new String[]{frame.name(i), DENOMINATOR_COL} : new String[]{frame.name(i), frame.name(i3), DENOMINATOR_COL}, strArr, TARGETCLASS_COL, NUMERATOR_COL, true);
                CategoricalWrappedVec.updateDomain(frame2.vec(TARGETCLASS_COL), vec.domain());
            } else {
                frame2 = new AstGroup().performGroupingWithAggregations(frame, iArr, new AstGroup.AGG[]{new AstGroup.AGG(AstGroup.FCN.sum, i2, AstGroup.NAHandling.ALL, -1), new AstGroup.AGG(AstGroup.FCN.nrow, i2, AstGroup.NAHandling.ALL, -1)}).getFrame();
                renameColumn(frame2, "sum_" + frame.name(i2), NUMERATOR_COL);
                renameColumn(frame2, "nrow", DENOMINATOR_COL);
            }
            Scope.untrack(new Frame[]{frame2});
            Frame frame5 = frame2;
            Scope.exit(new Key[0]);
            return frame5;
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static Frame groupEncodingsByCategory(Frame frame, int i) {
        int find = frame.find(NUMERATOR_COL);
        if (!$assertionsDisabled && find < 0) {
            throw new AssertionError();
        }
        int find2 = frame.find(DENOMINATOR_COL);
        if (!$assertionsDisabled && find2 < 0) {
            throw new AssertionError();
        }
        int find3 = frame.find(TARGETCLASS_COL);
        Frame frame2 = new AstGroup().performGroupingWithAggregations(frame, find3 < 0 ? new int[]{i} : new int[]{i, find3}, new AstGroup.AGG[]{new AstGroup.AGG(AstGroup.FCN.sum, find, AstGroup.NAHandling.ALL, -1), new AstGroup.AGG(AstGroup.FCN.sum, find2, AstGroup.NAHandling.ALL, -1)}).getFrame();
        renameColumn(frame2, "sum_" + NUMERATOR_COL, NUMERATOR_COL);
        renameColumn(frame2, "sum_" + DENOMINATOR_COL, DENOMINATOR_COL);
        return frame2;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static Frame groupEncodingsByCategory(Frame frame, int i, boolean z) {
        return z ? groupEncodingsByCategory(frame, i) : frame.deepCopy(Key.make().toString());
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static void imputeCategoricalColumn(Frame frame, int i, String str) {
        Vec vec = frame.vec(i);
        int cardinality = vec.cardinality();
        FillNAWithLongValueTask fillNAWithLongValueTask = new FillNAWithLongValueTask(i, cardinality);
        fillNAWithLongValueTask.doAll(frame);
        if (fillNAWithLongValueTask._imputationHappened) {
            String[] domain = vec.domain();
            String[] strArr = new String[cardinality + 1];
            System.arraycopy(domain, 0, strArr, 0, domain.length);
            strArr[cardinality] = str;
            updateColumnDomain(frame, i, strArr);
        }
    }

    private static void updateColumnDomain(Frame frame, int i, String[] strArr) {
        frame.write_lock();
        Vec vec = frame.vec(i);
        vec.setDomain(strArr);
        DKV.put(vec);
        frame.update();
        frame.unlock();
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static long[] getUniqueColumnValues(Frame frame, int i) {
        Vec vec = uniqueValuesBy(frame, i).vec(0);
        long length = vec.length();
        if (!$assertionsDisabled && length > 2147483647L) {
            throw new AssertionError("Number of unique values exceeded Integer.MAX_VALUE");
        }
        long[] malloc8 = MemoryManager.malloc8((int) length);
        for (int i2 = 0; i2 < vec.length(); i2++) {
            malloc8[i2] = vec.at8(i2);
        }
        vec.remove();
        return malloc8;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static double getBlendedValue(double d, double d2, long j, BlendingParams blendingParams) {
        double exp = 1.0d / (1.0d + Math.exp((blendingParams.getInflectionPoint() - j) / blendingParams.getSmoothing()));
        return (exp * d) + ((1.0d - exp) * d2);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static Frame mergeEncodings(Frame frame, Frame frame2, int i, int i2) {
        return mergeEncodings(frame, frame2, i, -1, i2, -1, 0);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static Frame mergeEncodings(Frame frame, Frame frame2, int i, int i2, int i3, int i4, int i5) {
        return TargetEncoderBroadcastJoin.join(frame, new int[]{i}, i2, frame2, new int[]{i3}, i4, i5);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static int applyEncodings(Frame frame, String str, double d, BlendingParams blendingParams) {
        int find = frame.find(NUMERATOR_COL);
        if (!$assertionsDisabled && find < 0) {
            throw new AssertionError();
        }
        int i = find + 1;
        frame.add(str, frame.anyVec().makeCon(0.0d));
        int numCols = frame.numCols() - 1;
        new ApplyEncodings(numCols, find, i, d, blendingParams).doAll(frame);
        return numCols;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static void addNoise(Frame frame, int i, double d, long j) {
        if (j == -1) {
            j = new Random().nextLong();
        }
        Vec makeCon = frame.anyVec().makeCon(0.0d);
        Vec makeRand = makeCon.makeRand(j);
        try {
            frame.add("runIf", makeRand);
            int numCols = frame.numCols() - 1;
            new AddNoiseTask(i, numCols, d).doAll(frame);
            frame.remove(numCols);
            makeRand.remove();
            makeCon.remove();
        } catch (Throwable th) {
            makeRand.remove();
            makeCon.remove();
            throw th;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static void subtractTargetValueForLOO(Frame frame, String str, int i) {
        int find = frame.find(NUMERATOR_COL);
        int find2 = frame.find(DENOMINATOR_COL);
        int find3 = frame.find(str);
        if (!$assertionsDisabled && find < 0) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && find2 < 0) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && find3 < 0) {
            throw new AssertionError();
        }
        new SubtractCurrentRowForLeaveOneOutTask(find, find2, find3, i).doAll(frame);
    }

    static Frame melt(Frame frame, String[] strArr, String[] strArr2, String str, String str2, boolean z) {
        AstMelt astMelt = new AstMelt();
        Val[] valArr = new Val[7];
        valArr[0] = null;
        valArr[1] = new ValFrame(frame);
        valArr[2] = new ValStrs(strArr);
        valArr[3] = new ValStrs(strArr2);
        valArr[4] = new ValStr(str);
        valArr[5] = new ValStr(str2);
        valArr[6] = new ValNum(z ? 1.0d : 0.0d);
        return register(astMelt.exec(valArr).getFrame());
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static Frame rBind(Frame frame, Frame frame2) {
        if (frame != null) {
            return execRapidsAndGetFrame(String.format("(rbind %s %s)", frame._key, frame2._key));
        }
        if ($assertionsDisabled || frame2 != null) {
            return frame2;
        }
        throw new AssertionError();
    }

    private static Frame execRapidsAndGetFrame(String str) {
        return register(Rapids.exec(str).getFrame());
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static int addCon(Frame frame, String str, long j) {
        frame.add(str, frame.anyVec().makeCon(j));
        return frame.numCols() - 1;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static Frame filterOutNAsInColumn(Frame frame, int i) {
        Frame outputFrame = new IsNotNaTask().doAll(1, (byte) 3, new Frame(new Vec[]{frame.vec(i)})).outputFrame();
        Frame selectByPredicate = selectByPredicate(frame, outputFrame);
        outputFrame.delete();
        return selectByPredicate;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static Frame filterNotByValue(Frame frame, int i, double d) {
        return filterByValueBase(frame, i, d, true);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static Frame filterByValue(Frame frame, int i, double d) {
        return filterByValueBase(frame, i, d, false);
    }

    private static Frame filterByValueBase(Frame frame, int i, double d, boolean z) {
        Frame outputFrame = new FilterByValueTask(d, z).doAll(1, (byte) 3, new Frame(new Vec[]{frame.vec(i)})).outputFrame();
        Frame selectByPredicate = selectByPredicate(frame, outputFrame);
        outputFrame.delete();
        return selectByPredicate;
    }

    private static Frame selectByPredicate(Frame frame, Frame frame2) {
        return new Frame.DeepSelect().doAll(frame.types(), (Vec[]) ArrayUtils.append(frame.vecs(), new Vec[]{frame2.anyVec()})).outputFrame(Key.make(), frame._names, frame.domains());
    }

    /* JADX WARN: Type inference failed for: r0v18, types: [ai.h2o.targetencoding.TargetEncoderHelper$1] */
    static Frame uniqueValuesBy(Frame frame, int i) {
        Vec makeZero;
        Vec vec = frame.vec(i);
        if (vec.isCategorical()) {
            makeZero = Vec.makeSeq(0L, vec.domain().length, true);
            makeZero.setDomain(vec.domain());
            DKV.put(makeZero);
        } else {
            UniqTask doAll = new UniqTask().doAll(new Vec[]{vec});
            int size = doAll._uniq.size();
            final AstGroup.G[] gArr = (AstGroup.G[]) doAll._uniq.keySet().toArray(new AstGroup.G[size]);
            makeZero = Vec.makeZero(size, vec.get_type());
            new MRTask() { // from class: ai.h2o.targetencoding.TargetEncoderHelper.1
                public void map(Chunk chunk) {
                    int start = (int) chunk.start();
                    for (int i2 = 0; i2 < chunk._len; i2++) {
                        chunk.set(i2, gArr[i2 + start]._gs[0]);
                    }
                }
            }.doAll(new Vec[]{makeZero});
        }
        return new Frame(new Vec[]{makeZero});
    }

    static void renameColumn(Frame frame, int i, String str) {
        String[] names = frame.names();
        names[i] = str;
        frame.setNames(names);
    }

    static void renameColumn(Frame frame, String str, String str2) {
        renameColumn(frame, frame.find(str), str2);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static Map<String, Integer> nameToIndex(Frame frame) {
        return nameToIndex(frame.names());
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static Map<String, Integer> nameToIndex(String[] strArr) {
        HashMap hashMap = new HashMap(strArr.length);
        for (int i = 0; i < strArr.length; i++) {
            hashMap.put(strArr[i], Integer.valueOf(i));
        }
        return hashMap;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static Frame register(Frame frame) {
        frame._key = Key.make();
        DKV.put(frame);
        return frame;
    }

    static void printFrame(Frame frame) {
        System.out.println(frame.toTwoDimTable(0L, (int) frame.numRows(), false).toString(2, true));
    }

    static {
        $assertionsDisabled = !TargetEncoderHelper.class.desiredAssertionStatus();
        NUMERATOR_COL = "numerator";
        DENOMINATOR_COL = "denominator";
        TARGETCLASS_COL = "targetclass";
        NA_POSTFIX = "_NA";
        logger = LoggerFactory.getLogger(TargetEncoderHelper.class);
    }
}
