/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.rl4j.learning;

import java.util.Random;
import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.rl4j.learning.HistoryProcessor;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.learning.ILearning;
import org.deeplearning4j.rl4j.learning.NeuralNetFetchable;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.DataManager;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class Learning<O extends Encodable, A, AS extends ActionSpace<A>, NN extends NeuralNet>
implements ILearning<O, A, AS>,
NeuralNetFetchable<NN> {
    private static final Logger log = LoggerFactory.getLogger(Learning.class);
    private final Random random;
    private int stepCounter = 0;
    private int epochCounter = 0;
    private IHistoryProcessor historyProcessor = null;

    public Learning(ILearning.LConfiguration conf) {
        this.random = new Random(conf.getSeed());
    }

    public static Integer getMaxAction(INDArray vector) {
        return Nd4j.argMax((INDArray)vector, (int[])new int[]{Integer.MAX_VALUE}).getInt(new int[]{0});
    }

    public static <O extends Encodable, A, AS extends ActionSpace<A>> INDArray getInput(MDP<O, A, AS> mdp, O obs) {
        INDArray arr = Nd4j.create((double[])obs.toArray());
        int[] shape = mdp.getObservationSpace().getShape();
        if (shape.length == 1) {
            return arr;
        }
        return arr.reshape(shape);
    }

    public static <O extends Encodable, A, AS extends ActionSpace<A>> InitMdp<O> initMdp(MDP<O, A, AS> mdp, IHistoryProcessor hp) {
        int step;
        int requiredFrame;
        Encodable obs;
        Encodable nextO = obs = (Encodable)mdp.reset();
        double reward = 0.0;
        boolean isHistoryProcessor = hp != null;
        int skipFrame = isHistoryProcessor ? hp.getConf().getSkipFrame() : 1;
        int n = requiredFrame = isHistoryProcessor ? skipFrame * (hp.getConf().getHistoryLength() - 1) : 0;
        for (step = 0; step < requiredFrame; ++step) {
            INDArray input = Learning.getInput(mdp, obs);
            if (isHistoryProcessor) {
                hp.record(input);
            }
            Object action = mdp.getActionSpace().noOp();
            if (step % skipFrame == 0 && isHistoryProcessor) {
                hp.add(input);
            }
            StepReply stepReply = mdp.step(action);
            reward += stepReply.getReward();
            nextO = (Encodable)stepReply.getObservation();
        }
        return new InitMdp<Encodable>(step, nextO, reward);
    }

    public static int[] makeShape(int size, int[] shape) {
        int[] nshape = new int[shape.length + 1];
        nshape[0] = size;
        for (int i = 0; i < shape.length; ++i) {
            nshape[i + 1] = shape[i];
        }
        return nshape;
    }

    public static int[] makeShape(int batch, int[] shape, int length) {
        int[] nshape = new int[3];
        nshape[0] = batch;
        nshape[1] = 1;
        for (int i = 0; i < shape.length; ++i) {
            nshape[1] = nshape[1] * shape[i];
        }
        nshape[2] = length;
        return nshape;
    }

    protected abstract DataManager getDataManager();

    @Override
    public abstract NN getNeuralNet();

    public int incrementStep() {
        return this.stepCounter++;
    }

    public int incrementEpoch() {
        return this.epochCounter++;
    }

    public void setHistoryProcessor(IHistoryProcessor.Configuration conf) {
        this.historyProcessor = new HistoryProcessor(conf);
    }

    public void setHistoryProcessor(IHistoryProcessor historyProcessor) {
        this.historyProcessor = historyProcessor;
    }

    public INDArray getInput(O obs) {
        return Learning.getInput(this.getMdp(), obs);
    }

    public InitMdp<O> initMdp() {
        this.getNeuralNet().reset();
        return Learning.initMdp(this.getMdp(), this.getHistoryProcessor());
    }

    public Random getRandom() {
        return this.random;
    }

    @Override
    public int getStepCounter() {
        return this.stepCounter;
    }

    public void setStepCounter(int stepCounter) {
        this.stepCounter = stepCounter;
    }

    public int getEpochCounter() {
        return this.epochCounter;
    }

    public void setEpochCounter(int epochCounter) {
        this.epochCounter = epochCounter;
    }

    public IHistoryProcessor getHistoryProcessor() {
        return this.historyProcessor;
    }

    public static final class InitMdp<O> {
        private final int steps;
        private final O lastObs;
        private final double reward;

        public int getSteps() {
            return this.steps;
        }

        public O getLastObs() {
            return this.lastObs;
        }

        public double getReward() {
            return this.reward;
        }

        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof InitMdp)) {
                return false;
            }
            InitMdp other = (InitMdp)o;
            if (this.getSteps() != other.getSteps()) {
                return false;
            }
            O this$lastObs = this.getLastObs();
            O other$lastObs = other.getLastObs();
            if (this$lastObs == null ? other$lastObs != null : !this$lastObs.equals(other$lastObs)) {
                return false;
            }
            return Double.compare(this.getReward(), other.getReward()) == 0;
        }

        public int hashCode() {
            int PRIME = 59;
            int result = 1;
            result = result * 59 + this.getSteps();
            O $lastObs = this.getLastObs();
            result = result * 59 + ($lastObs == null ? 43 : $lastObs.hashCode());
            long $reward = Double.doubleToLongBits(this.getReward());
            result = result * 59 + (int)($reward >>> 32 ^ $reward);
            return result;
        }

        public String toString() {
            return "Learning.InitMdp(steps=" + this.getSteps() + ", lastObs=" + this.getLastObs() + ", reward=" + this.getReward() + ")";
        }

        public InitMdp(int steps, O lastObs, double reward) {
            this.steps = steps;
            this.lastObs = lastObs;
            this.reward = reward;
        }
    }
}

