package org.arbiter.optimize.runner;

import java.beans.ConstructorProperties;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicLong;
import org.arbiter.optimize.api.Candidate;
import org.arbiter.optimize.api.OptimizationResult;
import org.arbiter.optimize.api.saving.ResultReference;
import org.arbiter.optimize.api.saving.ResultSaver;
import org.arbiter.optimize.api.termination.TerminationCondition;
import org.arbiter.optimize.config.OptimizationConfiguration;
import org.arbiter.optimize.executor.CandidateExecutor;
import org.arbiter.optimize.runner.listener.runner.OptimizationRunnerStatusListener;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/arbiter/optimize/runner/OptimizationRunner.class */
public class OptimizationRunner<C, M, D, A> implements IOptimizationRunner<C, M, A> {
    private static final int POLLING_FREQUENCY = 10;
    private static final TimeUnit POLLING_FREQUENCY_UNIT = TimeUnit.SECONDS;
    private static Logger log = LoggerFactory.getLogger(OptimizationRunner.class);
    private OptimizationConfiguration<C, M, D, A> config;
    private CandidateExecutor<C, M, D, A> executor;
    private ExecutorService futureListenerExecutor;
    private Queue<Future<OptimizationResult<C, M, A>>> queuedFutures = new ConcurrentLinkedQueue();
    private BlockingQueue<Future<OptimizationResult<C, M, A>>> completedFutures = new LinkedBlockingQueue();
    private int totalCandidateCount = 0;
    private int numCandidatesCompleted = 0;
    private int numCandidatesFailed = 0;
    private double bestScore = Double.MAX_VALUE;
    private long bestScoreTime = 0;
    private int bestScoreCandidateIndex = -1;
    private List<ResultReference<C, M, A>> allResults = new ArrayList();
    private Map<Integer, CandidateStatus> currentStatus = new ConcurrentHashMap();
    private List<OptimizationRunnerStatusListener> statusListeners = new ArrayList();

    /* loaded from: input_file:org/arbiter/optimize/runner/OptimizationRunner$FutureDetails.class */
    private class FutureDetails {
        private final Future<OptimizationResult<C, M, A>> future;
        private final long startTime;
        private final int index;

        @ConstructorProperties({"future", "startTime", "index"})
        public FutureDetails(Future<OptimizationResult<C, M, A>> future, long j, int i) {
            this.future = future;
            this.startTime = j;
            this.index = i;
        }

        public Future<OptimizationResult<C, M, A>> getFuture() {
            return this.future;
        }

        public long getStartTime() {
            return this.startTime;
        }

