package ai.djl.basicdataset.cv;

import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.util.NDImageUtils;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import ai.djl.training.dataset.RandomAccessDataset;
import java.io.IOException;
import java.util.Optional;

/* loaded from: input_file:ai/djl/basicdataset/cv/ImageDataset.class */
public abstract class ImageDataset extends RandomAccessDataset {
    protected Image.Flag flag;

    /* loaded from: input_file:ai/djl/basicdataset/cv/ImageDataset$BaseBuilder.class */
    public static abstract class BaseBuilder<T extends BaseBuilder<T>> extends RandomAccessDataset.BaseBuilder<T> {
        Image.Flag flag = Image.Flag.COLOR;

        public T optFlag(Image.Flag flag) {
            this.flag = flag;
            return (T) self();
        }
    }

    public ImageDataset(BaseBuilder<?> baseBuilder) {
        super(baseBuilder);
        this.flag = baseBuilder.flag;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public NDArray getRecordImage(NDManager nDManager, long j) throws IOException {
        NDArray nDArray = getImage(j).toNDArray(nDManager, this.flag);
        Optional<Integer> imageWidth = getImageWidth();
        Optional<Integer> imageHeight = getImageHeight();
        if (imageWidth.isPresent() && imageHeight.isPresent()) {
            nDArray = NDImageUtils.resize(nDArray, imageWidth.get().intValue(), imageHeight.get().intValue());
        }
        return nDArray;
    }

    protected abstract Image getImage(long j) throws IOException;

    public int getImageChannels() {
        return this.flag.numChannels();
    }

    public abstract Optional<Integer> getImageWidth();

    public abstract Optional<Integer> getImageHeight();
}
