package org.tribuo.evaluation;

import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.Provenance;
import com.oracle.labs.mlrg.olcut.provenance.primitives.BooleanProvenance;
import com.oracle.labs.mlrg.olcut.provenance.primitives.DoubleProvenance;
import com.oracle.labs.mlrg.olcut.provenance.primitives.IntProvenance;
import com.oracle.labs.mlrg.olcut.provenance.primitives.LongProvenance;
import com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance;
import com.oracle.labs.mlrg.olcut.util.Pair;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.Map;
import java.util.Objects;
import java.util.Random;
import org.tribuo.DataSource;
import org.tribuo.Example;
import org.tribuo.Output;
import org.tribuo.datasource.ListDataSource;
import org.tribuo.provenance.DataSourceProvenance;

/* loaded from: input_file:org/tribuo/evaluation/TrainTestSplitter.class */
public class TrainTestSplitter<T extends Output<T>> {
    private final DataSource<T> train;
    private final DataSource<T> test;
    private final DataSourceProvenance originalProvenance;
    private final long seed;
    private final double trainProportion;
    private final int size;

    /* loaded from: input_file:org/tribuo/evaluation/TrainTestSplitter$SplitDataSourceProvenance.class */
    public static class SplitDataSourceProvenance implements DataSourceProvenance {
        private static final long serialVersionUID = 1;
        private static final String SOURCE = "source";
        private static final String TRAIN_PROPORTION = "train-proportion";
        private static final String SEED = "seed";
        private static final String SIZE = "size";
        private static final String IS_TRAIN = "is-train";
        private final StringProvenance className;
        private final DataSourceProvenance innerSourceProvenance;
        private final DoubleProvenance trainProportion;
        private final LongProvenance seed;
        private final IntProvenance size;
        private final BooleanProvenance isTrain;

        <T extends Output<T>> SplitDataSourceProvenance(TrainTestSplitter<T> trainTestSplitter, boolean z) {
            this.className = new StringProvenance("class-name", trainTestSplitter.getClass().getName());
            this.innerSourceProvenance = ((TrainTestSplitter) trainTestSplitter).originalProvenance;
            this.trainProportion = new DoubleProvenance(TRAIN_PROPORTION, ((TrainTestSplitter) trainTestSplitter).trainProportion);
            this.seed = new LongProvenance(SEED, ((TrainTestSplitter) trainTestSplitter).seed);
            this.size = new IntProvenance(SIZE, ((TrainTestSplitter) trainTestSplitter).size);
            this.isTrain = new BooleanProvenance(IS_TRAIN, z);
        }

        public SplitDataSourceProvenance(Map<String, Provenance> map) {
            this.className = ObjectProvenance.checkAndExtractProvenance(map, "class-name", StringProvenance.class, SplitDataSourceProvenance.class.getSimpleName());
            this.innerSourceProvenance = ObjectProvenance.checkAndExtractProvenance(map, SOURCE, DataSourceProvenance.class, SplitDataSourceProvenance.class.getSimpleName());
            this.trainProportion = ObjectProvenance.checkAndExtractProvenance(map, TRAIN_PROPORTION, DoubleProvenance.class, SplitDataSourceProvenance.class.getSimpleName());
            this.seed = ObjectProvenance.checkAndExtractProvenance(map, SEED, LongProvenance.class, SplitDataSourceProvenance.class.getSimpleName());
            this.size = ObjectProvenance.checkAndExtractProvenance(map, SIZE, IntProvenance.class, SplitDataSourceProvenance.class.getSimpleName());
            this.isTrain = ObjectProvenance.checkAndExtractProvenance(map, IS_TRAIN, BooleanProvenance.class, SplitDataSourceProvenance.class.getSimpleName());
        }

        public String getClassName() {
            return this.className.getValue();
        }

        public Iterator<Pair<String, Provenance>> iterator() {
            ArrayList arrayList = new ArrayList();
            arrayList.add(new Pair("class-name", this.className));
            arrayList.add(new Pair(SOURCE, this.innerSourceProvenance));
            arrayList.add(new Pair(TRAIN_PROPORTION, this.trainProportion));
            arrayList.add(new Pair(SEED, this.seed));
            arrayList.add(new Pair(SIZE, this.size));
            arrayList.add(new Pair(IS_TRAIN, this.isTrain));
            return arrayList.iterator();
        }

        public boolean equals(Object obj) {
            if (this == obj) {
                return true;
            }
            if (!(obj instanceof SplitDataSourceProvenance)) {
                return false;
            }
            SplitDataSourceProvenance splitDataSourceProvenance = (SplitDataSourceProvenance) obj;
            return this.className.equals(splitDataSourceProvenance.className) && this.innerSourceProvenance.equals(splitDataSourceProvenance.innerSourceProvenance) && this.trainProportion.equals(splitDataSourceProvenance.trainProportion) && this.seed.equals(splitDataSourceProvenance.seed) && this.size.equals(splitDataSourceProvenance.size) && this.isTrain.equals(splitDataSourceProvenance.isTrain);
        }

        public int hashCode() {
            return Objects.hash(this.className, this.innerSourceProvenance, this.trainProportion, this.seed, this.size, this.isTrain);
        }

        public String toString() {
            return "SplitDataSourceProvenance(className=" + this.className + ",innerSourceProvenance=" + this.innerSourceProvenance + ",trainProportion=" + this.trainProportion + ",seed=" + this.seed + ",size=" + this.size + ",isTrain=" + this.isTrain + ')';
        }
    }

    public TrainTestSplitter(DataSource<T> dataSource) {
        this(dataSource, 1L);
    }

    public TrainTestSplitter(DataSource<T> dataSource, long j) {
        this(dataSource, 0.7d, j);
    }

    public TrainTestSplitter(DataSource<T> dataSource, double d, long j) {
        this.seed = j;
        this.trainProportion = d;
        this.originalProvenance = dataSource.getProvenance();
        ArrayList arrayList = new ArrayList();
        Iterator<Example<T>> it = dataSource.iterator();
        while (it.hasNext()) {
            arrayList.add((Example) it.next());
        }
        this.size = arrayList.size();
        Collections.shuffle(arrayList, new Random(j));
        int size = (int) (d * arrayList.size());
        this.train = new ListDataSource(arrayList.subList(0, size), dataSource.getOutputFactory(), new SplitDataSourceProvenance(this, true));
        this.test = new ListDataSource(arrayList.subList(size, arrayList.size()), dataSource.getOutputFactory(), new SplitDataSourceProvenance(this, false));
    }

    public int totalSize() {
        return this.size;
    }

    public DataSource<T> getTrain() {
        return this.train;
    }

    public DataSource<T> getTest() {
        return this.test;
    }
}
