/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.rl4j.learning.sync.qlearning;

import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import com.fasterxml.jackson.databind.annotation.JsonPOJOBuilder;
import java.util.ArrayList;
import java.util.List;
import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.rl4j.learning.ILearning;
import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.learning.sync.ExpReplay;
import org.deeplearning4j.rl4j.learning.sync.IExpReplay;
import org.deeplearning4j.rl4j.learning.sync.SyncLearning;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.deeplearning4j.rl4j.policy.EpsGreedy;
import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.DataManager;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class QLearning<O extends Encodable, A, AS extends ActionSpace<A>>
extends SyncLearning<O, A, AS, IDQN> {
    private static final Logger log = LoggerFactory.getLogger(QLearning.class);
    private final IExpReplay<A> expReplay;

    public QLearning(QLConfiguration conf) {
        super(conf);
        this.expReplay = new ExpReplay(conf.getExpRepMaxSize(), conf.getBatchSize(), conf.getSeed());
    }

    protected abstract EpsGreedy<O, A, AS> getEgPolicy();

    @Override
    public abstract MDP<O, A, AS> getMdp();

    protected abstract IDQN getCurrentDQN();

    protected abstract IDQN getTargetDQN();

    protected abstract void setTargetDQN(IDQN var1);

    protected INDArray dqnOutput(INDArray input) {
        return this.getCurrentDQN().output(input);
    }

    protected INDArray targetDqnOutput(INDArray input) {
        return this.getTargetDQN().output(input);
    }

    protected void updateTargetNetwork() {
        log.info("Update target network");
        this.setTargetDQN((IDQN)this.getCurrentDQN().clone());
    }

    @Override
    public IDQN getNeuralNet() {
        return this.getCurrentDQN();
    }

    @Override
    public abstract QLConfiguration getConfiguration();

    @Override
    protected abstract void preEpoch();

    @Override
    protected abstract void postEpoch();

    protected abstract QLStepReturn<O> trainStep(O var1);

    @Override
    protected DataManager.StatEntry trainEpoch() {
        int step;
        Learning.InitMdp initMdp = this.initMdp();
        Encodable obs = (Encodable)initMdp.getLastObs();
        double reward = initMdp.getReward();
        Double startQ = Double.NaN;
        double meanQ = 0.0;
        int numQ = 0;
        ArrayList<Double> scores = new ArrayList<Double>();
        for (step = initMdp.getSteps(); step < this.getConfiguration().getMaxEpochStep() && !this.getMdp().isDone(); ++step) {
            QLStepReturn<Encodable> stepR;
            if (this.getStepCounter() % this.getConfiguration().getTargetDqnUpdateFreq() == 0) {
                this.updateTargetNetwork();
            }
            if (!(stepR = this.trainStep(obs)).getMaxQ().isNaN()) {
                if (startQ.isNaN()) {
                    startQ = stepR.getMaxQ();
                }
                ++numQ;
                meanQ += stepR.getMaxQ().doubleValue();
            }
            if (stepR.getScore() != 0.0) {
                scores.add(stepR.getScore());
            }
            reward += stepR.getStepReply().getReward();
            obs = (Encodable)stepR.getStepReply().getObservation();
            this.incrementStep();
        }
        QLStatEntry statEntry = new QLStatEntry(this.getStepCounter(), this.getEpochCounter(), reward, step, scores, this.getEgPolicy().getEpsilon(), startQ, meanQ /= (double)numQ + 0.001);
        return statEntry;
    }

    public IExpReplay<A> getExpReplay() {
        return this.expReplay;
    }

    @JsonDeserialize(builder=QLConfigurationBuilder.class)
    public static class QLConfiguration
    implements ILearning.LConfiguration {
        int seed;
        int maxEpochStep;
        int maxStep;
        int expRepMaxSize;
        int batchSize;
        int targetDqnUpdateFreq;
        int updateStart;
        double rewardFactor;
        double gamma;
        double errorClamp;
        float minEpsilon;
        int epsilonNbStep;
        boolean doubleDQN;

        public static QLConfigurationBuilder builder() {
            return new QLConfigurationBuilder();
        }

        @Override
        public int getSeed() {
            return this.seed;
        }

        @Override
        public int getMaxEpochStep() {
            return this.maxEpochStep;
        }

        @Override
        public int getMaxStep() {
            return this.maxStep;
        }

        public int getExpRepMaxSize() {
            return this.expRepMaxSize;
        }

        public int getBatchSize() {
            return this.batchSize;
        }

        public int getTargetDqnUpdateFreq() {
            return this.targetDqnUpdateFreq;
        }

        public int getUpdateStart() {
            return this.updateStart;
        }

        public double getRewardFactor() {
            return this.rewardFactor;
        }

        @Override
        public double getGamma() {
            return this.gamma;
        }

        public double getErrorClamp() {
            return this.errorClamp;
        }

        public float getMinEpsilon() {
            return this.minEpsilon;
        }

        public int getEpsilonNbStep() {
            return this.epsilonNbStep;
        }

        public boolean isDoubleDQN() {
            return this.doubleDQN;
        }

        public void setSeed(int seed) {
            this.seed = seed;
        }

        public void setMaxEpochStep(int maxEpochStep) {
            this.maxEpochStep = maxEpochStep;
        }

        public void setMaxStep(int maxStep) {
            this.maxStep = maxStep;
        }

        public void setExpRepMaxSize(int expRepMaxSize) {
            this.expRepMaxSize = expRepMaxSize;
        }

        public void setBatchSize(int batchSize) {
            this.batchSize = batchSize;
        }

        public void setTargetDqnUpdateFreq(int targetDqnUpdateFreq) {
            this.targetDqnUpdateFreq = targetDqnUpdateFreq;
        }

        public void setUpdateStart(int updateStart) {
            this.updateStart = updateStart;
        }

        public void setRewardFactor(double rewardFactor) {
            this.rewardFactor = rewardFactor;
        }

        public void setGamma(double gamma) {
            this.gamma = gamma;
        }

        public void setErrorClamp(double errorClamp) {
            this.errorClamp = errorClamp;
        }

        public void setMinEpsilon(float minEpsilon) {
            this.minEpsilon = minEpsilon;
        }

        public void setEpsilonNbStep(int epsilonNbStep) {
            this.epsilonNbStep = epsilonNbStep;
        }

        public void setDoubleDQN(boolean doubleDQN) {
            this.doubleDQN = doubleDQN;
        }

        public String toString() {
            return "QLearning.QLConfiguration(seed=" + this.getSeed() + ", maxEpochStep=" + this.getMaxEpochStep() + ", maxStep=" + this.getMaxStep() + ", expRepMaxSize=" + this.getExpRepMaxSize() + ", batchSize=" + this.getBatchSize() + ", targetDqnUpdateFreq=" + this.getTargetDqnUpdateFreq() + ", updateStart=" + this.getUpdateStart() + ", rewardFactor=" + this.getRewardFactor() + ", gamma=" + this.getGamma() + ", errorClamp=" + this.getErrorClamp() + ", minEpsilon=" + this.getMinEpsilon() + ", epsilonNbStep=" + this.getEpsilonNbStep() + ", doubleDQN=" + this.isDoubleDQN() + ")";
        }

        public QLConfiguration(int seed, int maxEpochStep, int maxStep, int expRepMaxSize, int batchSize, int targetDqnUpdateFreq, int updateStart, double rewardFactor, double gamma, double errorClamp, float minEpsilon, int epsilonNbStep, boolean doubleDQN) {
            this.seed = seed;
            this.maxEpochStep = maxEpochStep;
            this.maxStep = maxStep;
            this.expRepMaxSize = expRepMaxSize;
            this.batchSize = batchSize;
            this.targetDqnUpdateFreq = targetDqnUpdateFreq;
            this.updateStart = updateStart;
            this.rewardFactor = rewardFactor;
            this.gamma = gamma;
            this.errorClamp = errorClamp;
            this.minEpsilon = minEpsilon;
            this.epsilonNbStep = epsilonNbStep;
            this.doubleDQN = doubleDQN;
        }

        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof QLConfiguration)) {
                return false;
            }
            QLConfiguration other = (QLConfiguration)o;
            if (!other.canEqual(this)) {
                return false;
            }
            if (this.getSeed() != other.getSeed()) {
                return false;
            }
            if (this.getMaxEpochStep() != other.getMaxEpochStep()) {
                return false;
            }
            if (this.getMaxStep() != other.getMaxStep()) {
                return false;
            }
            if (this.getExpRepMaxSize() != other.getExpRepMaxSize()) {
                return false;
            }
            if (this.getBatchSize() != other.getBatchSize()) {
                return false;
            }
            if (this.getTargetDqnUpdateFreq() != other.getTargetDqnUpdateFreq()) {
                return false;
            }
            if (this.getUpdateStart() != other.getUpdateStart()) {
                return false;
            }
            if (Double.compare(this.getRewardFactor(), other.getRewardFactor()) != 0) {
                return false;
            }
            if (Double.compare(this.getGamma(), other.getGamma()) != 0) {
                return false;
            }
            if (Double.compare(this.getErrorClamp(), other.getErrorClamp()) != 0) {
                return false;
            }
            if (Float.compare(this.getMinEpsilon(), other.getMinEpsilon()) != 0) {
                return false;
            }
            if (this.getEpsilonNbStep() != other.getEpsilonNbStep()) {
                return false;
            }
            return this.isDoubleDQN() == other.isDoubleDQN();
        }

        protected boolean canEqual(Object other) {
            return other instanceof QLConfiguration;
        }

        public int hashCode() {
            int PRIME = 59;
            int result = 1;
            result = result * 59 + this.getSeed();
            result = result * 59 + this.getMaxEpochStep();
            result = result * 59 + this.getMaxStep();
            result = result * 59 + this.getExpRepMaxSize();
            result = result * 59 + this.getBatchSize();
            result = result * 59 + this.getTargetDqnUpdateFreq();
            result = result * 59 + this.getUpdateStart();
            long $rewardFactor = Double.doubleToLongBits(this.getRewardFactor());
            result = result * 59 + (int)($rewardFactor >>> 32 ^ $rewardFactor);
            long $gamma = Double.doubleToLongBits(this.getGamma());
            result = result * 59 + (int)($gamma >>> 32 ^ $gamma);
            long $errorClamp = Double.doubleToLongBits(this.getErrorClamp());
            result = result * 59 + (int)($errorClamp >>> 32 ^ $errorClamp);
            result = result * 59 + Float.floatToIntBits(this.getMinEpsilon());
            result = result * 59 + this.getEpsilonNbStep();
            result = result * 59 + (this.isDoubleDQN() ? 79 : 97);
            return result;
        }

        @JsonPOJOBuilder(withPrefix="")
        public static final class QLConfigurationBuilder {
            private int seed;
            private int maxEpochStep;
            private int maxStep;
            private int expRepMaxSize;
            private int batchSize;
            private int targetDqnUpdateFreq;
            private int updateStart;
            private double rewardFactor;
            private double gamma;
            private double errorClamp;
            private float minEpsilon;
            private int epsilonNbStep;
            private boolean doubleDQN;

            QLConfigurationBuilder() {
            }

            public QLConfigurationBuilder seed(int seed) {
                this.seed = seed;
                return this;
            }

            public QLConfigurationBuilder maxEpochStep(int maxEpochStep) {
                this.maxEpochStep = maxEpochStep;
                return this;
            }

            public QLConfigurationBuilder maxStep(int maxStep) {
                this.maxStep = maxStep;
                return this;
            }

            public QLConfigurationBuilder expRepMaxSize(int expRepMaxSize) {
                this.expRepMaxSize = expRepMaxSize;
                return this;
            }

            public QLConfigurationBuilder batchSize(int batchSize) {
                this.batchSize = batchSize;
                return this;
            }

            public QLConfigurationBuilder targetDqnUpdateFreq(int targetDqnUpdateFreq) {
                this.targetDqnUpdateFreq = targetDqnUpdateFreq;
                return this;
            }

            public QLConfigurationBuilder updateStart(int updateStart) {
                this.updateStart = updateStart;
                return this;
            }

            public QLConfigurationBuilder rewardFactor(double rewardFactor) {
                this.rewardFactor = rewardFactor;
                return this;
            }

            public QLConfigurationBuilder gamma(double gamma) {
                this.gamma = gamma;
                return this;
            }

            public QLConfigurationBuilder errorClamp(double errorClamp) {
                this.errorClamp = errorClamp;
                return this;
            }

            public QLConfigurationBuilder minEpsilon(float minEpsilon) {
                this.minEpsilon = minEpsilon;
                return this;
            }

            public QLConfigurationBuilder epsilonNbStep(int epsilonNbStep) {
                this.epsilonNbStep = epsilonNbStep;
                return this;
            }

            public QLConfigurationBuilder doubleDQN(boolean doubleDQN) {
                this.doubleDQN = doubleDQN;
                return this;
            }

            public QLConfiguration build() {
                return new QLConfiguration(this.seed, this.maxEpochStep, this.maxStep, this.expRepMaxSize, this.batchSize, this.targetDqnUpdateFreq, this.updateStart, this.rewardFactor, this.gamma, this.errorClamp, this.minEpsilon, this.epsilonNbStep, this.doubleDQN);
            }

            public String toString() {
                return "QLearning.QLConfiguration.QLConfigurationBuilder(seed=" + this.seed + ", maxEpochStep=" + this.maxEpochStep + ", maxStep=" + this.maxStep + ", expRepMaxSize=" + this.expRepMaxSize + ", batchSize=" + this.batchSize + ", targetDqnUpdateFreq=" + this.targetDqnUpdateFreq + ", updateStart=" + this.updateStart + ", rewardFactor=" + this.rewardFactor + ", gamma=" + this.gamma + ", errorClamp=" + this.errorClamp + ", minEpsilon=" + this.minEpsilon + ", epsilonNbStep=" + this.epsilonNbStep + ", doubleDQN=" + this.doubleDQN + ")";
            }
        }
    }

    public static final class QLStepReturn<O> {
        private final Double maxQ;
        private final double score;
        private final StepReply<O> stepReply;

        public static <O> QLStepReturnBuilder<O> builder() {
            return new QLStepReturnBuilder();
        }

        public Double getMaxQ() {
            return this.maxQ;
        }

        public double getScore() {
            return this.score;
        }

        public StepReply<O> getStepReply() {
            return this.stepReply;
        }

        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof QLStepReturn)) {
                return false;
            }
            QLStepReturn other = (QLStepReturn)o;
            Double this$maxQ = this.getMaxQ();
            Double other$maxQ = other.getMaxQ();
            if (this$maxQ == null ? other$maxQ != null : !((Object)this$maxQ).equals(other$maxQ)) {
                return false;
            }
            if (Double.compare(this.getScore(), other.getScore()) != 0) {
                return false;
            }
            StepReply<O> this$stepReply = this.getStepReply();
            StepReply<O> other$stepReply = other.getStepReply();
            return !(this$stepReply == null ? other$stepReply != null : !this$stepReply.equals(other$stepReply));
        }

        public int hashCode() {
            int PRIME = 59;
            int result = 1;
            Double $maxQ = this.getMaxQ();
            result = result * 59 + ($maxQ == null ? 43 : ((Object)$maxQ).hashCode());
            long $score = Double.doubleToLongBits(this.getScore());
            result = result * 59 + (int)($score >>> 32 ^ $score);
            StepReply<O> $stepReply = this.getStepReply();
            result = result * 59 + ($stepReply == null ? 43 : $stepReply.hashCode());
            return result;
        }

        public String toString() {
            return "QLearning.QLStepReturn(maxQ=" + this.getMaxQ() + ", score=" + this.getScore() + ", stepReply=" + this.getStepReply() + ")";
        }

        public QLStepReturn(Double maxQ, double score, StepReply<O> stepReply) {
            this.maxQ = maxQ;
            this.score = score;
            this.stepReply = stepReply;
        }

        public static class QLStepReturnBuilder<O> {
            private Double maxQ;
            private double score;
            private StepReply<O> stepReply;

            QLStepReturnBuilder() {
            }

            public QLStepReturnBuilder<O> maxQ(Double maxQ) {
                this.maxQ = maxQ;
                return this;
            }

            public QLStepReturnBuilder<O> score(double score) {
                this.score = score;
                return this;
            }

            public QLStepReturnBuilder<O> stepReply(StepReply<O> stepReply) {
                this.stepReply = stepReply;
                return this;
            }

            public QLStepReturn<O> build() {
                return new QLStepReturn<O>(this.maxQ, this.score, this.stepReply);
            }

            public String toString() {
                return "QLearning.QLStepReturn.QLStepReturnBuilder(maxQ=" + this.maxQ + ", score=" + this.score + ", stepReply=" + this.stepReply + ")";
            }
        }
    }

    public static final class QLStatEntry
    implements DataManager.StatEntry {
        private final int stepCounter;
        private final int epochCounter;
        private final double reward;
        private final int episodeLength;
        private final List<Double> scores;
        private final float epsilon;
        private final double startQ;
        private final double meanQ;

        public static QLStatEntryBuilder builder() {
            return new QLStatEntryBuilder();
        }

        @Override
        public int getStepCounter() {
            return this.stepCounter;
        }

        @Override
        public int getEpochCounter() {
            return this.epochCounter;
        }

        @Override
        public double getReward() {
            return this.reward;
        }

        public int getEpisodeLength() {
            return this.episodeLength;
        }

        public List<Double> getScores() {
            return this.scores;
        }

        public float getEpsilon() {
            return this.epsilon;
        }

        public double getStartQ() {
            return this.startQ;
        }

        public double getMeanQ() {
            return this.meanQ;
        }

        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof QLStatEntry)) {
                return false;
            }
            QLStatEntry other = (QLStatEntry)o;
            if (this.getStepCounter() != other.getStepCounter()) {
                return false;
            }
            if (this.getEpochCounter() != other.getEpochCounter()) {
                return false;
            }
            if (Double.compare(this.getReward(), other.getReward()) != 0) {
                return false;
            }
            if (this.getEpisodeLength() != other.getEpisodeLength()) {
                return false;
            }
            List<Double> this$scores = this.getScores();
            List<Double> other$scores = other.getScores();
            if (this$scores == null ? other$scores != null : !((Object)this$scores).equals(other$scores)) {
                return false;
            }
            if (Float.compare(this.getEpsilon(), other.getEpsilon()) != 0) {
                return false;
            }
            if (Double.compare(this.getStartQ(), other.getStartQ()) != 0) {
                return false;
            }
            return Double.compare(this.getMeanQ(), other.getMeanQ()) == 0;
        }

        public int hashCode() {
            int PRIME = 59;
            int result = 1;
            result = result * 59 + this.getStepCounter();
            result = result * 59 + this.getEpochCounter();
            long $reward = Double.doubleToLongBits(this.getReward());
            result = result * 59 + (int)($reward >>> 32 ^ $reward);
            result = result * 59 + this.getEpisodeLength();
            List<Double> $scores = this.getScores();
            result = result * 59 + ($scores == null ? 43 : ((Object)$scores).hashCode());
            result = result * 59 + Float.floatToIntBits(this.getEpsilon());
            long $startQ = Double.doubleToLongBits(this.getStartQ());
            result = result * 59 + (int)($startQ >>> 32 ^ $startQ);
            long $meanQ = Double.doubleToLongBits(this.getMeanQ());
            result = result * 59 + (int)($meanQ >>> 32 ^ $meanQ);
            return result;
        }

        public String toString() {
            return "QLearning.QLStatEntry(stepCounter=" + this.getStepCounter() + ", epochCounter=" + this.getEpochCounter() + ", reward=" + this.getReward() + ", episodeLength=" + this.getEpisodeLength() + ", scores=" + this.getScores() + ", epsilon=" + this.getEpsilon() + ", startQ=" + this.getStartQ() + ", meanQ=" + this.getMeanQ() + ")";
        }

        public QLStatEntry(int stepCounter, int epochCounter, double reward, int episodeLength, List<Double> scores, float epsilon, double startQ, double meanQ) {
            this.stepCounter = stepCounter;
            this.epochCounter = epochCounter;
            this.reward = reward;
            this.episodeLength = episodeLength;
            this.scores = scores;
            this.epsilon = epsilon;
            this.startQ = startQ;
            this.meanQ = meanQ;
        }

        public static class QLStatEntryBuilder {
            private int stepCounter;
            private int epochCounter;
            private double reward;
            private int episodeLength;
            private List<Double> scores;
            private float epsilon;
            private double startQ;
            private double meanQ;

            QLStatEntryBuilder() {
            }

            public QLStatEntryBuilder stepCounter(int stepCounter) {
                this.stepCounter = stepCounter;
                return this;
            }

            public QLStatEntryBuilder epochCounter(int epochCounter) {
                this.epochCounter = epochCounter;
                return this;
            }

            public QLStatEntryBuilder reward(double reward) {
                this.reward = reward;
                return this;
            }

            public QLStatEntryBuilder episodeLength(int episodeLength) {
                this.episodeLength = episodeLength;
                return this;
            }

            public QLStatEntryBuilder scores(List<Double> scores) {
                this.scores = scores;
                return this;
            }

            public QLStatEntryBuilder epsilon(float epsilon) {
                this.epsilon = epsilon;
                return this;
            }

            public QLStatEntryBuilder startQ(double startQ) {
                this.startQ = startQ;
                return this;
            }

            public QLStatEntryBuilder meanQ(double meanQ) {
                this.meanQ = meanQ;
                return this;
            }

            public QLStatEntry build() {
                return new QLStatEntry(this.stepCounter, this.epochCounter, this.reward, this.episodeLength, this.scores, this.epsilon, this.startQ, this.meanQ);
            }

            public String toString() {
                return "QLearning.QLStatEntry.QLStatEntryBuilder(stepCounter=" + this.stepCounter + ", epochCounter=" + this.epochCounter + ", reward=" + this.reward + ", episodeLength=" + this.episodeLength + ", scores=" + this.scores + ", epsilon=" + this.epsilon + ", startQ=" + this.startQ + ", meanQ=" + this.meanQ + ")";
            }
        }
    }
}

