package ai.h2o.targetencoding;

import ai.h2o.targetencoding.TargetEncoder;
import ai.h2o.targetencoding.TargetEncoderModel;
import java.util.Map;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import water.DKV;
import water.Key;
import water.MRTask;
import water.Scope;
import water.TestUtil;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.fvec.TestFrameBuilder;
import water.fvec.Vec;
import water.rapids.Merge;
import water.util.IcedHashMapGeneric;

/* loaded from: input_file:ai/h2o/targetencoding/TargetEncoderBuilderTest.class */
public class TargetEncoderBuilderTest extends TestUtil {

    /* loaded from: input_file:ai/h2o/targetencoding/TargetEncoderBuilderTest$RowIndexTask.class */
    private static class RowIndexTask extends MRTask<RowIndexTask> {
        static String ROW_INDEX_COL = "__row_index";

        private RowIndexTask() {
        }

        public void map(Chunk chunk, NewChunk newChunk) {
            long start = chunk.start();
            for (int i = 0; i < chunk._len; i++) {
                newChunk.addNum(start + i);
            }
        }

        /* JADX INFO: Access modifiers changed from: private */
        public static void addRowIndex(Frame frame) {
            frame.insertVec(0, ROW_INDEX_COL, ((RowIndexTask) new RowIndexTask().doAll((byte) 3, new Vec[]{frame.anyVec()})).outputFrame().anyVec());
        }
    }

    @BeforeClass
    public static void setup() {
        stall_till_cloudsize(1);
    }

