package org.deeplearning4j.arbiter.optimize.runner;

import com.google.common.util.concurrent.ListenableFuture;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import org.apache.commons.lang3.exception.ExceptionUtils;
import org.deeplearning4j.arbiter.optimize.api.Candidate;
import org.deeplearning4j.arbiter.optimize.api.OptimizationResult;
import org.deeplearning4j.arbiter.optimize.api.data.DataProvider;
import org.deeplearning4j.arbiter.optimize.api.saving.ResultReference;
import org.deeplearning4j.arbiter.optimize.api.score.ScoreFunction;
import org.deeplearning4j.arbiter.optimize.api.termination.TerminationCondition;
import org.deeplearning4j.arbiter.optimize.config.OptimizationConfiguration;
import org.deeplearning4j.arbiter.optimize.runner.listener.StatusListener;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/arbiter/optimize/runner/BaseOptimizationRunner.class */
public abstract class BaseOptimizationRunner implements IOptimizationRunner {
    private static final int POLLING_FREQUENCY = 1;
    protected OptimizationConfiguration config;
    protected ExecutorService futureListenerExecutor;
    private static final Logger log = LoggerFactory.getLogger(BaseOptimizationRunner.class);
    private static final TimeUnit POLLING_FREQUENCY_UNIT = TimeUnit.SECONDS;
    protected Queue<Future<OptimizationResult>> queuedFutures = new ConcurrentLinkedQueue();
    protected BlockingQueue<Future<OptimizationResult>> completedFutures = new LinkedBlockingQueue();
    protected AtomicInteger totalCandidateCount = new AtomicInteger();
    protected AtomicInteger numCandidatesCompleted = new AtomicInteger();
    protected AtomicInteger numCandidatesFailed = new AtomicInteger();
    protected Double bestScore = null;
    protected Long bestScoreTime = null;
    protected AtomicInteger bestScoreCandidateIndex = new AtomicInteger(-1);
    protected List<ResultReference> allResults = new ArrayList();
    protected Map<Integer, CandidateInfo> currentStatus = new ConcurrentHashMap();
    protected List<StatusListener> statusListeners = new ArrayList();

    /* loaded from: input_file:org/deeplearning4j/arbiter/optimize/runner/BaseOptimizationRunner$FutureDetails.class */
    private class FutureDetails {
        private final Future<OptimizationResult> future;
        private final long startTime;
        private final int index;

        public FutureDetails(Future<OptimizationResult> future, long j, int i) {
            this.future = future;
            this.startTime = j;
            this.index = i;
        }

        public Future<OptimizationResult> getFuture() {
            return this.future;
        }

        public long getStartTime() {
            return this.startTime;
        }

