package ai.h2o.targetencoding;

import water.MRTask;
import water.MemoryManager;
import water.fvec.CategoricalWrappedVec;
import water.fvec.Chunk;
import water.fvec.Frame;

/* JADX INFO: Access modifiers changed from: package-private */
/* loaded from: input_file:ai/h2o/targetencoding/TargetEncoderBroadcastJoin.class */
public class TargetEncoderBroadcastJoin {
    static final /* synthetic */ boolean $assertionsDisabled;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:ai/h2o/targetencoding/TargetEncoderBroadcastJoin$BroadcastJoiner.class */
    public static class BroadcastJoiner extends MRTask<BroadcastJoiner> {
        int _categoricalColumnIdx;
        int _foldColumnIdx;
        int _maxKnownCatLevel;
        double[][] _encodingDataArray;
        int[][] _levelMappings;
        static final /* synthetic */ boolean $assertionsDisabled;

        BroadcastJoiner(int[] iArr, int i, double[][] dArr, int[][] iArr2, int i2) {
            if (!$assertionsDisabled && iArr.length != 1) {
                throw new AssertionError("Only single column target encoding (i.e. one categorical column is used to produce its encodings) is supported for now");
            }
            if (!$assertionsDisabled && iArr2.length != 1) {
                throw new AssertionError();
            }
            this._categoricalColumnIdx = iArr[0];
            this._foldColumnIdx = i;
            this._encodingDataArray = dArr;
            this._levelMappings = iArr2;
            this._maxKnownCatLevel = i2;
        }

        public void map(Chunk[] chunkArr) {
            int[] iArr = this._levelMappings[0];
            Chunk chunk = chunkArr[this._categoricalColumnIdx];
            Chunk chunk2 = chunkArr[chunkArr.length - 2];
            Chunk chunk3 = chunkArr[chunkArr.length - 1];
            for (int i = 0; i < chunk2.len(); i++) {
                int at8 = (int) chunk.at8(i);
                if (at8 >= iArr.length) {
                    setEncodingComponentsToNAs(chunk2, chunk3, i);
                } else {
                    int i2 = iArr[at8];
                    double[] dArr = this._encodingDataArray[this._foldColumnIdx >= 0 ? (int) chunkArr[this._foldColumnIdx].at8(i) : 0];
                    if (i2 > this._maxKnownCatLevel) {
                        setEncodingComponentsToNAs(chunk2, chunk3, i);
                    } else {
                        double d = dArr[(2 * i2) + 1];
                        if (d == 0.0d) {
                            setEncodingComponentsToNAs(chunk2, chunk3, i);
                        } else {
                            chunk2.set(i, dArr[2 * i2]);
                            chunk3.set(i, d);
                        }
                    }
                }
            }
        }

        private void setEncodingComponentsToNAs(Chunk chunk, Chunk chunk2, int i) {
            chunk.setNA(i);
            chunk2.setNA(i);
        }

