package org.tribuo;

import java.util.SplittableRandom;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;

/* loaded from: input_file:org/tribuo/CategoricalInfoTest.class */
public class CategoricalInfoTest {
    private static final double DELTA = 1.0E-12d;
    private static final int NUM_SAMPLES = 5000;

    public static CategoricalInfo generateFullInfo() {
        CategoricalInfo categoricalInfo = new CategoricalInfo("test");
        for (int i = 0; i < 5; i++) {
            categoricalInfo.observe(-1.0d);
            categoricalInfo.observe(2.0d);
            categoricalInfo.observe(3.0d);
            categoricalInfo.observe(4.0d);
        }
        return categoricalInfo;
    }

    public static CategoricalInfo generateEmptyInfo() {
        return new CategoricalInfo("empty");
    }

    public static CategoricalInfo generateOneValueInfo() {
        CategoricalInfo categoricalInfo = new CategoricalInfo("one-value");
        for (int i = 0; i < 25; i++) {
            categoricalInfo.observe(5.0d);
        }
        return categoricalInfo;
    }

    public void checkValueAndProb(CategoricalInfo categoricalInfo, double d, double d2) {
        int i = -1;
        for (int i2 = 0; i2 < categoricalInfo.values.length; i2++) {
            if (Math.abs(categoricalInfo.values[i2] - d) < DELTA) {
                if (i > 0) {
                    Assertions.fail("Found value " + d + " at " + i + " and " + i2);
                } else {
                    i = i2;
                }
            }
        }
        Assertions.assertEquals(d2, i == 0 ? categoricalInfo.cdf[0] : categoricalInfo.cdf[i] - categoricalInfo.cdf[i - 1], DELTA);
    }

    @Test
    public void samplingTest() {
        SplittableRandom splittableRandom = new SplittableRandom(1L);
        CategoricalInfo generateEmptyInfo = generateEmptyInfo();
        generateEmptyInfo.frequencyBasedSample(splittableRandom, 50L);
        Assertions.assertEquals(1, generateEmptyInfo.values.length);
        Assertions.assertEquals(0.0d, generateEmptyInfo.values[0], DELTA);
        Assertions.assertEquals(1, generateEmptyInfo.cdf.length);
        Assertions.assertEquals(1.0d, generateEmptyInfo.cdf[0], DELTA);
        for (int i = 0; i < 50; i++) {
            Assertions.assertEquals(0.0d, generateEmptyInfo.frequencyBasedSample(splittableRandom, 50L), DELTA);
        }
        CategoricalInfo generateOneValueInfo = generateOneValueInfo();
        generateOneValueInfo.frequencyBasedSample(splittableRandom, 50L);
        Assertions.assertEquals(2, generateOneValueInfo.values.length);
        Assertions.assertEquals(0.0d, generateOneValueInfo.values[0], DELTA);
        Assertions.assertEquals(5.0d, generateOneValueInfo.values[1], DELTA);
        Assertions.assertEquals(2, generateOneValueInfo.values.length);
        Assertions.assertEquals(0.5d, generateOneValueInfo.cdf[0], DELTA);
        Assertions.assertEquals(1.0d, generateOneValueInfo.cdf[1], DELTA);
        double d = 0.0d;
        double d2 = 0.0d;
        for (int i2 = 0; i2 < NUM_SAMPLES; i2++) {
            double frequencyBasedSample = generateOneValueInfo.frequencyBasedSample(splittableRandom, 50L);
            if (frequencyBasedSample > DELTA) {
                d2 += 1.0d;
            }
            d += frequencyBasedSample;
        }
        Assertions.assertEquals(0.5d, d2 / 5000.0d, 0.1d);
        Assertions.assertEquals(2.5d, d / 5000.0d, 0.1d);
        CategoricalInfo generateOneValueInfo2 = generateOneValueInfo();
        generateOneValueInfo2.frequencyBasedSample(splittableRandom, 1000L);
        Assertions.assertEquals(2, generateOneValueInfo2.values.length);
        Assertions.assertEquals(0.0d, generateOneValueInfo2.values[0], DELTA);
        Assertions.assertEquals(5.0d, generateOneValueInfo2.values[1], DELTA);
        Assertions.assertEquals(2, generateOneValueInfo2.values.length);
        Assertions.assertEquals(0.975d, generateOneValueInfo2.cdf[0], DELTA);
        Assertions.assertEquals(1.0d, generateOneValueInfo2.cdf[1], DELTA);
        CategoricalInfo generateFullInfo = generateFullInfo();
        generateFullInfo.frequencyBasedSample(splittableRandom, 50L);
        Assertions.assertEquals(5, generateFullInfo.values.length);
        Assertions.assertEquals(5, generateFullInfo.cdf.length);
        checkValueAndProb(generateFullInfo, 0.0d, 0.6d);
        checkValueAndProb(generateFullInfo, -1.0d, 0.1d);
        checkValueAndProb(generateFullInfo, 2.0d, 0.1d);
        checkValueAndProb(generateFullInfo, 3.0d, 0.1d);
        checkValueAndProb(generateFullInfo, 4.0d, 0.1d);
        CategoricalInfo generateFullInfo2 = generateFullInfo();
        generateFullInfo2.frequencyBasedSample(splittableRandom, 100L);
        Assertions.assertEquals(5, generateFullInfo2.values.length);
        Assertions.assertEquals(5, generateFullInfo2.cdf.length);
        checkValueAndProb(generateFullInfo2, 0.0d, 0.8d);
        checkValueAndProb(generateFullInfo2, -1.0d, 0.05d);
        checkValueAndProb(generateFullInfo2, 2.0d, 0.05d);
        checkValueAndProb(generateFullInfo2, 3.0d, 0.05d);
        checkValueAndProb(generateFullInfo2, 4.0d, 0.05d);
    }
}