    @Test
    public void getTargetEncodingMapByTrainingTEBuilder() {
        TargetEncoderModel targetEncoderModel = null;
        Map map = null;
        Map map2 = null;
        Scope.enter();
        try {
            Frame parse_test_file = parse_test_file("./smalldata/gbm_test/titanic.csv");
            Scope.track(new Frame[]{parse_test_file});
            asFactor(parse_test_file, "survived");
            new BlendingParams(3.0d, 1.0d);
            Frame.VecSpecifier[] vecSpecifierArr = {new Frame.VecSpecifier(parse_test_file._key, "home.dest"), new Frame.VecSpecifier(parse_test_file._key, "embarked")};
            TargetEncoderModel.TargetEncoderParameters targetEncoderParameters = new TargetEncoderModel.TargetEncoderParameters();
            targetEncoderParameters._blending = false;
            targetEncoderParameters._response_column = "survived";
            targetEncoderParameters._ignored_columns = ignoredColumns(parse_test_file, new String[]{"home.dest", "embarked", targetEncoderParameters._response_column});
            targetEncoderParameters.setTrain(parse_test_file._key);
            targetEncoderModel = (TargetEncoderModel) new TargetEncoderBuilder(targetEncoderParameters).trainModel().get();
            TargetEncoder targetEncoder = new TargetEncoder(Frame.VecSpecifier.vecNames(vecSpecifierArr));
            Frame parse_test_file2 = parse_test_file("./smalldata/gbm_test/titanic.csv");
            asFactor(parse_test_file2, "survived");
            Scope.track(new Frame[]{parse_test_file2});
            map = targetEncoder.prepareEncodingMap(parse_test_file2, "survived", (String) null);
            map2 = targetEncoderModel._output._target_encoding_map;
            areEncodingMapsIdentical(map, map2);
            removeEncodingMaps(map, map2);
            targetEncoderModel.remove();
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            removeEncodingMaps(map, map2);
            targetEncoderModel.remove();
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Test
    public void teColumnNameToMissingValuesPresenceMapIsComputedCorrectly() {
        TargetEncoderModel targetEncoderModel = null;
        Scope.enter();
        try {
            Frame build = new TestFrameBuilder().withName("testFrame").withColNames(new String[]{"home.dest", "embarked", "ColB"}).withVecTypes(new byte[]{4, 4, 4}).withDataForCol(0, ar(new String[]{"a", "b"})).withDataForCol(1, ar(new String[]{"s", null})).withDataForCol(2, ar(new String[]{"yes", "no"})).build();
            Frame.VecSpecifier[] vecSpecifierArr = {new Frame.VecSpecifier(build._key, "home.dest"), new Frame.VecSpecifier(build._key, "embarked")};
            TargetEncoderModel.TargetEncoderParameters targetEncoderParameters = new TargetEncoderModel.TargetEncoderParameters();
            targetEncoderParameters._blending = false;
            targetEncoderParameters._response_column = "ColB";
            targetEncoderParameters._ignored_columns = ignoredColumns(build, new String[]{"home.dest", "embarked", targetEncoderParameters._response_column});
            targetEncoderParameters.setTrain(build._key);
            targetEncoderParameters._ignore_const_cols = false;
            targetEncoderModel = (TargetEncoderModel) new TargetEncoderBuilder(targetEncoderParameters).trainModel().get();
            IcedHashMapGeneric icedHashMapGeneric = targetEncoderModel._output._column_name_to_missing_val_presence;
            Assert.assertTrue(((Integer) icedHashMapGeneric.get("home.dest")).intValue() == 0);
            Assert.assertTrue(((Integer) icedHashMapGeneric.get("embarked")).intValue() == 1);
            removeEncodingMaps(null, null);
            if (targetEncoderModel != null) {
                targetEncoderModel.remove();
            }
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            removeEncodingMaps(null, null);
            if (targetEncoderModel != null) {
                targetEncoderModel.remove();
            }
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Test
    public void getTargetEncodingMapByTrainingTEBuilder_KFold_scenario() {
        Map map = null;
        Map map2 = null;
        TargetEncoderModel targetEncoderModel = null;
        Scope.enter();
        try {
            Frame parse_test_file = parse_test_file("./smalldata/gbm_test/titanic.csv");
            TargetEncoderFrameHelper.addKFoldColumn(parse_test_file, "fold_column", 5, 1234L);
            Scope.track(new Frame[]{parse_test_file});
            asFactor(parse_test_file, "survived");
            new BlendingParams(3.0d, 1.0d);
            Frame.VecSpecifier[] vecSpecifierArr = {new Frame.VecSpecifier(parse_test_file._key, "home.dest"), new Frame.VecSpecifier(parse_test_file._key, "embarked")};
            TargetEncoderModel.TargetEncoderParameters targetEncoderParameters = new TargetEncoderModel.TargetEncoderParameters();
            targetEncoderParameters._blending = false;
            targetEncoderParameters._response_column = "survived";
            targetEncoderParameters._fold_column = "fold_column";
            targetEncoderParameters._ignored_columns = ignoredColumns(parse_test_file, new String[]{"home.dest", "embarked", targetEncoderParameters._response_column, targetEncoderParameters._fold_column});
            targetEncoderParameters.setTrain(parse_test_file._key);
            targetEncoderModel = (TargetEncoderModel) new TargetEncoderBuilder(targetEncoderParameters).trainModel().get();
            TargetEncoder targetEncoder = new TargetEncoder(Frame.VecSpecifier.vecNames(vecSpecifierArr));
            Frame parse_test_file2 = parse_test_file("./smalldata/gbm_test/titanic.csv");
            TargetEncoderFrameHelper.addKFoldColumn(parse_test_file2, "fold_column", 5, 1234L);
            asFactor(parse_test_file2, "survived");
            Scope.track(new Frame[]{parse_test_file2});
            map = targetEncoder.prepareEncodingMap(parse_test_file2, "survived", "fold_column");
            map2 = targetEncoderModel._output._target_encoding_map;
            areEncodingMapsIdentical(map, map2);
            removeEncodingMaps(map, map2);
            targetEncoderModel.remove();
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            removeEncodingMaps(map, map2);
            targetEncoderModel.remove();
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Test
    public void transform_KFold_scenario() {
        Map map = null;
        Map map2 = null;
        Scope.enter();
        try {
            Frame parse_test_file = parse_test_file("./smalldata/gbm_test/titanic.csv");
            TargetEncoderFrameHelper.addKFoldColumn(parse_test_file, "fold_column", 5, 1234L);
            Scope.track(new Frame[]{parse_test_file});
            asFactor(parse_test_file, "survived");
            TargetEncoderModel.TargetEncoderParameters targetEncoderParameters = new TargetEncoderModel.TargetEncoderParameters();
            targetEncoderParameters._blending = false;
            targetEncoderParameters._response_column = "survived";
            targetEncoderParameters._fold_column = "fold_column";
            targetEncoderParameters._seed = 1234L;
            targetEncoderParameters._ignored_columns = ignoredColumns(parse_test_file, new String[]{"home.dest", "embarked", targetEncoderParameters._response_column, targetEncoderParameters._fold_column});
            targetEncoderParameters._train = parse_test_file._key;
            TargetEncoderModel targetEncoderModel = new TargetEncoderBuilder(targetEncoderParameters).trainModel().get();
            Scope.track_generic(targetEncoderModel);
            TargetEncoder.DataLeakageHandlingStrategy dataLeakageHandlingStrategy = TargetEncoder.DataLeakageHandlingStrategy.KFold;
            Frame transform = targetEncoderModel.transform(parse_test_file, TargetEncoder.DataLeakageHandlingStrategy.KFold.getVal(), false, (BlendingParams) null, targetEncoderParameters._seed);
            Scope.track(new Frame[]{transform});
            map2 = targetEncoderModel._output._target_encoding_map;
            TargetEncoder targetEncoder = new TargetEncoder(new String[]{"embarked", "home.dest"});
            map = targetEncoder.prepareEncodingMap(parse_test_file, "survived", "fold_column", false);
            Frame applyTargetEncoding = targetEncoder.applyTargetEncoding(parse_test_file, "survived", map, dataLeakageHandlingStrategy, "fold_column", targetEncoderParameters._blending, false, TargetEncoder.DEFAULT_BLENDING_PARAMS, targetEncoderParameters._seed);
            Scope.track(new Frame[]{applyTargetEncoding});
            assertBitIdentical(transform, applyTargetEncoding);
            removeEncodingMaps(map, map2);
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            removeEncodingMaps(map, map2);
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Test
    public void columnOrderHasNoEffectWhenNoiseIsZero() {
        try {
            Scope.enter();
            Frame parse_test_file = parse_test_file("./smalldata/gbm_test/titanic.csv");
            Scope.track(new Frame[]{parse_test_file});
            asFactor(parse_test_file, "survived");
            TargetEncoderFrameHelper.addKFoldColumn(parse_test_file, "fold_column", 5, 1234L);
            RowIndexTask.addRowIndex(parse_test_file);
            DKV.put(parse_test_file);
            TargetEncoder targetEncoder = new TargetEncoder(new String[]{"home.dest", "embarked"});
            IcedHashMapGeneric prepareEncodingMap = targetEncoder.prepareEncodingMap(parse_test_file, "survived", "fold_column", false);
            TargetEncoder targetEncoder2 = new TargetEncoder(new String[]{"embarked", "home.dest"});
            IcedHashMapGeneric prepareEncodingMap2 = targetEncoder2.prepareEncodingMap(parse_test_file, "survived", "fold_column", false);
            areEncodingMapsIdentical(prepareEncodingMap, prepareEncodingMap2);
            Frame applyTargetEncoding = targetEncoder.applyTargetEncoding(parse_test_file, "survived", prepareEncodingMap, TargetEncoder.DataLeakageHandlingStrategy.KFold, "fold_column", false, 0.0d, false, TargetEncoder.DEFAULT_BLENDING_PARAMS, 1234L);
            Scope.track(new Frame[]{applyTargetEncoding});
            Frame track = Scope.track(new Frame[]{Merge.sort(applyTargetEncoding, applyTargetEncoding.find(RowIndexTask.ROW_INDEX_COL))});
            Frame applyTargetEncoding2 = targetEncoder2.applyTargetEncoding(parse_test_file, "survived", prepareEncodingMap2, TargetEncoder.DataLeakageHandlingStrategy.KFold, "fold_column", false, 0.0d, false, TargetEncoder.DEFAULT_BLENDING_PARAMS, 1234L);
            Scope.track(new Frame[]{applyTargetEncoding2});
            Frame track2 = Scope.track(new Frame[]{Merge.sort(applyTargetEncoding2, applyTargetEncoding2.find(RowIndexTask.ROW_INDEX_COL))});
            removeEncodingMaps(prepareEncodingMap, prepareEncodingMap2);
            assertBitIdentical(track, new Frame(track.names(), track2.vecs(track.names())));
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    private void removeEncodingMaps(Map<String, Frame> map, Map<String, Frame> map2) {
        if (map != null) {
            TargetEncoderFrameHelper.encodingMapCleanUp(map);
        }
        if (map2 != null) {
            TargetEncoderFrameHelper.encodingMapCleanUp(map2);
        }
    }

    private void areEncodingMapsIdentical(Map<String, Frame> map, Map<String, Frame> map2) {
        for (Map.Entry<String, Frame> entry : map2.entrySet()) {
            assertBitIdentical(entry.getValue(), map.get(entry.getKey()));
        }
    }
}
