package hex;

import java.util.Random;
import org.junit.Assert;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.junit.runner.RunWith;
import water.Key;
import water.Scope;
import water.TestUtil;
import water.fvec.Vec;
import water.runner.CloudSize;
import water.runner.H2ORunner;
import water.util.Log;
import water.util.TwoDimTable;

@CloudSize(1)
@RunWith(H2ORunner.class)
/* loaded from: input_file:hex/GainsLiftTest.class */
public class GainsLiftTest extends TestUtil {

    @Rule
    public ExpectedException expectedException = ExpectedException.none();

    @Test
    public void constant() {
        double[] dArr = new double[100000];
        long[] jArr = new long[100000];
        Random random = new Random(912559L);
        for (int i = 0; i < 100000; i++) {
            jArr[i] = random.nextDouble() > 0.8d ? 1L : 0L;
            dArr[i] = 0.343424d;
        }
        Vec makeVec = Vec.makeVec(jArr, new String[]{"N", "Y"}, Vec.newKey());
        Vec makeVec2 = Vec.makeVec(dArr, Vec.newKey());
        GainsLift gainsLift = new GainsLift(makeVec2, makeVec);
        gainsLift._groups = 10;
        gainsLift.exec();
        Log.info(new Object[]{gainsLift});
        Assert.assertTrue(gainsLift.response_rates[0] == gainsLift.avg_response_rate);
        makeVec.remove();
        makeVec2.remove();
    }

    @Test
    public void good() {
        double[] dArr = new double[100000];
        long[] jArr = new long[100000];
        Random random = new Random(912559L);
        for (int i = 0; i < 100000; i++) {
            jArr[i] = random.nextDouble() > 0.8d ? 1L : 0L;
            dArr[i] = jArr[i] == 0 ? 0.5d * random.nextDouble() : 0.5d + (random.nextDouble() * 0.5d);
        }
        Vec makeVec = Vec.makeVec(jArr, new String[]{"N", "Y"}, Vec.newKey());
        Vec makeVec2 = Vec.makeVec(dArr, Vec.newKey());
        GainsLift gainsLift = new GainsLift(makeVec2, makeVec);
        gainsLift._groups = 10;
        gainsLift.exec();
        Log.info(new Object[]{gainsLift});
        for (int i2 = 0; i2 < 2; i2++) {
            Assert.assertTrue(gainsLift.response_rates[i2] > 0.9d);
        }
        for (int i3 = 2; i3 < gainsLift.response_rates.length; i3++) {
            Assert.assertTrue(gainsLift.response_rates[i3] < 0.1d);
        }
        makeVec.remove();
        makeVec2.remove();
    }

    @Test
    public void bad() {
        double[] dArr = new double[100000];
        long[] jArr = new long[100000];
        Random random = new Random(912559L);
        for (int i = 0; i < 100000; i++) {
            jArr[i] = random.nextDouble() > 0.8d ? 1L : 0L;
            dArr[i] = jArr[i] == 0 ? 0.5d + (0.5d * random.nextDouble()) : 0.5d * random.nextDouble();
        }
        Vec makeVec = Vec.makeVec(jArr, new String[]{"N", "Y"}, Vec.newKey());
        Vec makeVec2 = Vec.makeVec(dArr, Vec.newKey());
        GainsLift gainsLift = new GainsLift(makeVec2, makeVec);
        gainsLift._groups = 10;
        gainsLift.exec();
        Log.info(new Object[]{gainsLift});
        for (int length = gainsLift.response_rates.length - 2; length < gainsLift.response_rates.length; length++) {
            Assert.assertTrue(gainsLift.response_rates[length] > 0.9d);
        }
        for (int i2 = 0; i2 < gainsLift.response_rates.length - 2; i2++) {
            Assert.assertTrue(gainsLift.response_rates[i2] < 0.1d);
        }
        makeVec.remove();
        makeVec2.remove();
    }

    @Test
    public void random() {
        double[] dArr = new double[100000];
        long[] jArr = new long[100000];
        Random random = new Random(912559L);
        for (int i = 0; i < 100000; i++) {
            jArr[i] = random.nextDouble() > 0.8d ? 1L : 0L;
            dArr[i] = random.nextDouble();
        }
        Vec makeVec = Vec.makeVec(jArr, new String[]{"N", "Y"}, Vec.newKey());
        Vec makeVec2 = Vec.makeVec(dArr, Vec.newKey());
        GainsLift gainsLift = new GainsLift(makeVec2, makeVec);
        gainsLift._groups = 10;
        gainsLift.exec();
        Log.info(new Object[]{gainsLift});
        for (int i2 = 0; i2 < gainsLift.response_rates.length; i2++) {
            Assert.assertTrue(gainsLift.response_rates[i2] > 0.19d && gainsLift.response_rates[i2] < 0.21d);
        }
        makeVec.remove();
        makeVec2.remove();
    }

