package org.tribuo.transform;

import java.util.Collections;
import java.util.HashMap;
import java.util.Random;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.tribuo.FeatureMap;
import org.tribuo.MutableDataset;
import org.tribuo.impl.ArrayExample;
import org.tribuo.test.MockDataSourceProvenance;
import org.tribuo.test.MockOutput;
import org.tribuo.test.MockOutputFactory;
import org.tribuo.transform.transformations.BinningTransformation;

/* loaded from: input_file:org/tribuo/transform/BinningTest.class */
public class BinningTest {
    public static MutableDataset<MockOutput> generateRandomDenseDataset(int i) {
        MutableDataset<MockOutput> mutableDataset = new MutableDataset<>(new MockDataSourceProvenance(), new MockOutputFactory());
        Random random = new Random(i);
        MockOutput mockOutput = new MockOutput("UNK");
        String[] strArr = {"F0", "F1"};
        for (int i2 = 0; i2 < 10000; i2++) {
            mutableDataset.add(new ArrayExample(mockOutput, strArr, new double[]{(random.nextGaussian() * 5.0d) + 10.0d, random.nextDouble() * (-20.0d)}));
        }
        return mutableDataset;
    }

    @Test
    public void testEqualWidthBinning() {
        TransformationMap transformationMap = new TransformationMap(Collections.singletonList(BinningTransformation.equalWidth(5)), new HashMap());
        MutableDataset<MockOutput> generateRandomDenseDataset = generateRandomDenseDataset(1);
        MutableDataset<MockOutput> generateRandomDenseDataset2 = generateRandomDenseDataset(2);
        TransformerMap createTransformers = generateRandomDenseDataset.createTransformers(transformationMap);
        MutableDataset transformDataset = createTransformers.transformDataset(generateRandomDenseDataset);
        MutableDataset transformDataset2 = createTransformers.transformDataset(generateRandomDenseDataset2);
        FeatureMap featureMap = transformDataset.getFeatureMap();
        FeatureMap featureMap2 = transformDataset2.getFeatureMap();
        Assertions.assertEquals(140L, featureMap.get("F0").getObservationCount(1.0d));
        Assertions.assertEquals(2349L, featureMap.get("F0").getObservationCount(2.0d));
        Assertions.assertEquals(5463L, featureMap.get("F0").getObservationCount(3.0d));
        Assertions.assertEquals(1946L, featureMap.get("F0").getObservationCount(4.0d));
        Assertions.assertEquals(102L, featureMap.get("F0").getObservationCount(5.0d));
        Assertions.assertEquals(2015L, featureMap.get("F1").getObservationCount(1.0d));
        Assertions.assertEquals(1985L, featureMap.get("F1").getObservationCount(2.0d));
        Assertions.assertEquals(2034L, featureMap.get("F1").getObservationCount(3.0d));
        Assertions.assertEquals(1979L, featureMap.get("F1").getObservationCount(4.0d));
        Assertions.assertEquals(1987L, featureMap.get("F1").getObservationCount(5.0d));
        Assertions.assertEquals(154L, featureMap2.get("F0").getObservationCount(1.0d));
        Assertions.assertEquals(2364L, featureMap2.get("F0").getObservationCount(2.0d));
        Assertions.assertEquals(5523L, featureMap2.get("F0").getObservationCount(3.0d));
        Assertions.assertEquals(1867L, featureMap2.get("F0").getObservationCount(4.0d));
        Assertions.assertEquals(92L, featureMap2.get("F0").getObservationCount(5.0d));
        Assertions.assertEquals(1976L, featureMap2.get("F1").getObservationCount(1.0d));
        Assertions.assertEquals(2004L, featureMap2.get("F1").getObservationCount(2.0d));
        Assertions.assertEquals(2033L, featureMap2.get("F1").getObservationCount(3.0d));
        Assertions.assertEquals(1980L, featureMap2.get("F1").getObservationCount(4.0d));
        Assertions.assertEquals(2007L, featureMap2.get("F1").getObservationCount(5.0d));
    }

