package org.deeplearning4j.rl4j.mdp.toy;

import java.util.Random;
import java.util.logging.Logger;
import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.deeplearning4j.rl4j.space.ArrayObservationSpace;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.ObservationSpace;
import org.json.JSONObject;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/deeplearning4j/rl4j/mdp/toy/HardDeteministicToy.class */
public class HardDeteministicToy implements MDP<HardToyState, Integer, DiscreteSpace> {
    private static final int ACTION_SIZE = 10;
    private DiscreteSpace actionSpace = new DiscreteSpace(ACTION_SIZE);
    private ObservationSpace<HardToyState> observationSpace = new ArrayObservationSpace(new int[]{ACTION_SIZE});
    private HardToyState hardToyState;
    private static final int MAX_STEP = 20;
    private static final int SEED = 1234;
    private static final HardToyState[] states = genToyStates(MAX_STEP, SEED);

    public static void printTest(IDQN idqn) {
        INDArray create = Nd4j.create(MAX_STEP, ACTION_SIZE);
        for (int i = 0; i < MAX_STEP; i++) {
            create.putRow(i, Nd4j.create(states[i].toArray()));
        }
        Logger.getAnonymousLogger().info(Nd4j.max(idqn.output(create), 1).toString());
    }

    public static int maxIndex(double[] dArr) {
        double d = -4.9E-324d;
        int i = -1;
        for (int i2 = 0; i2 < dArr.length; i2++) {
            if (dArr[i2] > d) {
                d = dArr[i2];
                i = i2;
            }
        }
        return i;
    }

    public static HardToyState[] genToyStates(int i, int i2) {
        Random random = new Random(i2);
        HardToyState[] hardToyStateArr = new HardToyState[i];
        for (int i3 = 0; i3 < i; i3++) {
            double[] dArr = new double[ACTION_SIZE];
            for (int i4 = 0; i4 < ACTION_SIZE; i4++) {
                dArr[i4] = random.nextDouble();
            }
            hardToyStateArr[i3] = new HardToyState(dArr, i3);
        }
        return hardToyStateArr;
    }

    public void close() {
    }

    public boolean isDone() {
        return this.hardToyState.getStep() == 19;
    }

    /* renamed from: reset, reason: merged with bridge method [inline-methods] */
    public HardToyState m3reset() {
        HardToyState hardToyState = states[0];
        this.hardToyState = hardToyState;
        return hardToyState;
    }

    public StepReply<HardToyState> step(Integer num) {
        double d = 0.0d;
        if (num.intValue() == maxIndex(this.hardToyState.getValues())) {
            d = 0.0d + 1.0d;
        }
        this.hardToyState = states[this.hardToyState.getStep() + 1];
        return new StepReply<>(this.hardToyState, d, isDone(), new JSONObject("{}"));
    }

    /* renamed from: newInstance, reason: merged with bridge method [inline-methods] */
    public HardDeteministicToy m2newInstance() {
        return new HardDeteministicToy();
    }

    /* renamed from: getActionSpace, reason: merged with bridge method [inline-methods] */
    public DiscreteSpace m4getActionSpace() {
        return this.actionSpace;
    }

    public ObservationSpace<HardToyState> getObservationSpace() {
        return this.observationSpace;
    }
}