    @Test
    public void tiesNApreds() {
        double[] dArr = new double[100000];
        long[] jArr = new long[100000];
        Random random = new Random(912559L);
        for (int i = 0; i < 100000; i++) {
            jArr[i] = random.nextDouble() > 0.8d ? 1L : 0L;
            dArr[i] = random.nextDouble() > 0.5d ? 0.7d : 0.4d;
            if (random.nextDouble() > 0.85d) {
                dArr[i] = Double.NaN;
            }
        }
        Vec makeVec = Vec.makeVec(jArr, new String[]{"N", "Y"}, Vec.newKey());
        Vec makeVec2 = Vec.makeVec(dArr, Vec.newKey());
        GainsLift gainsLift = new GainsLift(makeVec2, makeVec);
        gainsLift._groups = 10;
        gainsLift.exec();
        Log.info(new Object[]{gainsLift});
        for (int i2 = 0; i2 < gainsLift.response_rates.length; i2++) {
            Assert.assertTrue(gainsLift.response_rates[i2] > 0.19d && gainsLift.response_rates[i2] < 0.21d);
        }
        makeVec.remove();
        makeVec2.remove();
    }

    @Test
    public void tiesNAlabels() {
        double[] dArr = new double[100000];
        double[] dArr2 = new double[100000];
        Random random = new Random(912559L);
        for (int i = 0; i < 100000; i++) {
            dArr2[i] = random.nextDouble() > 0.8d ? 1.0d : 0.0d;
            dArr[i] = random.nextDouble() > 0.5d ? 0.7d : 0.4d;
            if (random.nextDouble() > 0.85d) {
                dArr2[i] = Double.NaN;
            }
        }
        Vec makeVec = Vec.makeVec(dArr2, new String[]{"N", "Y"}, Vec.newKey());
        Vec makeVec2 = Vec.makeVec(dArr, Vec.newKey());
        GainsLift gainsLift = new GainsLift(makeVec2, makeVec);
        gainsLift._groups = 10;
        gainsLift.exec();
        Log.info(new Object[]{gainsLift});
        for (int i2 = 0; i2 < gainsLift.response_rates.length; i2++) {
            Assert.assertTrue(gainsLift.response_rates[i2] > 0.19d && gainsLift.response_rates[i2] < 0.21d);
        }
        makeVec.remove();
        makeVec2.remove();
    }

    @Test
    public void tiesNAlabels_preds() {
        double[] dArr = new double[100000];
        double[] dArr2 = new double[100000];
        Random random = new Random(912559L);
        for (int i = 0; i < 100000; i++) {
            dArr2[i] = random.nextDouble() > 0.8d ? 1.0d : 0.0d;
            dArr[i] = random.nextDouble() > 0.5d ? 0.7d : 0.4d;
            if (random.nextDouble() > 0.85d) {
                dArr2[i] = Double.NaN;
            }
            if (random.nextDouble() > 0.85d) {
                dArr[i] = Double.NaN;
            }
        }
        Vec makeVec = Vec.makeVec(dArr2, new String[]{"N", "Y"}, Vec.newKey());
        Vec makeVec2 = Vec.makeVec(dArr, Vec.newKey());
        GainsLift gainsLift = new GainsLift(makeVec2, makeVec);
        gainsLift._groups = 10;
        gainsLift.exec();
        Log.info(new Object[]{gainsLift});
        for (int i2 = 0; i2 < gainsLift.response_rates.length; i2++) {
            Assert.assertTrue(gainsLift.response_rates[i2] > 0.19d && gainsLift.response_rates[i2] < 0.21d);
        }
        makeVec.remove();
        makeVec2.remove();
    }

    @Test
    public void imbalanced() {
        double[] dArr = new double[2 * 50000];
        long[] jArr = new long[2 * 50000];
        Random random = new Random(912559L);
        for (int i = 0; i < 50000; i++) {
            jArr[i] = random.nextDouble() > 0.8d ? 1L : 0L;
            dArr[i] = random.nextDouble() * 1.0E-7d;
        }
        for (int i2 = 50000; i2 < 2 * 50000; i2++) {
            jArr[i2] = random.nextDouble() > 0.8d ? 1L : 0L;
            dArr[i2] = (1.0d - 1.0E-7d) + (1.0E-7d * random.nextDouble());
        }
        Vec makeVec = Vec.makeVec(jArr, new String[]{"N", "Y"}, Vec.newKey());
        Vec makeVec2 = Vec.makeVec(dArr, Vec.newKey());
        GainsLift gainsLift = new GainsLift(makeVec2, makeVec);
        gainsLift._groups = 10;
        gainsLift.exec();
        Log.info(new Object[]{gainsLift});
        for (int i3 = 0; i3 < gainsLift.response_rates.length; i3++) {
            Assert.assertTrue(gainsLift.response_rates[i3] > 0.19d && gainsLift.response_rates[i3] < 0.21d);
        }
        makeVec.remove();
        makeVec2.remove();
    }

