package org.deeplearning4j.spark.data;

import java.util.ArrayList;
import java.util.Iterator;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.nd4j.linalg.dataset.DataSet;

/* loaded from: input_file:org/deeplearning4j/spark/data/BatchDataSetsFunction.class */
public class BatchDataSetsFunction implements FlatMapFunction<Iterator<DataSet>, DataSet> {
    private final int minibatchSize;

    public Iterator<DataSet> call(Iterator<DataSet> it) throws Exception {
        ArrayList arrayList = new ArrayList();
        while (it.hasNext()) {
            ArrayList arrayList2 = new ArrayList();
            int i = 0;
            while (i < this.minibatchSize && it.hasNext()) {
                DataSet next = it.next();
                i = (int) (i + next.getFeatures().size(0));
                arrayList2.add(next);
            }
            arrayList.add(arrayList2.isEmpty() ? (DataSet) arrayList2.get(0) : DataSet.merge(arrayList2));
        }
        return arrayList.iterator();
    }

    public BatchDataSetsFunction(int i) {
        this.minibatchSize = i;
    }
}
