package org.tribuo.evaluation;

import java.util.ArrayList;
import java.util.List;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.tribuo.test.MockDataSource;

/* loaded from: input_file:org/tribuo/evaluation/TrainTestSplitterTest.class */
public class TrainTestSplitterTest {
    private static final long seed = 0;

    @Test
    public void testSplitter_emptyDataSource() {
        TrainTestSplitter trainTestSplitter = new TrainTestSplitter(new MockDataSource(0), 0.7d, seed);
        Assertions.assertEquals(0, trainTestSplitter.totalSize());
        Assertions.assertEquals(0, sizeOf(trainTestSplitter.getTrain()));
        Assertions.assertEquals(0, sizeOf(trainTestSplitter.getTest()));
    }

    @Test
    public void testSplitter_singletonDataSource() {
        TrainTestSplitter trainTestSplitter = new TrainTestSplitter(new MockDataSource(1), 0.7d, seed);
        Assertions.assertEquals(1, trainTestSplitter.totalSize());
        Assertions.assertEquals(0, sizeOf(trainTestSplitter.getTrain()));
        Assertions.assertEquals(1, sizeOf(trainTestSplitter.getTest()));
    }

    @Test
    public void testSplitter() {
        TrainTestSplitter trainTestSplitter = new TrainTestSplitter(new MockDataSource(10), 0.5d, seed);
        Assertions.assertEquals(10, trainTestSplitter.totalSize());
        Assertions.assertEquals(5, sizeOf(trainTestSplitter.getTrain()));
        Assertions.assertEquals(5, sizeOf(trainTestSplitter.getTest()));
    }

    @Test
    public void testSplitter_indivisibleTrainProportion() {
        TrainTestSplitter trainTestSplitter = new TrainTestSplitter(new MockDataSource(11), 0.5d, seed);
        Assertions.assertEquals(11, trainTestSplitter.totalSize());
        Assertions.assertEquals(5, sizeOf(trainTestSplitter.getTrain()));
        Assertions.assertEquals(6, sizeOf(trainTestSplitter.getTest()));
    }

    private static <T> List<T> list(Iterable<T> iterable) {
        ArrayList arrayList = new ArrayList();
        arrayList.getClass();
        iterable.forEach(arrayList::add);
        return arrayList;
    }

    private static <T> int sizeOf(Iterable<T> iterable) {
        int i = 0;
        for (T t : iterable) {
            i++;
        }
        return i;
    }
}