    @Test
    public void testMediansBinning() {
        TransformationMap transformationMap = new TransformationMap(Collections.singletonList(BinningTransformation.equalFrequency(5)), new HashMap());
        MutableDataset<MockOutput> generateRandomDenseDataset = generateRandomDenseDataset(1);
        MutableDataset<MockOutput> generateRandomDenseDataset2 = generateRandomDenseDataset(2);
        TransformerMap createTransformers = generateRandomDenseDataset.createTransformers(transformationMap);
        MutableDataset transformDataset = createTransformers.transformDataset(generateRandomDenseDataset);
        MutableDataset transformDataset2 = createTransformers.transformDataset(generateRandomDenseDataset2);
        FeatureMap featureMap = transformDataset.getFeatureMap();
        FeatureMap featureMap2 = transformDataset2.getFeatureMap();
        Assertions.assertEquals(2001L, featureMap.get("F0").getObservationCount(1.0d));
        Assertions.assertEquals(2000L, featureMap.get("F0").getObservationCount(2.0d));
        Assertions.assertEquals(2000L, featureMap.get("F0").getObservationCount(3.0d));
        Assertions.assertEquals(2000L, featureMap.get("F0").getObservationCount(4.0d));
        Assertions.assertEquals(1999L, featureMap.get("F0").getObservationCount(5.0d));
        Assertions.assertEquals(2001L, featureMap.get("F1").getObservationCount(1.0d));
        Assertions.assertEquals(2000L, featureMap.get("F1").getObservationCount(2.0d));
        Assertions.assertEquals(2000L, featureMap.get("F1").getObservationCount(3.0d));
        Assertions.assertEquals(2000L, featureMap.get("F1").getObservationCount(4.0d));
        Assertions.assertEquals(1999L, featureMap.get("F1").getObservationCount(5.0d));
        Assertions.assertEquals(2024L, featureMap2.get("F0").getObservationCount(1.0d));
        Assertions.assertEquals(2061L, featureMap2.get("F0").getObservationCount(2.0d));
        Assertions.assertEquals(2038L, featureMap2.get("F0").getObservationCount(3.0d));
        Assertions.assertEquals(1955L, featureMap2.get("F0").getObservationCount(4.0d));
        Assertions.assertEquals(1922L, featureMap2.get("F0").getObservationCount(5.0d));
        Assertions.assertEquals(1962L, featureMap2.get("F1").getObservationCount(1.0d));
        Assertions.assertEquals(2019L, featureMap2.get("F1").getObservationCount(2.0d));
        Assertions.assertEquals(2001L, featureMap2.get("F1").getObservationCount(3.0d));
        Assertions.assertEquals(1995L, featureMap2.get("F1").getObservationCount(4.0d));
        Assertions.assertEquals(2023L, featureMap2.get("F1").getObservationCount(5.0d));
    }

    @Test
    public void testStdDevBinning() {
        TransformationMap transformationMap = new TransformationMap(Collections.singletonList(BinningTransformation.stdDevs(3)), new HashMap());
        MutableDataset<MockOutput> generateRandomDenseDataset = generateRandomDenseDataset(1);
        MutableDataset<MockOutput> generateRandomDenseDataset2 = generateRandomDenseDataset(2);
        TransformerMap createTransformers = generateRandomDenseDataset.createTransformers(transformationMap);
        MutableDataset transformDataset = createTransformers.transformDataset(generateRandomDenseDataset);
        MutableDataset transformDataset2 = createTransformers.transformDataset(generateRandomDenseDataset2);
        FeatureMap featureMap = transformDataset.getFeatureMap();
        FeatureMap featureMap2 = transformDataset2.getFeatureMap();
        Assertions.assertEquals(221L, featureMap.get("F0").getObservationCount(1.0d));
        Assertions.assertEquals(1383L, featureMap.get("F0").getObservationCount(2.0d));
        Assertions.assertEquals(3391L, featureMap.get("F0").getObservationCount(3.0d));
        Assertions.assertEquals(3415L, featureMap.get("F0").getObservationCount(4.0d));
        Assertions.assertEquals(1365L, featureMap.get("F0").getObservationCount(5.0d));
        Assertions.assertEquals(225L, featureMap.get("F0").getObservationCount(6.0d));
        Assertions.assertEquals(0L, featureMap.get("F1").getObservationCount(1.0d));
        Assertions.assertEquals(2142L, featureMap.get("F1").getObservationCount(2.0d));
        Assertions.assertEquals(2863L, featureMap.get("F1").getObservationCount(3.0d));
        Assertions.assertEquals(2889L, featureMap.get("F1").getObservationCount(4.0d));
        Assertions.assertEquals(2106L, featureMap.get("F1").getObservationCount(5.0d));
        Assertions.assertEquals(0L, featureMap.get("F1").getObservationCount(6.0d));
        Assertions.assertEquals(222L, featureMap2.get("F0").getObservationCount(1.0d));
        Assertions.assertEquals(1382L, featureMap2.get("F0").getObservationCount(2.0d));
        Assertions.assertEquals(3482L, featureMap2.get("F0").getObservationCount(3.0d));
        Assertions.assertEquals(3399L, featureMap2.get("F0").getObservationCount(4.0d));
        Assertions.assertEquals(1294L, featureMap2.get("F0").getObservationCount(5.0d));
        Assertions.assertEquals(221L, featureMap2.get("F0").getObservationCount(6.0d));
        Assertions.assertEquals(0L, featureMap2.get("F1").getObservationCount(1.0d));
        Assertions.assertEquals(2102L, featureMap2.get("F1").getObservationCount(2.0d));
        Assertions.assertEquals(2917L, featureMap2.get("F1").getObservationCount(3.0d));
        Assertions.assertEquals(2868L, featureMap2.get("F1").getObservationCount(4.0d));
        Assertions.assertEquals(2113L, featureMap2.get("F1").getObservationCount(5.0d));
        Assertions.assertEquals(0L, featureMap2.get("F1").getObservationCount(6.0d));
    }
}
