package hex.genmodel.algos.targetencoder;

import hex.genmodel.easy.DomainMapConstructor;
import hex.genmodel.easy.RowData;
import hex.genmodel.easy.RowToRawDataConverter;
import hex.genmodel.easy.error.VoidErrorConsumer;
import hex.genmodel.easy.exception.PredictException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Random;
import org.junit.Assert;
import org.junit.Test;

/* loaded from: input_file:hex/genmodel/algos/targetencoder/TargetEncoderMojoModelTest.class */
public class TargetEncoderMojoModelTest {
    @Test
    public void computeLambda() {
        Assert.assertEquals(0.5d, TargetEncoderMojoModel.computeLambda(5, 5.0d, 1.0d), 1.0E-5d);
        Assert.assertEquals(0.0d, TargetEncoderMojoModel.computeLambda(1, 15.0d, 1.0d), 1.0E-5d);
        Assert.assertEquals(1.0d, TargetEncoderMojoModel.computeLambda(20, 5.0d, 1.0d), 1.0E-5d);
    }

    @Test
    public void computeBlendedEncoding() {
        Random random = new Random();
        double computeBlendedEncoding = TargetEncoderMojoModel.computeBlendedEncoding(random.nextDouble(), random.nextDouble(), random.nextDouble());
        Assert.assertTrue(computeBlendedEncoding >= 0.0d && computeBlendedEncoding <= 1.0d);
        Assert.assertEquals(0.5d, TargetEncoderMojoModel.computeBlendedEncoding(0.5d, 1.0d, 0.0d), 1.0E-5d);
        Assert.assertEquals(1.0d, TargetEncoderMojoModel.computeBlendedEncoding(0.5d, 1.0d, 1.0d), 1.0E-5d);
        Assert.assertEquals(0.0d, TargetEncoderMojoModel.computeBlendedEncoding(0.5d, 0.0d, 0.0d), 1.0E-5d);
        Assert.assertEquals(0.62d, TargetEncoderMojoModel.computeBlendedEncoding(0.1d, 0.8d, 0.6d), 1.0E-5d);
    }

    @Test
    public void transformWithBlending() throws PredictException {
        String[][] strArr = new String[3][2];
        strArr[0] = null;
        strArr[1][0] = "A";
        strArr[1][1] = "B";
        strArr[2] = null;
        String[] strArr2 = {"numerical_col1", "categ_var1", "numerical_col2"};
        TargetEncoderMojoModel targetEncoderMojoModel = new TargetEncoderMojoModel(strArr2, strArr, (String) null);
        targetEncoderMojoModel._nfeatures = strArr2.length;
        EncodingMaps encodingMaps = new EncodingMaps();
        EncodingMap encodingMap = new EncodingMap();
        encodingMap.put(0, new int[]{2, 5});
        encodingMap.put(1, new int[]{3, 6});
        encodingMap.put(2, new int[]{4, 7});
        encodingMaps.put("categ_var1", encodingMap);
        targetEncoderMojoModel._targetEncodingMap = encodingMaps;
        targetEncoderMojoModel._withBlending = true;
        targetEncoderMojoModel._inflectionPoint = 5.0d;
        targetEncoderMojoModel._smoothing = 1.0d;
        targetEncoderMojoModel._priorMean = 0.5d;
        targetEncoderMojoModel._teColumnNameToIdx = new HashMap();
        targetEncoderMojoModel._teColumnNameToIdx.put("categ_var1", 1);
        VoidErrorConsumer voidErrorConsumer = new VoidErrorConsumer();
        HashMap hashMap = new HashMap();
        hashMap.put("categ_var1", 1);
        hashMap.put("numerical_col1", 0);
        hashMap.put("numerical_col2", 2);
        RowToRawDataConverter rowToRawDataConverter = new RowToRawDataConverter(targetEncoderMojoModel, hashMap, new DomainMapConstructor(targetEncoderMojoModel).create(), voidErrorConsumer, true, true);
        RowData rowData = new RowData();
        rowData.put("numerical_col1", Double.valueOf(42.0d));
        rowData.put("categ_var1", "A");
        rowData.put("numerical_col2", Double.valueOf(10.0d));
        double[] convert = rowToRawDataConverter.convert(rowData, nanArray(3));
        double[] dArr = new double[1];
        targetEncoderMojoModel.score0(convert, dArr);
        Assert.assertEquals(0.45d, dArr[0], 1.0E-5d);
        targetEncoderMojoModel._inflectionPoint = 12.0d;
        RowData rowData2 = new RowData();
        rowData2.put("numerical_col1", Double.valueOf(42.0d));
        rowData2.put("categ_var1", "B");
        rowData2.put("numerical_col2", Double.valueOf(10.0d));
        double[] convert2 = rowToRawDataConverter.convert(rowData2, nanArray(3));
        double[] dArr2 = new double[1];
        targetEncoderMojoModel.score0(convert2, dArr2);
        Assert.assertEquals(0.5d, dArr2[0], 1.0E-5d);
    }

