/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.rl4j.learning.async;

import org.deeplearning4j.rl4j.learning.HistoryProcessor;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.learning.StepCountable;
import org.deeplearning4j.rl4j.learning.async.AsyncConfiguration;
import org.deeplearning4j.rl4j.learning.async.AsyncGlobal;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.policy.Policy;
import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.DataManager;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class AsyncThread<O extends Encodable, A, AS extends ActionSpace<A>, NN extends NeuralNet>
extends Thread
implements StepCountable {
    private static final Logger log = LoggerFactory.getLogger(AsyncThread.class);
    private int threadNumber;
    private int stepCounter = 0;
    private int epochCounter = 0;
    private IHistoryProcessor historyProcessor;
    private int lastMonitor = -10000;

    public AsyncThread(AsyncGlobal<NN> asyncGlobal, int threadNumber) {
        this.threadNumber = threadNumber;
    }

    public void setHistoryProcessor(IHistoryProcessor.Configuration conf) {
        this.historyProcessor = new HistoryProcessor(conf);
    }

    public void setHistoryProcessor(IHistoryProcessor historyProcessor) {
        this.historyProcessor = historyProcessor;
    }

    protected void postEpoch() {
        if (this.getHistoryProcessor() != null) {
            this.getHistoryProcessor().stopMonitor();
        }
    }

    protected void preEpoch() {
        if (this.getStepCounter() - this.lastMonitor >= 10000 && this.getHistoryProcessor() != null && this.getDataManager().isSaveData()) {
            this.lastMonitor = this.getStepCounter();
            int[] shape = this.getMdp().getObservationSpace().getShape();
            this.getHistoryProcessor().startMonitor(this.getDataManager().getVideoDir() + "/video-" + this.threadNumber + "-" + this.getEpochCounter() + "-" + this.getStepCounter() + ".mp4", shape);
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public void run() {
        try {
            log.info("ThreadNum-" + this.threadNumber + " Started!");
            this.getCurrent().reset();
            Learning.InitMdp<O> initMdp = Learning.initMdp(this.getMdp(), this.historyProcessor);
            Encodable obs = (Encodable)initMdp.getLastObs();
            double rewards = initMdp.getReward();
            int length = initMdp.getSteps();
            this.preEpoch();
            while (!this.getAsyncGlobal().isTrainingComplete() && this.getAsyncGlobal().isRunning()) {
                int maxSteps = Math.min(this.getConf().getNstep(), this.getConf().getMaxEpochStep() - length);
                SubEpochReturn<Encodable> subEpochReturn = this.trainSubEpoch(obs, maxSteps);
                obs = subEpochReturn.getLastObs();
                this.stepCounter += subEpochReturn.getSteps();
                rewards += subEpochReturn.getReward();
                double score = subEpochReturn.getScore();
                if ((length += subEpochReturn.getSteps()) < this.getConf().getMaxEpochStep() && !this.getMdp().isDone()) continue;
                this.postEpoch();
                AsyncStatEntry statEntry = new AsyncStatEntry(this.getStepCounter(), this.epochCounter, rewards, length, score);
                this.getDataManager().appendStat(statEntry);
                log.info("ThreadNum-" + this.threadNumber + " Epoch: " + this.getEpochCounter() + ", reward: " + statEntry.getReward());
                this.getCurrent().reset();
                initMdp = Learning.initMdp(this.getMdp(), this.historyProcessor);
                obs = (Encodable)initMdp.getLastObs();
                rewards = initMdp.getReward();
                length = initMdp.getSteps();
                ++this.epochCounter;
                this.preEpoch();
            }
        }
        catch (Exception e) {
            log.error("Thread crashed: " + e.getCause());
            this.getAsyncGlobal().setRunning(false);
            e.printStackTrace();
        }
        finally {
            this.postEpoch();
        }
    }

    protected abstract NN getCurrent();

    protected abstract int getThreadNumber();

    protected abstract AsyncGlobal<NN> getAsyncGlobal();

    protected abstract MDP<O, A, AS> getMdp();

    protected abstract AsyncConfiguration getConf();

    protected abstract DataManager getDataManager();

    protected abstract Policy<O, A> getPolicy(NN var1);

    protected abstract SubEpochReturn<O> trainSubEpoch(O var1, int var2);

    @Override
    public int getStepCounter() {
        return this.stepCounter;
    }

    public void setStepCounter(int stepCounter) {
        this.stepCounter = stepCounter;
    }

    public int getEpochCounter() {
        return this.epochCounter;
    }

    public void setEpochCounter(int epochCounter) {
        this.epochCounter = epochCounter;
    }

    public IHistoryProcessor getHistoryProcessor() {
        return this.historyProcessor;
    }

    public int getLastMonitor() {
        return this.lastMonitor;
    }

    public static final class AsyncStatEntry
    implements DataManager.StatEntry {
        private final int stepCounter;
        private final int epochCounter;
        private final double reward;
        private final int episodeLength;
        private final double score;

        @Override
        public int getStepCounter() {
            return this.stepCounter;
        }

        @Override
        public int getEpochCounter() {
            return this.epochCounter;
        }

        @Override
        public double getReward() {
            return this.reward;
        }

        public int getEpisodeLength() {
            return this.episodeLength;
        }

        public double getScore() {
            return this.score;
        }

        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof AsyncStatEntry)) {
                return false;
            }
            AsyncStatEntry other = (AsyncStatEntry)o;
            if (this.getStepCounter() != other.getStepCounter()) {
                return false;
            }
            if (this.getEpochCounter() != other.getEpochCounter()) {
                return false;
            }
            if (Double.compare(this.getReward(), other.getReward()) != 0) {
                return false;
            }
            if (this.getEpisodeLength() != other.getEpisodeLength()) {
                return false;
            }
            return Double.compare(this.getScore(), other.getScore()) == 0;
        }

        public int hashCode() {
            int PRIME = 59;
            int result = 1;
            result = result * 59 + this.getStepCounter();
            result = result * 59 + this.getEpochCounter();
            long $reward = Double.doubleToLongBits(this.getReward());
            result = result * 59 + (int)($reward >>> 32 ^ $reward);
            result = result * 59 + this.getEpisodeLength();
            long $score = Double.doubleToLongBits(this.getScore());
            result = result * 59 + (int)($score >>> 32 ^ $score);
            return result;
        }

        public String toString() {
            return "AsyncThread.AsyncStatEntry(stepCounter=" + this.getStepCounter() + ", epochCounter=" + this.getEpochCounter() + ", reward=" + this.getReward() + ", episodeLength=" + this.getEpisodeLength() + ", score=" + this.getScore() + ")";
        }

        public AsyncStatEntry(int stepCounter, int epochCounter, double reward, int episodeLength, double score) {
            this.stepCounter = stepCounter;
            this.epochCounter = epochCounter;
            this.reward = reward;
            this.episodeLength = episodeLength;
            this.score = score;
        }
    }

    public static final class SubEpochReturn<O> {
        private final int steps;
        private final O lastObs;
        private final double reward;
        private final double score;

        public int getSteps() {
            return this.steps;
        }

        public O getLastObs() {
            return this.lastObs;
        }

        public double getReward() {
            return this.reward;
        }

        public double getScore() {
            return this.score;
        }

        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof SubEpochReturn)) {
                return false;
            }
            SubEpochReturn other = (SubEpochReturn)o;
            if (this.getSteps() != other.getSteps()) {
                return false;
            }
            O this$lastObs = this.getLastObs();
            O other$lastObs = other.getLastObs();
            if (this$lastObs == null ? other$lastObs != null : !this$lastObs.equals(other$lastObs)) {
                return false;
            }
            if (Double.compare(this.getReward(), other.getReward()) != 0) {
                return false;
            }
            return Double.compare(this.getScore(), other.getScore()) == 0;
        }

        public int hashCode() {
            int PRIME = 59;
            int result = 1;
            result = result * 59 + this.getSteps();
            O $lastObs = this.getLastObs();
            result = result * 59 + ($lastObs == null ? 43 : $lastObs.hashCode());
            long $reward = Double.doubleToLongBits(this.getReward());
            result = result * 59 + (int)($reward >>> 32 ^ $reward);
            long $score = Double.doubleToLongBits(this.getScore());
            result = result * 59 + (int)($score >>> 32 ^ $score);
            return result;
        }

        public String toString() {
            return "AsyncThread.SubEpochReturn(steps=" + this.getSteps() + ", lastObs=" + this.getLastObs() + ", reward=" + this.getReward() + ", score=" + this.getScore() + ")";
        }

        public SubEpochReturn(int steps, O lastObs, double reward, double score) {
            this.steps = steps;
            this.lastObs = lastObs;
            this.reward = reward;
            this.score = score;
        }
    }
}

