/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.rl4j.policy;

import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.rl4j.learning.HistoryProcessor;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.learning.sync.Transition;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.util.ArrayUtil;

public abstract class Policy<O extends Encodable, A> {
    public abstract NeuralNet getNeuralNet();

    public abstract A nextAction(INDArray var1);

    public <AS extends ActionSpace<A>> double play(MDP<O, A, AS> mdp) {
        return this.play(mdp, (IHistoryProcessor)null);
    }

    public <AS extends ActionSpace<A>> double play(MDP<O, A, AS> mdp, IHistoryProcessor.Configuration conf) {
        return this.play(mdp, new HistoryProcessor(conf));
    }

    public <AS extends ActionSpace<A>> double play(MDP<O, A, AS> mdp, IHistoryProcessor hp) {
        this.getNeuralNet().reset();
        Learning.InitMdp<O> initMdp = Learning.initMdp(mdp, hp);
        Encodable obs = (Encodable)initMdp.getLastObs();
        double reward = initMdp.getReward();
        Object lastAction = mdp.getActionSpace().noOp();
        int step = initMdp.getSteps();
        INDArray[] history = null;
        while (!mdp.isDone()) {
            INDArray[] iNDArrayArray;
            Object action;
            int skipFrame;
            boolean isHistoryProcessor;
            INDArray input = Learning.getInput(mdp, obs);
            boolean bl = isHistoryProcessor = hp != null;
            if (isHistoryProcessor) {
                hp.record(input);
            }
            int n = skipFrame = isHistoryProcessor ? hp.getConf().getSkipFrame() : 1;
            if (step % skipFrame != 0) {
                action = lastAction;
            } else {
                if (history == null) {
                    if (isHistoryProcessor) {
                        hp.add(input);
                        history = hp.getHistory();
                    } else {
                        history = new INDArray[]{input};
                    }
                }
                INDArray hstack = Transition.concat(history);
                if (isHistoryProcessor) {
                    hstack.muli((Number)(1.0 / hp.getScale()));
                }
                if (this.getNeuralNet().isRecurrent()) {
                    hstack = hstack.reshape(Learning.makeShape(1, ArrayUtil.toInts((long[])hstack.shape()), 1));
                } else if (hstack.shape().length > 2) {
                    hstack = hstack.reshape(Learning.makeShape(1, ArrayUtil.toInts((long[])hstack.shape())));
                }
                action = this.nextAction(hstack);
            }
            lastAction = action;
            StepReply stepReply = mdp.step(action);
            reward += stepReply.getReward();
            if (isHistoryProcessor) {
                hp.add(Learning.getInput(mdp, (Encodable)stepReply.getObservation()));
            }
            if (isHistoryProcessor) {
                iNDArrayArray = hp.getHistory();
            } else {
                INDArray[] iNDArrayArray2 = new INDArray[1];
                iNDArrayArray = iNDArrayArray2;
                iNDArrayArray2[0] = Learning.getInput(mdp, (Encodable)stepReply.getObservation());
            }
            history = iNDArrayArray;
            ++step;
        }
        return reward;
    }
}