        static {
            $assertionsDisabled = !TargetEncoderBroadcastJoin.class.desiredAssertionStatus();
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:ai/h2o/targetencoding/TargetEncoderBroadcastJoin$FrameWithEncodingDataToArray.class */
    public static class FrameWithEncodingDataToArray extends MRTask<FrameWithEncodingDataToArray> {
        final double[][] _encodingDataPerNode;
        final int _categoricalColumnIdx;
        final int _foldColumnIdx;
        final int _numeratorIdx;
        final int _denominatorIdx;
        final int _cardinalityOfCatCol;
        static final /* synthetic */ boolean $assertionsDisabled;

        FrameWithEncodingDataToArray(int i, int i2, int i3, int i4, int i5, int i6) {
            this._categoricalColumnIdx = i;
            this._foldColumnIdx = i2;
            this._numeratorIdx = i3;
            this._denominatorIdx = i4;
            this._cardinalityOfCatCol = i5;
            if (i2 == -1) {
                this._encodingDataPerNode = MemoryManager.malloc8d(1, this._cardinalityOfCatCol * 2);
                return;
            }
            if (!$assertionsDisabled && i6 < 1) {
                throw new AssertionError("There should be at least two folds in the fold column");
            }
            if (!$assertionsDisabled && (this._cardinalityOfCatCol <= 0 || this._cardinalityOfCatCol >= 1073741823)) {
                throw new AssertionError("Cardinality of categ. column should be within range (0, Integer.MAX_VALUE / 2 )");
            }
            this._encodingDataPerNode = MemoryManager.malloc8d(i6 + 1, this._cardinalityOfCatCol * 2);
        }

        public void map(Chunk[] chunkArr) {
            Chunk chunk = chunkArr[this._categoricalColumnIdx];
            Chunk chunk2 = chunkArr[this._numeratorIdx];
            Chunk chunk3 = chunkArr[this._denominatorIdx];
            for (int i = 0; i < chunk.len(); i++) {
                int at8 = (int) chunk.at8(i);
                double[] dArr = this._encodingDataPerNode[this._foldColumnIdx != -1 ? (int) chunkArr[this._foldColumnIdx].at8(i) : 0];
                dArr[2 * at8] = chunk2.atd(i);
                dArr[(2 * at8) + 1] = chunk3.at8(i);
            }
        }

        public void reduce(FrameWithEncodingDataToArray frameWithEncodingDataToArray) {
            double[][] encodingDataArray = getEncodingDataArray();
            double[][] encodingDataArray2 = frameWithEncodingDataToArray.getEncodingDataArray();
            if (encodingDataArray != encodingDataArray2) {
                for (int i = 0; i < encodingDataArray.length; i++) {
                    for (int i2 = 0; i2 < encodingDataArray[i].length; i2++) {
                        encodingDataArray[i][i2] = Math.max(encodingDataArray[i][i2], encodingDataArray2[i][i2]);
                    }
                }
            }
        }

        double[][] getEncodingDataArray() {
            return this._encodingDataPerNode;
        }

        static {
            $assertionsDisabled = !TargetEncoderBroadcastJoin.class.desiredAssertionStatus();
        }
    }

    TargetEncoderBroadcastJoin() {
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* JADX WARN: Type inference failed for: r0v15, types: [int[], int[][]] */
    public static Frame join(Frame frame, int[] iArr, int i, Frame frame2, int[] iArr2, int i2, int i3) {
        int find = frame2.find("numerator");
        int find2 = frame2.find("denominator");
        if (!$assertionsDisabled && iArr.length != 1) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && iArr2.length != 1) {
            throw new AssertionError();
        }
        int i4 = iArr[0];
        int i5 = iArr2[0];
        int cardinality = frame2.vec(i5).cardinality();
        if (i2 != -1 && frame2.vec(i2).max() > 2.147483647E9d) {
            throw new IllegalArgumentException("Fold value should be a non-negative integer (i.e. should belong to [0, Integer.MAX_VALUE] range)");
        }
        ?? r0 = {CategoricalWrappedVec.computeMap(frame.vec(i4).domain(), frame2.vec(i5).domain())};
        double[][] encodingsToArray = encodingsToArray(frame2, i5, i2, find, find2, cardinality, i3);
        Frame frame3 = new Frame(frame);
        frame3.add("numerator", frame3.anyVec().makeCon(0.0d));
        frame3.add("denominator", frame3.anyVec().makeCon(0.0d));
        new BroadcastJoiner(iArr, i, encodingsToArray, r0, cardinality - 1).doAll(frame3);
        return frame3;
    }

    static double[][] encodingsToArray(Frame frame, int i, int i2, int i3, int i4, int i5, int i6) {
        return ((FrameWithEncodingDataToArray) new FrameWithEncodingDataToArray(i, i2, i3, i4, i5, i6).doAll(frame)).getEncodingDataArray();
    }

    static {
        $assertionsDisabled = !TargetEncoderBroadcastJoin.class.desiredAssertionStatus();
    }
}