        public int getIndex() {
            return this.index;
        }

        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof FutureDetails)) {
                return false;
            }
            FutureDetails futureDetails = (FutureDetails) obj;
            if (!futureDetails.canEqual(this)) {
                return false;
            }
            Future<OptimizationResult<C, M, A>> future = getFuture();
            Future<OptimizationResult<C, M, A>> future2 = futureDetails.getFuture();
            if (future == null) {
                if (future2 != null) {
                    return false;
                }
            } else if (!future.equals(future2)) {
                return false;
            }
            return getStartTime() == futureDetails.getStartTime() && getIndex() == futureDetails.getIndex();
        }

        protected boolean canEqual(Object obj) {
            return obj instanceof FutureDetails;
        }

        public int hashCode() {
            Future<OptimizationResult<C, M, A>> future = getFuture();
            int hashCode = (1 * 59) + (future == null ? 43 : future.hashCode());
            long startTime = getStartTime();
            return (((hashCode * 59) + ((int) ((startTime >>> 32) ^ startTime))) * 59) + getIndex();
        }

        public String toString() {
            return "OptimizationRunner.FutureDetails(future=" + getFuture() + ", startTime=" + getStartTime() + ", index=" + getIndex() + ")";
        }
    }

    /* loaded from: input_file:org/arbiter/optimize/runner/OptimizationRunner$OnCompletionListener.class */
    private class OnCompletionListener implements Runnable {
        private Future<OptimizationResult<C, M, A>> future;

        @Override // java.lang.Runnable
        public void run() {
            OptimizationRunner.this.completedFutures.add(this.future);
        }

        @ConstructorProperties({"future"})
        public OnCompletionListener(Future<OptimizationResult<C, M, A>> future) {
            this.future = future;
        }
    }

    public OptimizationRunner(OptimizationConfiguration<C, M, D, A> optimizationConfiguration, CandidateExecutor<C, M, D, A> candidateExecutor) {
        this.config = optimizationConfiguration;
        this.executor = candidateExecutor;
        if (optimizationConfiguration.getTerminationConditions() == null || optimizationConfiguration.getTerminationConditions().size() == 0) {
            throw new IllegalArgumentException("Cannot create OptimizationRunner without TerminationConditions (termination conditions are null or empty)");
        }
        this.futureListenerExecutor = Executors.newFixedThreadPool(candidateExecutor.maxConcurrentTasks(), new ThreadFactory() { // from class: org.arbiter.optimize.runner.OptimizationRunner.1
            private AtomicLong counter = new AtomicLong(0);

            @Override // java.util.concurrent.ThreadFactory
            public Thread newThread(Runnable runnable) {
                Thread newThread = Executors.defaultThreadFactory().newThread(runnable);
                newThread.setDaemon(true);
                newThread.setName("ArbiterOptimizationRunner-" + this.counter.getAndIncrement());
                return newThread;
            }
        });
    }

    @Override // org.arbiter.optimize.runner.IOptimizationRunner
    public void execute() {
        log.info("OptimizationRunner: execution started");
        Iterator<OptimizationRunnerStatusListener> it = this.statusListeners.iterator();
        while (it.hasNext()) {
            it.next().onInitialization(this);
        }
        Iterator<TerminationCondition> it2 = this.config.getTerminationConditions().iterator();
        while (it2.hasNext()) {
            it2.next().initialize(this);
        }
        ArrayList<Future<OptimizationResult<C, M, A>>> arrayList = new ArrayList(100);
        while (true) {
            boolean z = false;
            Future<OptimizationResult<C, M, A>> future = null;
            try {
                future = this.completedFutures.poll(10L, POLLING_FREQUENCY_UNIT);
            } catch (InterruptedException e) {
            }
            if (future != null) {
                arrayList.add(future);
            }
            this.completedFutures.drainTo(arrayList);
            for (Future<OptimizationResult<C, M, A>> future2 : arrayList) {
                this.queuedFutures.remove(future2);
                processReturnedTask(future2);
                z = true;
            }
            arrayList.clear();
            if (terminate()) {
                break;
            }
            while (this.queuedFutures.size() < this.executor.maxConcurrentTasks()) {
                Candidate<C> candidate = this.config.getCandidateGenerator().getCandidate();
                Future<OptimizationResult<C, M, A>> execute = this.executor.execute(candidate, this.config.getDataProvider(), this.config.getScoreFunction());
                execute.addListener(new OnCompletionListener(execute), this.futureListenerExecutor);
                this.queuedFutures.add(execute);
                this.totalCandidateCount++;
                z = true;
                this.currentStatus.put(Integer.valueOf(candidate.getIndex()), new CandidateStatus(candidate.getIndex(), Status.Created, null, System.currentTimeMillis(), null, null));
            }
            if (z) {
                Iterator<OptimizationRunnerStatusListener> it3 = this.statusListeners.iterator();
                while (it3.hasNext()) {
                    it3.next().onStatusChange(this);
                }
            }
        }
        this.executor.shutdown();
        this.completedFutures.drainTo(arrayList);
        for (Future<OptimizationResult<C, M, A>> future3 : arrayList) {
            this.queuedFutures.remove(future3);
            processReturnedTask(future3);
        }
        arrayList.clear();
        log.info("Optimization runner: execution complete");
        Iterator<OptimizationRunnerStatusListener> it4 = this.statusListeners.iterator();
        while (it4.hasNext()) {
            it4.next().onShutdown(this);
        }
    }

    private void processReturnedTask(Future<OptimizationResult<C, M, A>> future) {
        try {
            OptimizationResult<C, M, A> optimizationResult = future.get(100L, TimeUnit.MILLISECONDS);
            this.currentStatus.put(Integer.valueOf(optimizationResult.getIndex()), new CandidateStatus(optimizationResult.getIndex(), Status.Complete, Double.valueOf(optimizationResult.getScore()), this.currentStatus.get(Integer.valueOf(optimizationResult.getIndex())).getCreatedTime(), null, null));
            Iterator<OptimizationRunnerStatusListener> it = this.statusListeners.iterator();
            while (it.hasNext()) {
                it.next().onCompletion(optimizationResult);
            }
            double score = optimizationResult.getScore();
            log.info("Completed task {}, score = {}", Integer.valueOf(optimizationResult.getIndex()), Double.valueOf(optimizationResult.getScore()));
            if (score < this.bestScore) {
                if (this.bestScore == Double.MAX_VALUE) {
                    log.info("New best score: {} (first completed model)", Double.valueOf(score));
                } else {
                    log.info("New best score: {} (prev={})", Double.valueOf(score), Double.valueOf(this.bestScore));
                }
                this.bestScore = score;
                this.bestScoreTime = System.currentTimeMillis();
                this.bestScoreCandidateIndex = optimizationResult.getIndex();
            }
            this.numCandidatesCompleted++;
            ResultSaver<C, M, A> resultSaver = this.config.getResultSaver();
            ResultReference<C, M, A> resultReference = null;
            if (resultSaver != null) {
                try {
                    resultReference = resultSaver.saveModel(optimizationResult);
                } catch (IOException e) {
                    log.warn("Error saving model (id={}): IOException thrown. ", Integer.valueOf(optimizationResult.getIndex()), e);
                }
            }
            if (resultReference != null) {
                this.allResults.add(resultReference);
            }
        } catch (InterruptedException e2) {
            throw new RuntimeException("Unexpected InterruptedException thrown for task", e2);
        } catch (ExecutionException e3) {
            log.warn("Task failed", e3);
            this.numCandidatesFailed++;
        } catch (TimeoutException e4) {
            throw new RuntimeException(e4);
        }
    }

    @Override // org.arbiter.optimize.runner.IOptimizationRunner
    public int numCandidatesTotal() {
        return this.totalCandidateCount;
    }

    @Override // org.arbiter.optimize.runner.IOptimizationRunner
    public int numCandidatesCompleted() {
        return this.numCandidatesCompleted;
    }

    @Override // org.arbiter.optimize.runner.IOptimizationRunner
    public int numCandidatesFailed() {
        return this.numCandidatesFailed;
    }

    @Override // org.arbiter.optimize.runner.IOptimizationRunner
    public int numCandidatesQueued() {
        return this.queuedFutures.size();
    }

    @Override // org.arbiter.optimize.runner.IOptimizationRunner
    public double bestScore() {
        return this.bestScore;
    }

    @Override // org.arbiter.optimize.runner.IOptimizationRunner
    public long bestScoreTime() {
        return this.bestScoreTime;
    }

    @Override // org.arbiter.optimize.runner.IOptimizationRunner
    public int bestScoreCandidateIndex() {
        return this.bestScoreCandidateIndex;
    }

    @Override // org.arbiter.optimize.runner.IOptimizationRunner
    public List<ResultReference<C, M, A>> getResults() {
        return new ArrayList(this.allResults);
    }

    @Override // org.arbiter.optimize.runner.IOptimizationRunner
    public OptimizationConfiguration<C, M, ?, A> getConfiguration() {
        return this.config;
    }

    @Override // org.arbiter.optimize.runner.IOptimizationRunner
    public void addListeners(OptimizationRunnerStatusListener... optimizationRunnerStatusListenerArr) {
        for (OptimizationRunnerStatusListener optimizationRunnerStatusListener : optimizationRunnerStatusListenerArr) {
            if (!this.statusListeners.contains(optimizationRunnerStatusListener)) {
                this.statusListeners.add(optimizationRunnerStatusListener);
            }
        }
    }

    @Override // org.arbiter.optimize.runner.IOptimizationRunner
    public void removeListeners(OptimizationRunnerStatusListener... optimizationRunnerStatusListenerArr) {
        for (OptimizationRunnerStatusListener optimizationRunnerStatusListener : optimizationRunnerStatusListenerArr) {
            if (this.statusListeners.contains(optimizationRunnerStatusListener)) {
                this.statusListeners.remove(optimizationRunnerStatusListener);
            }
        }
    }

    @Override // org.arbiter.optimize.runner.IOptimizationRunner
    public void removeAllListeners() {
        this.statusListeners.clear();
    }

    @Override // org.arbiter.optimize.runner.IOptimizationRunner
    public List<CandidateStatus> getCandidateStatus() {
        ArrayList arrayList = new ArrayList();
        arrayList.addAll(this.currentStatus.values());
        return arrayList;
    }

    private boolean terminate() {
        for (TerminationCondition terminationCondition : this.config.getTerminationConditions()) {
            if (terminationCondition.terminate(this)) {
                log.info("OptimizationRunner global termination condition hit: {}", terminationCondition);
                return true;
            }
        }
        return false;
    }
}
