/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete;

import java.util.ArrayList;
import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.learning.sync.Transition;
import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.deeplearning4j.rl4j.policy.DQNPolicy;
import org.deeplearning4j.rl4j.policy.EpsGreedy;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.DataManager;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.util.ArrayUtil;

public abstract class QLearningDiscrete<O extends Encodable>
extends QLearning<O, Integer, DiscreteSpace> {
    private final QLearning.QLConfiguration configuration;
    private final DataManager dataManager;
    private final MDP<O, Integer, DiscreteSpace> mdp;
    private final IDQN currentDQN;
    private DQNPolicy<O> policy;
    private EpsGreedy<O, Integer, DiscreteSpace> egPolicy;
    private IDQN targetDQN;
    private int lastAction;
    private INDArray[] history = null;
    private double accuReward = 0.0;
    private int lastMonitor = -10000;

    public QLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN dqn, QLearning.QLConfiguration conf, DataManager dataManager, int epsilonNbStep) {
        super(conf);
        this.configuration = conf;
        this.mdp = mdp;
        this.dataManager = dataManager;
        this.currentDQN = dqn;
        this.targetDQN = dqn.clone();
        this.policy = new DQNPolicy(this.getCurrentDQN());
        this.egPolicy = new EpsGreedy<O, Integer, DiscreteSpace>(this.policy, mdp, conf.getUpdateStart(), epsilonNbStep, this.getRandom(), conf.getMinEpsilon(), this);
        ((DiscreteSpace)mdp.getActionSpace()).setSeed(conf.getSeed());
    }

    @Override
    public void postEpoch() {
        if (this.getHistoryProcessor() != null) {
            this.getHistoryProcessor().stopMonitor();
        }
    }

    @Override
    public void preEpoch() {
        this.history = null;
        this.lastAction = 0;
        this.accuReward = 0.0;
        if (this.getStepCounter() - this.lastMonitor >= 10000 && this.getHistoryProcessor() != null && this.getDataManager().isSaveData()) {
            this.lastMonitor = this.getStepCounter();
            int[] shape = this.getMdp().getObservationSpace().getShape();
            this.getHistoryProcessor().startMonitor(this.getDataManager().getVideoDir() + "/video-" + this.getEpochCounter() + "-" + this.getStepCounter() + ".mp4", shape);
        }
    }

    @Override
    protected QLearning.QLStepReturn<O> trainStep(O obs) {
        Integer action;
        boolean isHistoryProcessor;
        INDArray input = this.getInput(obs);
        boolean bl = isHistoryProcessor = this.getHistoryProcessor() != null;
        if (isHistoryProcessor) {
            this.getHistoryProcessor().record(input);
        }
        int skipFrame = isHistoryProcessor ? this.getHistoryProcessor().getConf().getSkipFrame() : 1;
        int historyLength = isHistoryProcessor ? this.getHistoryProcessor().getConf().getHistoryLength() : 1;
        int updateStart = this.getConfiguration().getUpdateStart() + (this.getConfiguration().getBatchSize() + historyLength) * skipFrame;
        Double maxQ = Double.NaN;
        if (this.getStepCounter() % skipFrame != 0) {
            action = this.lastAction;
        } else {
            if (this.history == null) {
                if (isHistoryProcessor) {
                    this.getHistoryProcessor().add(input);
                    this.history = this.getHistoryProcessor().getHistory();
                } else {
                    this.history = new INDArray[]{input};
                }
            }
            INDArray hstack = Transition.concat(Transition.dup(this.history));
            if (isHistoryProcessor) {
                hstack.muli((Number)(1.0 / this.getHistoryProcessor().getScale()));
            }
            if (hstack.shape().length > 2) {
                hstack = hstack.reshape(Learning.makeShape(1, ArrayUtil.toInts((long[])hstack.shape())));
            }
            INDArray qs = this.getCurrentDQN().output(hstack);
            int maxAction = Learning.getMaxAction(qs);
            maxQ = qs.getDouble((long)maxAction);
            action = this.getEgPolicy().nextAction(hstack);
        }
        this.lastAction = action;
        StepReply stepReply = this.getMdp().step((Object)action);
        this.accuReward += stepReply.getReward() * this.configuration.getRewardFactor();
        if (this.getStepCounter() % skipFrame == 0 || stepReply.isDone()) {
            INDArray[] iNDArrayArray;
            INDArray ninput = this.getInput((Encodable)stepReply.getObservation());
            if (isHistoryProcessor) {
                this.getHistoryProcessor().add(ninput);
            }
            if (isHistoryProcessor) {
                iNDArrayArray = this.getHistoryProcessor().getHistory();
            } else {
                INDArray[] iNDArrayArray2 = new INDArray[1];
                iNDArrayArray = iNDArrayArray2;
                iNDArrayArray2[0] = ninput;
            }
            INDArray[] nhistory = iNDArrayArray;
            Transition<Integer> trans = new Transition<Integer>(this.history, action, this.accuReward, stepReply.isDone(), nhistory[0]);
            this.getExpReplay().store(trans);
            if (this.getStepCounter() > updateStart) {
                Pair<INDArray, INDArray> targets = this.setTarget(this.getExpReplay().getBatch());
                this.getCurrentDQN().fit((INDArray)targets.getFirst(), (INDArray)targets.getSecond());
            }
            this.history = nhistory;
            this.accuReward = 0.0;
        }
        return new QLearning.QLStepReturn(maxQ, this.getCurrentDQN().getLatestScore(), stepReply);
    }

    protected Pair<INDArray, INDArray> setTarget(ArrayList<Transition<Integer>> transitions) {
        if (transitions.size() == 0) {
            throw new IllegalArgumentException("too few transitions");
        }
        int size = transitions.size();
        int[] shape = this.getHistoryProcessor() == null ? this.getMdp().getObservationSpace().getShape() : this.getHistoryProcessor().getConf().getShape();
        int[] nshape = QLearningDiscrete.makeShape(size, shape);
        INDArray obs = Nd4j.create((int[])nshape);
        INDArray nextObs = Nd4j.create((int[])nshape);
        int[] actions = new int[size];
        boolean[] areTerminal = new boolean[size];
        for (int i = 0; i < size; ++i) {
            Transition<Integer> trans = transitions.get(i);
            areTerminal[i] = trans.isTerminal();
            actions[i] = trans.getAction();
            INDArray[] obsArray = trans.getObservation();
            if (obs.rank() == 2) {
                obs.putRow((long)i, obsArray[0]);
            } else {
                for (int j = 0; j < obsArray.length; ++j) {
                    obs.put(new INDArrayIndex[]{NDArrayIndex.point((long)i), NDArrayIndex.point((long)j)}, obsArray[j]);
                }
            }
            INDArray[] nextObsArray = Transition.append(trans.getObservation(), trans.getNextObservation());
            if (nextObs.rank() == 2) {
                nextObs.putRow((long)i, nextObsArray[0]);
                continue;
            }
            for (int j = 0; j < nextObsArray.length; ++j) {
                nextObs.put(new INDArrayIndex[]{NDArrayIndex.point((long)i), NDArrayIndex.point((long)j)}, nextObsArray[j]);
            }
        }
        if (this.getHistoryProcessor() != null) {
            obs.muli((Number)(1.0 / this.getHistoryProcessor().getScale()));
            nextObs.muli((Number)(1.0 / this.getHistoryProcessor().getScale()));
        }
        INDArray dqnOutputAr = this.dqnOutput(obs);
        INDArray dqnOutputNext = this.dqnOutput(nextObs);
        INDArray targetDqnOutputNext = null;
        INDArray tempQ = null;
        INDArray getMaxAction = null;
        if (this.getConfiguration().isDoubleDQN()) {
            targetDqnOutputNext = this.targetDqnOutput(nextObs);
            getMaxAction = Nd4j.argMax((INDArray)dqnOutputNext, (int[])new int[]{1});
        } else {
            tempQ = Nd4j.max((INDArray)dqnOutputNext, (int)1);
        }
        for (int i = 0; i < size; ++i) {
            double yTar = transitions.get(i).getReward();
            if (!areTerminal[i]) {
                double q = 0.0;
                q = this.getConfiguration().isDoubleDQN() ? (q += targetDqnOutputNext.getDouble((long)i, (long)getMaxAction.getInt(new int[]{i}))) : (q += tempQ.getDouble((long)i));
                yTar += this.getConfiguration().getGamma() * q;
            }
            double previousV = dqnOutputAr.getDouble((long)i, (long)actions[i]);
            double lowB = previousV - this.getConfiguration().getErrorClamp();
            double highB = previousV + this.getConfiguration().getErrorClamp();
            double clamped = Math.min(highB, Math.max(yTar, lowB));
            dqnOutputAr.putScalar((long)i, (long)actions[i], clamped);
        }
        return new Pair((Object)obs, (Object)dqnOutputAr);
    }

    @Override
    public QLearning.QLConfiguration getConfiguration() {
        return this.configuration;
    }

    @Override
    public DataManager getDataManager() {
        return this.dataManager;
    }

    @Override
    public MDP<O, Integer, DiscreteSpace> getMdp() {
        return this.mdp;
    }

    @Override
    public IDQN getCurrentDQN() {
        return this.currentDQN;
    }

    public DQNPolicy<O> getPolicy() {
        return this.policy;
    }

    @Override
    public EpsGreedy<O, Integer, DiscreteSpace> getEgPolicy() {
        return this.egPolicy;
    }

    @Override
    public IDQN getTargetDQN() {
        return this.targetDQN;
    }

    @Override
    public void setTargetDQN(IDQN targetDQN) {
        this.targetDQN = targetDQN;
    }
}