    @Test
    public void transformWithoutBlending() throws PredictException {
        String[][] strArr = new String[3][2];
        strArr[0] = null;
        strArr[1][0] = "A";
        strArr[1][1] = "B";
        strArr[2] = null;
        String[] strArr2 = {"numerical_col1", "categ_var1", "numerical_col2"};
        TargetEncoderMojoModel targetEncoderMojoModel = new TargetEncoderMojoModel(strArr2, strArr, (String) null);
        targetEncoderMojoModel._nfeatures = strArr2.length;
        EncodingMaps encodingMaps = new EncodingMaps();
        EncodingMap encodingMap = new EncodingMap();
        encodingMap.put(0, new int[]{2, 5});
        encodingMap.put(1, new int[]{3, 7});
        encodingMaps.put("categ_var1", encodingMap);
        targetEncoderMojoModel._targetEncodingMap = encodingMaps;
        targetEncoderMojoModel._withBlending = false;
        targetEncoderMojoModel._teColumnNameToIdx = new HashMap();
        targetEncoderMojoModel._teColumnNameToIdx.put("categ_var1", 1);
        VoidErrorConsumer voidErrorConsumer = new VoidErrorConsumer();
        HashMap hashMap = new HashMap();
        hashMap.put("categ_var1", 1);
        hashMap.put("numerical_col1", 0);
        hashMap.put("numerical_col2", 2);
        RowToRawDataConverter rowToRawDataConverter = new RowToRawDataConverter(targetEncoderMojoModel, hashMap, new DomainMapConstructor(targetEncoderMojoModel).create(), voidErrorConsumer, true, true);
        RowData rowData = new RowData();
        rowData.put("numerical_col1", Double.valueOf(42.0d));
        rowData.put("categ_var1", "A");
        rowData.put("numerical_col2", Double.valueOf(10.0d));
        double[] convert = rowToRawDataConverter.convert(rowData, nanArray(3));
        double[] dArr = new double[1];
        targetEncoderMojoModel.score0(convert, dArr);
        Assert.assertEquals(0.4d, dArr[0], 1.0E-5d);
    }

    @Test
    public void transform_unknown_categories_when_training_data_had_missing_or_unexpected_values() throws PredictException {
        String[][] strArr = new String[3][2];
        strArr[0] = null;
        strArr[1][0] = "A";
        strArr[1][1] = "B";
        strArr[2] = null;
        String[] strArr2 = {"numerical_col1", "categ_var1", "numerical_col2"};
        TargetEncoderMojoModel targetEncoderMojoModel = new TargetEncoderMojoModel(strArr2, strArr, (String) null);
        targetEncoderMojoModel._nfeatures = strArr2.length;
        EncodingMaps encodingMaps = new EncodingMaps();
        EncodingMap encodingMap = new EncodingMap();
        encodingMap.put(0, new int[]{2, 5});
        encodingMap.put(1, new int[]{3, 7});
        encodingMap.put(2, new int[]{6, 8});
        encodingMaps.put("categ_var1", encodingMap);
        targetEncoderMojoModel._targetEncodingMap = encodingMaps;
        targetEncoderMojoModel._withBlending = false;
        targetEncoderMojoModel._teColumnNameToIdx = new HashMap();
        targetEncoderMojoModel._teColumnNameToIdx.put("categ_var1", 1);
        targetEncoderMojoModel._teColumnNameToMissingValuesPresence = new HashMap();
        targetEncoderMojoModel._teColumnNameToMissingValuesPresence.put("categ_var1", 1);
        VoidErrorConsumer voidErrorConsumer = new VoidErrorConsumer();
        HashMap hashMap = new HashMap();
        hashMap.put("categ_var1", 1);
        hashMap.put("numerical_col1", 0);
        hashMap.put("numerical_col2", 2);
        RowToRawDataConverter rowToRawDataConverter = new RowToRawDataConverter(targetEncoderMojoModel, hashMap, new DomainMapConstructor(targetEncoderMojoModel).create(), voidErrorConsumer, true, true);
        RowData rowData = new RowData();
        rowData.put("numerical_col1", Double.valueOf(42.0d));
        rowData.put("categ_var1", "C");
        rowData.put("numerical_col2", Double.valueOf(10.0d));
        double[] convert = rowToRawDataConverter.convert(rowData, nanArray(3));
        double[] dArr = new double[1];
        targetEncoderMojoModel.score0(convert, dArr);
        Assert.assertEquals(0.75d, dArr[0], 1.0E-5d);
        RowData rowData2 = new RowData();
        rowData2.put("numerical_col1", Double.valueOf(42.0d));
        rowData2.put("categ_var1", Double.valueOf(Double.NaN));
        rowData2.put("numerical_col2", Double.valueOf(10.0d));
        targetEncoderMojoModel.score0(rowToRawDataConverter.convert(rowData2, nanArray(3)), new double[1]);
        Assert.assertEquals(0.75d, dArr[0], 1.0E-5d);
    }

