package org.tweetyproject.machinelearning;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import org.tweetyproject.commons.util.Pair;
import org.tweetyproject.machinelearning.Category;
import org.tweetyproject.machinelearning.Observation;

/* loaded from: input_file:org.tweetyproject.machinelearning-1.19-SNAPSHOT.jar:org/tweetyproject/machinelearning/CrossValidator.class */
public class CrossValidator<S extends Observation, T extends Category> extends ClassificationTester<S, T> {
    private int fold;

    public CrossValidator(int i) {
        if (i < 2) {
            throw new IllegalArgumentException("Number of partitions must be greater or equal to 2.");
        }
        this.fold = i;
    }

    @Override // org.tweetyproject.machinelearning.ClassificationTester
    public double test(Trainer<S, T> trainer, TrainingSet<S, T> trainingSet) {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < this.fold; i++) {
            arrayList.add(new TrainingSet());
        }
        Iterator<T> it = trainingSet.getCategories().iterator();
        while (it.hasNext()) {
            int i2 = 0;
            Iterator<Pair<S, T>> it2 = trainingSet.getObservations(it.next()).iterator();
            while (it2.hasNext()) {
                ((TrainingSet) arrayList.get(i2 % this.fold)).add((Pair) it2.next());
                i2++;
            }
        }
        double d = 0.0d;
        for (int i3 = 0; i3 < this.fold; i3++) {
            TrainingSet<S, T> trainingSet2 = new TrainingSet<>();
            for (int i4 = 0; i4 < this.fold; i4++) {
                if (i3 != i4) {
                    trainingSet2.addAll((Collection) arrayList.get(i4));
                }
            }
            d += test(trainer.train(trainingSet2), (TrainingSet) arrayList.get(i3));
        }
        return d / this.fold;
    }
}
