package org.tribuo.evaluation;

import com.oracle.labs.mlrg.olcut.util.Pair;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.logging.Level;
import java.util.logging.Logger;
import org.tribuo.Dataset;
import org.tribuo.Model;
import org.tribuo.Output;
import org.tribuo.Trainer;
import org.tribuo.evaluation.Evaluation;
import org.tribuo.evaluation.KFoldSplitter;

/* loaded from: input_file:org/tribuo/evaluation/CrossValidation.class */
public class CrossValidation<T extends Output<T>, E extends Evaluation<T>> {
    private static final Logger logger = Logger.getLogger(CrossValidation.class.getName());
    private final Trainer<T> trainer;
    private final int numFolds;
    private final Dataset<T> data;
    private final Evaluator<T, E> evaluator;
    private final KFoldSplitter<T> splitter;

    public CrossValidation(Trainer<T> trainer, Dataset<T> dataset, Evaluator<T, E> evaluator, int i) {
        this(trainer, dataset, evaluator, i, Trainer.DEFAULT_SEED);
    }

    public CrossValidation(Trainer<T> trainer, Dataset<T> dataset, Evaluator<T, E> evaluator, int i, long j) {
        this.trainer = trainer;
        this.data = dataset;
        this.evaluator = evaluator;
        this.numFolds = i;
        this.splitter = new KFoldSplitter<>(i, j);
    }

    public int getK() {
        return this.numFolds;
    }

    public List<Pair<E, Model<T>>> evaluate() {
        ArrayList arrayList = new ArrayList();
        Iterator<KFoldSplitter.TrainTestFold<T>> split = this.splitter.split(this.data, true);
        int i = 0;
        while (split.hasNext()) {
            logger.log(Level.INFO, "Training for fold " + i);
            KFoldSplitter.TrainTestFold<T> next = split.next();
            Model<T> train = this.trainer.train(next.train);
            arrayList.add(new Pair(this.evaluator.evaluate(train, next.test), train));
            i++;
        }
        return arrayList;
    }
}