    @Test
    public void rareEvents() {
        double[] dArr = new double[100000];
        long[] jArr = new long[100000];
        Random random = new Random(912559L);
        for (int i = 0; i < 100000; i++) {
            jArr[i] = random.nextDouble() > 0.999d ? 1L : 0L;
            dArr[i] = jArr[i] == 0 ? 0.5d * random.nextDouble() : 0.5d + (random.nextDouble() * 0.5d);
        }
        Vec makeVec = Vec.makeVec(jArr, new String[]{"N", "Y"}, Vec.newKey());
        Vec makeVec2 = Vec.makeVec(dArr, Vec.newKey());
        GainsLift gainsLift = new GainsLift(makeVec2, makeVec);
        gainsLift._groups = 10;
        gainsLift.exec();
        Log.info(new Object[]{gainsLift});
        Assert.assertTrue(gainsLift.response_rates[0] <= 0.011d && gainsLift.response_rates[0] >= 0.009d);
        for (int i2 = 1; i2 < gainsLift.response_rates.length; i2++) {
            Assert.assertTrue(gainsLift.response_rates[i2] == 0.0d);
        }
        makeVec.remove();
        makeVec2.remove();
    }

    @Test
    public void rareEvents20() {
        double[] dArr = new double[100000];
        long[] jArr = new long[100000];
        Random random = new Random(912559L);
        for (int i = 0; i < 100000; i++) {
            jArr[i] = random.nextDouble() > 0.999d ? 1L : 0L;
            dArr[i] = jArr[i] == 0 ? 0.5d * random.nextDouble() : 0.5d + (random.nextDouble() * 0.5d);
        }
        Vec makeVec = Vec.makeVec(jArr, new String[]{"N", "Y"}, Vec.newKey());
        Vec makeVec2 = Vec.makeVec(dArr, Vec.newKey());
        GainsLift gainsLift = new GainsLift(makeVec2, makeVec);
        gainsLift._groups = 20;
        gainsLift.exec();
        Log.info(new Object[]{gainsLift});
        Assert.assertTrue(gainsLift.response_rates[0] <= 0.022d && gainsLift.response_rates[0] >= 0.018d);
        for (int i2 = 1; i2 < gainsLift.response_rates.length; i2++) {
            Assert.assertTrue(gainsLift.response_rates[i2] == 0.0d);
        }
        makeVec.remove();
        makeVec2.remove();
    }

    @Test
    public void testAverageResponseRate() {
        try {
            Scope.enter();
            Vec track = Scope.track(Vec.makeCon(1.0d, 10L));
            track.set(0L, 0.0d);
            Vec track2 = Scope.track(Vec.makeCon(1.0d, 10L));
            Vec track3 = Scope.track(Vec.makeCon(1.0d, 10L));
            track3.set(0L, 0.0d);
            GainsLift gainsLift = new GainsLift(track2, track, track3);
            gainsLift.exec();
            Log.info(new Object[]{gainsLift});
            Log.info(new Object[]{Double.valueOf(gainsLift.avg_response_rate)});
            Assert.assertEquals(1.0d, gainsLift.avg_response_rate, 0.0d);
            TwoDimTable createTwoDimTable = gainsLift.createTwoDimTable();
            Assert.assertEquals("Kolmogorov Smirnov", createTwoDimTable.getColHeaders()[13]);
            Assert.assertEquals(Double.valueOf(1.0d), createTwoDimTable.getCellValues()[0][13].get());
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Test
    public void testActualLabelCardinality() {
        Scope.enter();
        Vec makeCon = Vec.makeCon(1.0d, 10L);
        Vec makeCon2 = Vec.makeCon(1.0d, 10L);
        try {
            this.expectedException.expect(IllegalArgumentException.class);
            this.expectedException.expectMessage("Actual column must contain binary class labels, but found cardinality 1!");
            new GainsLift(makeCon2, makeCon).exec();
            makeCon.remove();
            makeCon2.remove();
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            makeCon.remove();
            makeCon2.remove();
            Scope.exit(new Key[0]);
            throw th;
        }
    }
}
