/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.rl4j.learning.async;

import java.util.Stack;
import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.learning.async.AsyncGlobal;
import org.deeplearning4j.rl4j.learning.async.AsyncThread;
import org.deeplearning4j.rl4j.learning.async.MiniTrans;
import org.deeplearning4j.rl4j.learning.sync.Transition;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.policy.Policy;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.ArrayUtil;

public abstract class AsyncThreadDiscrete<O extends Encodable, NN extends NeuralNet>
extends AsyncThread<O, Integer, DiscreteSpace, NN> {
    private NN current;

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public AsyncThreadDiscrete(AsyncGlobal<NN> asyncGlobal, int threadNumber) {
        super(asyncGlobal, threadNumber);
        AsyncGlobal<NN> asyncGlobal2 = asyncGlobal;
        synchronized (asyncGlobal2) {
            this.current = asyncGlobal.getCurrent().clone();
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public AsyncThread.SubEpochReturn<O> trainSubEpoch(O sObs, int nstep) {
        INDArray[] output;
        INDArray hstack;
        INDArray input;
        int i;
        AsyncGlobal asyncGlobal = this.getAsyncGlobal();
        synchronized (asyncGlobal) {
            this.current.copy(this.getAsyncGlobal().getCurrent());
        }
        Stack<MiniTrans<Integer>> rewards = new Stack<MiniTrans<Integer>>();
        Object obs = sObs;
        Policy policy = this.getPolicy(this.current);
        Integer lastAction = null;
        IHistoryProcessor hp = this.getHistoryProcessor();
        int skipFrame = hp != null ? hp.getConf().getSkipFrame() : 1;
        double reward = 0.0;
        double accuReward = 0.0;
        for (i = 0; !this.getMdp().isDone() && i < nstep * skipFrame; ++i) {
            Integer action;
            input = Learning.getInput(this.getMdp(), obs);
            hstack = null;
            if (hp != null) {
                hp.record(input);
            }
            if (i % skipFrame != 0 && lastAction != null) {
                action = lastAction;
            } else {
                hstack = this.processHistory(input);
                action = (Integer)policy.nextAction(hstack);
            }
            StepReply stepReply = this.getMdp().step(action);
            accuReward += stepReply.getReward() * this.getConf().getRewardFactor();
            if (i % skipFrame == 0 || lastAction == null || stepReply.isDone()) {
                obs = (Encodable)stepReply.getObservation();
                if (hstack == null) {
                    hstack = this.processHistory(input);
                }
                output = this.current.outputAll(hstack);
                rewards.add(new MiniTrans<Integer>(hstack, action, output, accuReward));
                accuReward = 0.0;
            }
            reward += stepReply.getReward();
            lastAction = action;
        }
        input = Learning.getInput(this.getMdp(), obs);
        hstack = this.processHistory(input);
        if (hp != null) {
            hp.record(input);
        }
        if (this.getMdp().isDone() && i < nstep * skipFrame) {
            rewards.add(new MiniTrans<Object>(hstack, null, null, 0.0));
        } else {
            INDArray[] output2 = null;
            if (this.getConf().getTargetDqnUpdateFreq() == -1) {
                output2 = this.current.outputAll(hstack);
            } else {
                output = this.getAsyncGlobal();
                synchronized (output) {
                    output2 = this.getAsyncGlobal().getTarget().outputAll(hstack);
                }
            }
            double maxQ = Nd4j.max((INDArray)output2[0]).getDouble(0L);
            rewards.add(new MiniTrans<Object>(hstack, null, output2, maxQ));
        }
        this.getAsyncGlobal().enqueue(this.calcGradient(this.current, rewards), i);
        return new AsyncThread.SubEpochReturn<O>(i, obs, reward, this.current.getLatestScore());
    }

    protected INDArray processHistory(INDArray input) {
        INDArray[] history;
        IHistoryProcessor hp = this.getHistoryProcessor();
        if (hp != null) {
            hp.add(input);
            history = hp.getHistory();
        } else {
            history = new INDArray[]{input};
        }
        INDArray hstack = Transition.concat(history);
        if (hp != null) {
            hstack.muli((Number)(1.0 / hp.getScale()));
        }
        if (this.getCurrent().isRecurrent()) {
            hstack = hstack.reshape(Learning.makeShape(1, ArrayUtil.toInts((long[])hstack.shape()), 1));
        } else if (hstack.shape().length > 2) {
            hstack = hstack.reshape(Learning.makeShape(1, ArrayUtil.toInts((long[])hstack.shape())));
        }
        return hstack;
    }

    public abstract Gradient[] calcGradient(NN var1, Stack<MiniTrans<Integer>> var2);

    @Override
    public NN getCurrent() {
        return this.current;
    }
}

