package org.deeplearning4j.spark.iterator;

import java.util.UUID;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingQueue;
import lombok.NonNull;
import org.apache.spark.TaskContext;
import org.apache.spark.TaskContextHelper;
import org.nd4j.linalg.dataset.AsyncMultiDataSetIterator;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.dataset.callbacks.DataSetCallback;
import org.nd4j.linalg.dataset.callbacks.DefaultCallback;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/spark/iterator/SparkAMDSI.class */
public class SparkAMDSI extends AsyncMultiDataSetIterator {
    private static final Logger log = LoggerFactory.getLogger(SparkAMDSI.class);
    protected TaskContext context;

    /* loaded from: input_file:org/deeplearning4j/spark/iterator/SparkAMDSI$SparkPrefetchThread.class */
    protected class SparkPrefetchThread extends AsyncMultiDataSetIterator.AsyncPrefetchThread {
        protected SparkPrefetchThread(@NonNull BlockingQueue<MultiDataSet> blockingQueue, @NonNull MultiDataSetIterator multiDataSetIterator, @NonNull MultiDataSet multiDataSet, int i) {
            super(SparkAMDSI.this, blockingQueue, multiDataSetIterator, multiDataSet, i);
            if (blockingQueue == null) {
                throw new NullPointerException("queue is marked non-null but is null");
            }
            if (multiDataSetIterator == null) {
                throw new NullPointerException("iterator is marked non-null but is null");
            }
            if (multiDataSet == null) {
                throw new NullPointerException("terminator is marked non-null but is null");
            }
        }
    }

    protected SparkAMDSI() {
    }

    public SparkAMDSI(MultiDataSetIterator multiDataSetIterator) {
        this(multiDataSetIterator, 8);
    }

    public SparkAMDSI(MultiDataSetIterator multiDataSetIterator, int i, BlockingQueue<MultiDataSet> blockingQueue) {
        this(multiDataSetIterator, i, blockingQueue, true);
    }

    public SparkAMDSI(MultiDataSetIterator multiDataSetIterator, int i) {
        this(multiDataSetIterator, i, new LinkedBlockingQueue(i));
    }

    public SparkAMDSI(MultiDataSetIterator multiDataSetIterator, int i, boolean z) {
        this(multiDataSetIterator, i, new LinkedBlockingQueue(i), z);
    }

    public SparkAMDSI(MultiDataSetIterator multiDataSetIterator, int i, boolean z, Integer num) {
        this(multiDataSetIterator, i, new LinkedBlockingQueue(i), z, new DefaultCallback(), num);
    }

    public SparkAMDSI(MultiDataSetIterator multiDataSetIterator, int i, boolean z, DataSetCallback dataSetCallback) {
        this(multiDataSetIterator, i, new LinkedBlockingQueue(i), z, dataSetCallback);
    }

    public SparkAMDSI(MultiDataSetIterator multiDataSetIterator, int i, BlockingQueue<MultiDataSet> blockingQueue, boolean z) {
        this(multiDataSetIterator, i, blockingQueue, z, null);
    }

    public SparkAMDSI(MultiDataSetIterator multiDataSetIterator, int i, BlockingQueue<MultiDataSet> blockingQueue, boolean z, DataSetCallback dataSetCallback) {
        this(multiDataSetIterator, i, blockingQueue, z, dataSetCallback, Nd4j.getAffinityManager().getDeviceForCurrentThread());
    }

    public SparkAMDSI(MultiDataSetIterator multiDataSetIterator, int i, BlockingQueue<MultiDataSet> blockingQueue, boolean z, DataSetCallback dataSetCallback, Integer num) {
        this();
        i = i < 2 ? 2 : i;
        this.callback = dataSetCallback;
        this.buffer = blockingQueue;
        this.backedIterator = multiDataSetIterator;
        this.useWorkspaces = z;
        this.prefetchSize = i;
        this.workspaceId = "SAMDSI_ITER-" + UUID.randomUUID().toString();
        this.deviceId = num;
        if (multiDataSetIterator.resetSupported()) {
            this.backedIterator.reset();
        }
        this.thread = new SparkPrefetchThread(this.buffer, multiDataSetIterator, this.terminator, Nd4j.getAffinityManager().getDeviceForCurrentThread().intValue());
        this.context = TaskContext.get();
        this.thread.setDaemon(true);
        this.thread.start();
    }

    protected void externalCall() {
        TaskContextHelper.setTaskContext(this.context);
    }
}