    @Test
    public void transform_unknown_categories_when_training_data_does_not_have_missing_values() throws PredictException {
        String[][] strArr = new String[3][2];
        strArr[0] = null;
        strArr[1][0] = "A";
        strArr[1][1] = "B";
        strArr[2] = null;
        String[] strArr2 = {"numerical_col1", "categ_var1", "numerical_col2"};
        TargetEncoderMojoModel targetEncoderMojoModel = new TargetEncoderMojoModel(strArr2, strArr, (String) null);
        targetEncoderMojoModel._nfeatures = strArr2.length;
        EncodingMaps encodingMaps = new EncodingMaps();
        EncodingMap encodingMap = new EncodingMap();
        encodingMap.put(0, new int[]{2, 5});
        encodingMap.put(1, new int[]{3, 7});
        encodingMaps.put("categ_var1", encodingMap);
        targetEncoderMojoModel._targetEncodingMap = encodingMaps;
        targetEncoderMojoModel._withBlending = false;
        targetEncoderMojoModel._teColumnNameToIdx = new HashMap();
        targetEncoderMojoModel._teColumnNameToIdx.put("categ_var1", 1);
        targetEncoderMojoModel._teColumnNameToMissingValuesPresence = new HashMap();
        targetEncoderMojoModel._teColumnNameToMissingValuesPresence.put("categ_var1", 0);
        targetEncoderMojoModel._priorMean = 0.4166666666666667d;
        VoidErrorConsumer voidErrorConsumer = new VoidErrorConsumer();
        HashMap hashMap = new HashMap();
        hashMap.put("categ_var1", 1);
        hashMap.put("numerical_col1", 0);
        hashMap.put("numerical_col2", 2);
        RowToRawDataConverter rowToRawDataConverter = new RowToRawDataConverter(targetEncoderMojoModel, hashMap, new DomainMapConstructor(targetEncoderMojoModel).create(), voidErrorConsumer, true, true);
        RowData rowData = new RowData();
        rowData.put("numerical_col1", Double.valueOf(42.0d));
        rowData.put("categ_var1", "C");
        rowData.put("numerical_col2", Double.valueOf(10.0d));
        double[] convert = rowToRawDataConverter.convert(rowData, nanArray(3));
        double[] dArr = new double[1];
        targetEncoderMojoModel.score0(convert, dArr);
        Assert.assertEquals(0.4166666666666667d, dArr[0], 1.0E-5d);
        RowData rowData2 = new RowData();
        rowData2.put("numerical_col1", Double.valueOf(42.0d));
        rowData2.put("categ_var1", Double.valueOf(Double.NaN));
        rowData2.put("numerical_col2", Double.valueOf(10.0d));
        targetEncoderMojoModel.score0(rowToRawDataConverter.convert(rowData2, nanArray(3)), new double[1]);
        Assert.assertEquals(0.4166666666666667d, dArr[0], 1.0E-5d);
    }

    @Test
    public void sortEncodingMapByIndex() {
        TargetEncoderMojoModel targetEncoderMojoModel = new TargetEncoderMojoModel(new String[0], new String[0][0], (String) null);
        EncodingMaps encodingMaps = new EncodingMaps();
        EncodingMap encodingMap = new EncodingMap();
        encodingMap.put(0, new int[]{2, 5});
        encodingMap.put(1, new int[]{3, 7});
        encodingMaps.put("categ_var1", encodingMap);
        encodingMaps.put("categ_var2", encodingMap);
        encodingMaps.put("categ_var3", encodingMap);
        HashMap hashMap = new HashMap();
        hashMap.put("categ_var1", 42);
        hashMap.put("categ_var2", 100);
        hashMap.put("categ_var3", 7);
        targetEncoderMojoModel._teColumnNameToIdx = hashMap;
        LinkedHashMap sortByColumnIndex = targetEncoderMojoModel.sortByColumnIndex(encodingMaps.encodingMap());
        ArrayList arrayList = new ArrayList();
        Iterator it = sortByColumnIndex.entrySet().iterator();
        while (it.hasNext()) {
            arrayList.add(((Map.Entry) it.next()).getKey());
        }
        Assert.assertArrayEquals(new String[]{"categ_var3", "categ_var1", "categ_var2"}, arrayList.toArray(new String[0]));
    }

    private static double[] nanArray(int i) {
        double[] dArr = new double[i];
        for (int i2 = 0; i2 < i; i2++) {
            dArr[i2] = Double.NaN;
        }
        return dArr;
    }
}
