package org.deeplearning4j.arbiter.optimize.runner;

import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.ListeningExecutorService;
import com.google.common.util.concurrent.MoreExecutors;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.Executors;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.atomic.AtomicLong;
import org.deeplearning4j.arbiter.optimize.api.Candidate;
import org.deeplearning4j.arbiter.optimize.api.OptimizationResult;
import org.deeplearning4j.arbiter.optimize.api.TaskCreator;
import org.deeplearning4j.arbiter.optimize.api.data.DataProvider;
import org.deeplearning4j.arbiter.optimize.api.score.ScoreFunction;
import org.deeplearning4j.arbiter.optimize.config.OptimizationConfiguration;
import org.deeplearning4j.arbiter.optimize.runner.listener.candidate.UICandidateStatusListenerImpl;

/* loaded from: input_file:org/deeplearning4j/arbiter/optimize/runner/LocalOptimizationRunner.class */
public class LocalOptimizationRunner<C, M, D, A> extends BaseOptimizationRunner<C, M, D, A> {
    public static final int DEFAULT_MAX_CONCURRENT_TASKS = 1;
    private final int maxConcurrentTasks;
    private TaskCreator<C, M, D, A> taskCreator;
    private ListeningExecutorService executor;
    private final boolean reportResults = true;

    public LocalOptimizationRunner(OptimizationConfiguration<C, M, D, A> optimizationConfiguration, TaskCreator<C, M, D, A> taskCreator) {
        this(1, optimizationConfiguration, taskCreator);
    }

    public LocalOptimizationRunner(int i, OptimizationConfiguration<C, M, D, A> optimizationConfiguration, TaskCreator<C, M, D, A> taskCreator) {
        super(optimizationConfiguration);
        this.reportResults = true;
        if (i <= 0) {
            throw new IllegalArgumentException("maxConcurrentTasks must be > 0 (got: " + i + ")");
        }
        this.maxConcurrentTasks = i;
        this.taskCreator = taskCreator;
        this.executor = MoreExecutors.listeningDecorator(Executors.newFixedThreadPool(i, new ThreadFactory() { // from class: org.deeplearning4j.arbiter.optimize.runner.LocalOptimizationRunner.1
            private AtomicLong counter = new AtomicLong(0);

            @Override // java.util.concurrent.ThreadFactory
            public Thread newThread(Runnable runnable) {
                Thread newThread = Executors.defaultThreadFactory().newThread(runnable);
                newThread.setDaemon(true);
                newThread.setName("LocalCandidateExecutor-" + this.counter.getAndIncrement());
                return newThread;
            }
        }));
        init();
    }

    @Override // org.deeplearning4j.arbiter.optimize.runner.BaseOptimizationRunner
    protected int maxConcurrentTasks() {
        return this.maxConcurrentTasks;
    }

    @Override // org.deeplearning4j.arbiter.optimize.runner.BaseOptimizationRunner
    protected ListenableFuture<OptimizationResult<C, M, A>> execute(Candidate<C> candidate, DataProvider<D> dataProvider, ScoreFunction<M, D> scoreFunction) {
        return execute(Collections.singletonList(candidate), dataProvider, scoreFunction).get(0);
    }

    @Override // org.deeplearning4j.arbiter.optimize.runner.BaseOptimizationRunner
    protected List<ListenableFuture<OptimizationResult<C, M, A>>> execute(List<Candidate<C>> list, DataProvider<D> dataProvider, ScoreFunction<M, D> scoreFunction) {
        ArrayList arrayList = new ArrayList(list.size());
        for (Candidate<C> candidate : list) {
            arrayList.add(this.executor.submit(this.taskCreator.create(candidate, dataProvider, scoreFunction, new UICandidateStatusListenerImpl(candidate.getIndex()))));
        }
        return arrayList;
    }

    @Override // org.deeplearning4j.arbiter.optimize.runner.BaseOptimizationRunner
    protected void shutdown() {
        this.executor.shutdownNow();
    }
}
