package org.deeplearning4j.rl4j.mdp.ale;

import org.bytedeco.ale.ALEInterface;
import org.bytedeco.javacpp.IntPointer;
import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.space.ArrayObservationSpace;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.space.ObservationSpace;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/rl4j/mdp/ale/ALEMDP.class */
public class ALEMDP implements MDP<GameScreen, Integer, DiscreteSpace> {
    private static final Logger log = LoggerFactory.getLogger(ALEMDP.class);
    protected ALEInterface ale;
    protected final int[] actions;
    protected final DiscreteSpace discreteSpace;
    protected final ObservationSpace<GameScreen> observationSpace;
    protected final String romFile;
    protected final boolean render;
    protected final Configuration configuration;
    protected double scaleFactor;
    private byte[] screenBuffer;

    /* loaded from: input_file:org/deeplearning4j/rl4j/mdp/ale/ALEMDP$Configuration.class */
    public static final class Configuration {
        private final int randomSeed;
        private final float repeatActionProbability;
        private final int maxNumFrames;
        private final int maxNumFramesPerEpisode;
        private final boolean minimalActionSet;

        public Configuration(int i, float f, int i2, int i3, boolean z) {
            this.randomSeed = i;
            this.repeatActionProbability = f;
            this.maxNumFrames = i2;
            this.maxNumFramesPerEpisode = i3;
            this.minimalActionSet = z;
        }

        public int getRandomSeed() {
            return this.randomSeed;
        }

        public float getRepeatActionProbability() {
            return this.repeatActionProbability;
        }

        public int getMaxNumFrames() {
            return this.maxNumFrames;
        }

        public int getMaxNumFramesPerEpisode() {
            return this.maxNumFramesPerEpisode;
        }

        public boolean isMinimalActionSet() {
            return this.minimalActionSet;
        }

        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof Configuration)) {
                return false;
            }
            Configuration configuration = (Configuration) obj;
            return getRandomSeed() == configuration.getRandomSeed() && Float.compare(getRepeatActionProbability(), configuration.getRepeatActionProbability()) == 0 && getMaxNumFrames() == configuration.getMaxNumFrames() && getMaxNumFramesPerEpisode() == configuration.getMaxNumFramesPerEpisode() && isMinimalActionSet() == configuration.isMinimalActionSet();
        }

        public int hashCode() {
            return (((((((((1 * 59) + getRandomSeed()) * 59) + Float.floatToIntBits(getRepeatActionProbability())) * 59) + getMaxNumFrames()) * 59) + getMaxNumFramesPerEpisode()) * 59) + (isMinimalActionSet() ? 79 : 97);
        }

        public String toString() {
            return "ALEMDP.Configuration(randomSeed=" + getRandomSeed() + ", repeatActionProbability=" + getRepeatActionProbability() + ", maxNumFrames=" + getMaxNumFrames() + ", maxNumFramesPerEpisode=" + getMaxNumFramesPerEpisode() + ", minimalActionSet=" + isMinimalActionSet() + ")";
        }
    }

    /* loaded from: input_file:org/deeplearning4j/rl4j/mdp/ale/ALEMDP$GameScreen.class */
    public static class GameScreen implements Encodable {
        final INDArray data;

        public GameScreen(int[] iArr, byte[] bArr) {
            this.data = Nd4j.create(bArr, new long[]{iArr[1], iArr[2], 3}, DataType.UINT8).permute(new int[]{2, 0, 1});
        }

        private GameScreen(INDArray iNDArray) {
            this.data = iNDArray.dup();
        }

        public double[] toArray() {
            return this.data.data().asDouble();
        }

        public boolean isSkipped() {
            return false;
        }

        public INDArray getData() {
            return this.data;
        }

        /* renamed from: dup, reason: merged with bridge method [inline-methods] */
        public GameScreen m4dup() {
            return new GameScreen(this.data);
        }
    }

    public ALEMDP(String str) {
        this(str, false);
    }

    public ALEMDP(String str, boolean z) {
        this(str, z, new Configuration(123, 0.0f, 0, 0, true));
    }

    public ALEMDP(String str, boolean z, Configuration configuration) {
        this.scaleFactor = 1.0d;
        this.romFile = str;
        this.configuration = configuration;
        this.render = z;
        this.ale = new ALEInterface();
        setupGame();
        IntPointer minimalActionSet = getConfiguration().minimalActionSet ? this.ale.getMinimalActionSet() : this.ale.getLegalActionSet();
        this.actions = new int[(int) minimalActionSet.limit()];
        minimalActionSet.get(this.actions);
        int height = (int) this.ale.getScreen().height();
        int width = (int) this.ale.getScreen().width();
        this.discreteSpace = new DiscreteSpace(this.actions.length);
        int[] iArr = {3, height, width};
        this.observationSpace = new ArrayObservationSpace(iArr);
        this.screenBuffer = new byte[iArr[0] * iArr[1] * iArr[2]];
    }

    public void setupGame() {
        Configuration configuration = getConfiguration();
        this.ale.setInt("random_seed", configuration.randomSeed);
        this.ale.setFloat("repeat_action_probability", configuration.repeatActionProbability);
        this.ale.setBool("display_screen", this.render);
        this.ale.setBool("sound", this.render);
        this.ale.setInt("max_num_frames", configuration.maxNumFrames);
        this.ale.setInt("max_num_frames_per_episode", configuration.maxNumFramesPerEpisode);
        this.ale.loadROM(this.romFile);
    }

    public boolean isDone() {
        return this.ale.game_over();
    }

    /* renamed from: reset, reason: merged with bridge method [inline-methods] */
    public GameScreen m2reset() {
        this.ale.reset_game();
        this.ale.getScreenRGB(this.screenBuffer);
        return new GameScreen(this.observationSpace.getShape(), this.screenBuffer);
    }

    public void close() {
        this.ale.deallocate();
    }

    public StepReply<GameScreen> step(Integer num) {
        double act = this.ale.act(this.actions[num.intValue()]) * this.scaleFactor;
        log.info(this.ale.getEpisodeFrameNumber() + " " + act + " " + num + " ");
        this.ale.getScreenRGB(this.screenBuffer);
        return new StepReply<>(new GameScreen(this.observationSpace.getShape(), this.screenBuffer), act, this.ale.game_over(), (Object) null);
    }

    public ObservationSpace<GameScreen> getObservationSpace() {
        return this.observationSpace;
    }

    /* renamed from: getActionSpace, reason: merged with bridge method [inline-methods] */
    public DiscreteSpace m3getActionSpace() {
        return this.discreteSpace;
    }

    /* renamed from: newInstance, reason: merged with bridge method [inline-methods] */
    public ALEMDP m1newInstance() {
        return new ALEMDP(this.romFile, this.render, this.configuration);
    }

    public String getRomFile() {
        return this.romFile;
    }

    public boolean isRender() {
        return this.render;
    }

    public Configuration getConfiguration() {
        return this.configuration;
    }

    public void setScaleFactor(double d) {
        this.scaleFactor = d;
    }
}