        public int getIndex() {
            return this.index;
        }

        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof FutureDetails)) {
                return false;
            }
            FutureDetails futureDetails = (FutureDetails) obj;
            if (!futureDetails.canEqual(this)) {
                return false;
            }
            Future<OptimizationResult> future = getFuture();
            Future<OptimizationResult> future2 = futureDetails.getFuture();
            if (future == null) {
                if (future2 != null) {
                    return false;
                }
            } else if (!future.equals(future2)) {
                return false;
            }
            return getStartTime() == futureDetails.getStartTime() && getIndex() == futureDetails.getIndex();
        }

        protected boolean canEqual(Object obj) {
            return obj instanceof FutureDetails;
        }

        public int hashCode() {
            Future<OptimizationResult> future = getFuture();
            int hashCode = (1 * 59) + (future == null ? 43 : future.hashCode());
            long startTime = getStartTime();
            return (((hashCode * 59) + ((int) ((startTime >>> 32) ^ startTime))) * 59) + getIndex();
        }

        public String toString() {
            return "BaseOptimizationRunner.FutureDetails(future=" + getFuture() + ", startTime=" + getStartTime() + ", index=" + getIndex() + ")";
        }
    }

    /* loaded from: input_file:org/deeplearning4j/arbiter/optimize/runner/BaseOptimizationRunner$OnCompletionListener.class */
    private class OnCompletionListener implements Runnable {
        private Future<OptimizationResult> future;

        @Override // java.lang.Runnable
        public void run() {
            BaseOptimizationRunner.this.completedFutures.add(this.future);
        }

        public OnCompletionListener(Future<OptimizationResult> future) {
            this.future = future;
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public BaseOptimizationRunner(OptimizationConfiguration optimizationConfiguration) {
        this.config = optimizationConfiguration;
        if (optimizationConfiguration.getTerminationConditions() == null || optimizationConfiguration.getTerminationConditions().size() == 0) {
            throw new IllegalArgumentException("Cannot create BaseOptimizationRunner without TerminationConditions (termination conditions are null or empty)");
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void init() {
        this.futureListenerExecutor = Executors.newFixedThreadPool(maxConcurrentTasks(), new ThreadFactory() { // from class: org.deeplearning4j.arbiter.optimize.runner.BaseOptimizationRunner.1
            private AtomicLong counter = new AtomicLong(0);

            @Override // java.util.concurrent.ThreadFactory
            public Thread newThread(Runnable runnable) {
                Thread newThread = Executors.defaultThreadFactory().newThread(runnable);
                newThread.setDaemon(true);
                newThread.setName("ArbiterOptimizationRunner-" + this.counter.getAndIncrement());
                return newThread;
            }
        });
    }

    @Override // org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner
    public void execute() {
        CandidateInfo candidateInfo;
        log.info("{}: execution started", getClass().getSimpleName());
        this.config.setExecutionStartTime(System.currentTimeMillis());
        Iterator<StatusListener> it = this.statusListeners.iterator();
        while (it.hasNext()) {
            it.next().onInitialization(this);
        }
        Iterator<TerminationCondition> it2 = this.config.getTerminationConditions().iterator();
        while (it2.hasNext()) {
            it2.next().initialize(this);
        }
        ArrayList<Future<OptimizationResult>> arrayList = new ArrayList(100);
        while (true) {
            Future<OptimizationResult> future = null;
            try {
                future = this.completedFutures.poll(1L, POLLING_FREQUENCY_UNIT);
            } catch (InterruptedException e) {
            }
            if (future != null) {
                arrayList.add(future);
            }
            this.completedFutures.drainTo(arrayList);
            for (Future<OptimizationResult> future2 : arrayList) {
                this.queuedFutures.remove(future2);
                processReturnedTask(future2);
            }
            if (arrayList.size() > 0) {
                Iterator<StatusListener> it3 = this.statusListeners.iterator();
                while (it3.hasNext()) {
                    it3.next().onRunnerStatusChange(this);
                }
            }
            arrayList.clear();
            if (terminate()) {
                break;
            }
            while (this.config.getCandidateGenerator().hasMoreCandidates() && this.queuedFutures.size() < maxConcurrentTasks()) {
                Candidate candidate = this.config.getCandidateGenerator().getCandidate();
                if (candidate.getException() != null) {
                    candidateInfo = processFailedCandidates(candidate);
                } else {
                    long currentTimeMillis = System.currentTimeMillis();
                    Future<OptimizationResult> execute = execute(candidate, this.config.getDataProvider(), this.config.getScoreFunction());
                    execute.addListener(new OnCompletionListener(execute), this.futureListenerExecutor);
                    this.queuedFutures.add(execute);
                    this.totalCandidateCount.getAndIncrement();
                    candidateInfo = new CandidateInfo(candidate.getIndex(), CandidateStatus.Created, null, currentTimeMillis, null, null, candidate.getFlatParameters(), null);
                    this.currentStatus.put(Integer.valueOf(candidate.getIndex()), candidateInfo);
                }
                Iterator<StatusListener> it4 = this.statusListeners.iterator();
                while (it4.hasNext()) {
                    it4.next().onCandidateStatusChange(candidateInfo, this, null);
                }
            }
        }
        shutdown(true);
        this.completedFutures.drainTo(arrayList);
        for (Future<OptimizationResult> future3 : arrayList) {
            this.queuedFutures.remove(future3);
            processReturnedTask(future3);
        }
        arrayList.clear();
        log.info("Optimization runner: execution complete");
        Iterator<StatusListener> it5 = this.statusListeners.iterator();
        while (it5.hasNext()) {
            it5.next().onShutdown(this);
        }
    }

    private CandidateInfo processFailedCandidates(Candidate<?> candidate) {
        long currentTimeMillis = System.currentTimeMillis();
        CandidateInfo candidateInfo = new CandidateInfo(candidate.getIndex(), CandidateStatus.Failed, null, currentTimeMillis, Long.valueOf(currentTimeMillis), Long.valueOf(currentTimeMillis), candidate.getFlatParameters(), ExceptionUtils.getStackTrace(candidate.getException()));
        this.currentStatus.put(Integer.valueOf(candidate.getIndex()), candidateInfo);
        return candidateInfo;
    }

    private void processReturnedTask(Future<OptimizationResult> future) {
        long currentTimeMillis = System.currentTimeMillis();
        try {
            OptimizationResult optimizationResult = future.get(100L, TimeUnit.MILLISECONDS);
            CandidateInfo candidateInfo = this.currentStatus.get(Integer.valueOf(optimizationResult.getIndex()));
            this.currentStatus.put(Integer.valueOf(optimizationResult.getIndex()), new CandidateInfo(optimizationResult.getIndex(), optimizationResult.getCandidateInfo().getCandidateStatus(), optimizationResult.getScore(), candidateInfo.getCreatedTime(), optimizationResult.getCandidateInfo().getStartTime(), Long.valueOf(currentTimeMillis), candidateInfo.getFlatParams(), optimizationResult.getCandidateInfo().getExceptionStackTrace()));
            if (optimizationResult.getCandidateInfo().getCandidateStatus() == CandidateStatus.Failed) {
                log.info("Task {} failed during execution: {}", Integer.valueOf(optimizationResult.getIndex()), optimizationResult.getCandidateInfo().getExceptionStackTrace());
                this.numCandidatesFailed.getAndIncrement();
                return;
            }
            this.config.getCandidateGenerator().reportResults(optimizationResult);
            Double score = optimizationResult.getScore();
            log.info("Completed task {}, score = {}", Integer.valueOf(optimizationResult.getIndex()), optimizationResult.getScore());
            boolean minimize = this.config.getScoreFunction().minimize();
            if (score != null && (this.bestScore == null || ((minimize && score.doubleValue() < this.bestScore.doubleValue()) || (!minimize && score.doubleValue() > this.bestScore.doubleValue())))) {
                if (this.bestScore == null) {
                    log.info("New best score: {} (first completed model)", score);
                } else {
                    log.info("New best score: {}, model {} (prev={}, model {})", new Object[]{score, Integer.valueOf(optimizationResult.getIndex()), this.bestScore, Integer.valueOf(this.bestScoreCandidateIndex.get())});
                }
                this.bestScore = score;
                this.bestScoreTime = Long.valueOf(System.currentTimeMillis());
                this.bestScoreCandidateIndex.set(optimizationResult.getIndex());
            }
            this.numCandidatesCompleted.getAndIncrement();
            ResultReference resultReference = optimizationResult.getResultReference();
            if (resultReference != null) {
                this.allResults.add(resultReference);
            }
        } catch (InterruptedException e) {
            throw new RuntimeException("Unexpected InterruptedException thrown for task", e);
        } catch (ExecutionException e2) {
            log.warn("Task failed", e2);
            this.numCandidatesFailed.getAndIncrement();
        } catch (TimeoutException e3) {
            throw new RuntimeException(e3);
        }
    }

    @Override // org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner
    public int numCandidatesTotal() {
        return this.totalCandidateCount.get();
    }

    @Override // org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner
    public int numCandidatesCompleted() {
        return this.numCandidatesCompleted.get();
    }

    @Override // org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner
    public int numCandidatesFailed() {
        return this.numCandidatesFailed.get();
    }

    @Override // org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner
    public int numCandidatesQueued() {
        return this.queuedFutures.size();
    }

    @Override // org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner
    public Double bestScore() {
        return this.bestScore;
    }

    @Override // org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner
    public Long bestScoreTime() {
        return this.bestScoreTime;
    }

    @Override // org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner
    public int bestScoreCandidateIndex() {
        return this.bestScoreCandidateIndex.get();
    }

    @Override // org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner
    public List<ResultReference> getResults() {
        return new ArrayList(this.allResults);
    }

    @Override // org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner
    public OptimizationConfiguration getConfiguration() {
        return this.config;
    }

    @Override // org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner
    public void addListeners(StatusListener... statusListenerArr) {
        for (StatusListener statusListener : statusListenerArr) {
            if (!this.statusListeners.contains(statusListener)) {
                this.statusListeners.add(statusListener);
            }
        }
    }

    @Override // org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner
    public void removeListeners(StatusListener... statusListenerArr) {
        for (StatusListener statusListener : statusListenerArr) {
            if (this.statusListeners.contains(statusListener)) {
                this.statusListeners.remove(statusListener);
            }
        }
    }

    @Override // org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner
    public void removeAllListeners() {
        this.statusListeners.clear();
    }

    @Override // org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner
    public List<CandidateInfo> getCandidateStatus() {
        ArrayList arrayList = new ArrayList();
        arrayList.addAll(this.currentStatus.values());
        return arrayList;
    }

    private boolean terminate() {
        for (TerminationCondition terminationCondition : this.config.getTerminationConditions()) {
            if (terminationCondition.terminate(this)) {
                log.info("BaseOptimizationRunner global termination condition hit: {}", terminationCondition);
                return true;
            }
        }
        return false;
    }

    protected abstract int maxConcurrentTasks();

    protected abstract ListenableFuture<OptimizationResult> execute(Candidate candidate, DataProvider dataProvider, ScoreFunction scoreFunction);

    protected abstract List<ListenableFuture<OptimizationResult>> execute(List<Candidate> list, DataProvider dataProvider, ScoreFunction scoreFunction);
}
