package ai.djl.training.dataset;

import ai.djl.training.dataset.Sampler;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;

/* loaded from: input_file:ai/djl/training/dataset/BatchSampler.class */
public class BatchSampler implements Sampler {
    private Sampler.SubSampler subSampler;
    private int batchSize;
    private boolean dropLast;

    /* loaded from: input_file:ai/djl/training/dataset/BatchSampler$Iterate.class */
    class Iterate implements Iterator<List<Long>> {
        private long size;
        private long current = 0;
        private Iterator<Long> itemSampler;

        Iterate(RandomAccessDataset randomAccessDataset) {
            if (BatchSampler.this.dropLast) {
                this.size = randomAccessDataset.size() / BatchSampler.this.batchSize;
            } else {
                this.size = ((randomAccessDataset.size() + BatchSampler.this.batchSize) - 1) / BatchSampler.this.batchSize;
            }
            this.itemSampler = BatchSampler.this.subSampler.sample(randomAccessDataset);
        }

        @Override // java.util.Iterator
        public boolean hasNext() {
            return this.current < this.size;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.Iterator
        public List<Long> next() {
            ArrayList arrayList = new ArrayList();
            while (this.itemSampler.hasNext()) {
                arrayList.add(this.itemSampler.next());
                if (arrayList.size() == BatchSampler.this.batchSize) {
                    break;
                }
            }
            this.current++;
            return arrayList;
        }
    }

    public BatchSampler(Sampler.SubSampler subSampler, int i) {
        this(subSampler, i, false);
    }

    public BatchSampler(Sampler.SubSampler subSampler, int i, boolean z) {
        this.subSampler = subSampler;
        this.batchSize = i;
        this.dropLast = z;
    }

    @Override // ai.djl.training.dataset.Sampler
    public Iterator<List<Long>> sample(RandomAccessDataset randomAccessDataset) {
        return new Iterate(randomAccessDataset);
    }

    @Override // ai.djl.training.dataset.Sampler
    public int getBatchSize() {
        return this.batchSize;
    }
}
