/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.rl4j.mdp.toy;

import java.util.Random;
import java.util.logging.Logger;
import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.mdp.toy.HardToyState;
import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.deeplearning4j.rl4j.space.ArrayObservationSpace;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.ObservationSpace;
import org.json.JSONObject;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

public class HardDeteministicToy
implements MDP<HardToyState, Integer, DiscreteSpace> {
    private static final int MAX_STEP = 20;
    private static final int SEED = 1234;
    private static final int ACTION_SIZE = 10;
    private static final HardToyState[] states = HardDeteministicToy.genToyStates(20, 1234);
    private DiscreteSpace actionSpace = new DiscreteSpace(10);
    private ObservationSpace<HardToyState> observationSpace = new ArrayObservationSpace(new int[]{10});
    private HardToyState hardToyState;

    public static void printTest(IDQN idqn) {
        INDArray input = Nd4j.create((int)20, (int)10);
        for (int i = 0; i < 20; ++i) {
            input.putRow((long)i, Nd4j.create((double[])states[i].toArray()));
        }
        INDArray output = Nd4j.max((INDArray)idqn.output(input), (int)1);
        Logger.getAnonymousLogger().info(output.toString());
    }

    public static int maxIndex(double[] values) {
        double maxValue = -4.9E-324;
        int maxIndex = -1;
        for (int i = 0; i < values.length; ++i) {
            if (!(values[i] > maxValue)) continue;
            maxValue = values[i];
            maxIndex = i;
        }
        return maxIndex;
    }

    public static HardToyState[] genToyStates(int size, int seed) {
        Random rd = new Random(seed);
        HardToyState[] hardToyStates = new HardToyState[size];
        for (int i = 0; i < size; ++i) {
            double[] values = new double[10];
            for (int j = 0; j < 10; ++j) {
                values[j] = rd.nextDouble();
            }
            hardToyStates[i] = new HardToyState(values, i);
        }
        return hardToyStates;
    }

    public void close() {
    }

    public boolean isDone() {
        return this.hardToyState.getStep() == 19;
    }

    public HardToyState reset() {
        this.hardToyState = states[0];
        return this.hardToyState;
    }

    public StepReply<HardToyState> step(Integer a) {
        double reward = 0.0;
        if (a == HardDeteministicToy.maxIndex(this.hardToyState.getValues())) {
            reward += 1.0;
        }
        this.hardToyState = states[this.hardToyState.getStep() + 1];
        return new StepReply((Object)this.hardToyState, reward, this.isDone(), new JSONObject("{}"));
    }

    public HardDeteministicToy newInstance() {
        return new HardDeteministicToy();
    }

    public DiscreteSpace getActionSpace() {
        return this.actionSpace;
    }

    public ObservationSpace<HardToyState> getObservationSpace() {
        return this.observationSpace;
    }
}

