package org.deeplearning4j.rl4j.mdp.gym;

import java.io.IOException;
import org.bytedeco.cpython.PyCompilerFlags;
import org.bytedeco.cpython.PyObject;
import org.bytedeco.cpython.global.python;
import org.bytedeco.gym.presets.gym;
import org.bytedeco.javacpp.DoublePointer;
import org.bytedeco.javacpp.Pointer;
import org.bytedeco.javacpp.SizeTPointer;
import org.bytedeco.numpy.PyArrayObject;
import org.bytedeco.numpy.global.numpy;
import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.ArrayObservationSpace;
import org.deeplearning4j.rl4j.space.Box;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.space.ObservationSpace;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/rl4j/mdp/gym/GymEnv.class */
public class GymEnv<OBSERVATION extends Encodable, A, AS extends ActionSpace<A>> implements MDP<OBSERVATION, A, AS> {
    private static final Logger log = LoggerFactory.getLogger(GymEnv.class);
    public static final String GYM_MONITOR_DIR = "/tmp/gym-dqn";
    private static Pointer program;
    private static PyObject globals;
    private PyObject locals;
    protected final DiscreteSpace actionSpace;
    protected final ObservationSpace<OBSERVATION> observationSpace;
    private final String envId;
    private final boolean render;
    private final boolean monitor;
    private ActionTransformer actionTransformer;
    private boolean done;

    private static void checkPythonError() {
        if (python.PyErr_Occurred() != null) {
            python.PyErr_Print();
            throw new RuntimeException("Python error occurred");
        }
    }

    public GymEnv(String str, boolean z, boolean z2) {
        this(str, z, z2, (Integer) null);
    }

    public GymEnv(String str, boolean z, boolean z2, Integer num) {
        this.actionTransformer = null;
        this.done = false;
        this.envId = str;
        this.render = z;
        this.monitor = z2;
        int PyGILState_Ensure = python.PyGILState_Ensure();
        try {
            this.locals = python.PyDict_New();
            python.Py_DecRef(python.PyRun_StringFlags("import gym; env = gym.make('" + str + "')", 256, globals, this.locals, (PyCompilerFlags) null));
            checkPythonError();
            if (z2) {
                python.Py_DecRef(python.PyRun_StringFlags("env = gym.wrappers.Monitor(env, '/tmp/gym-dqn')", 256, globals, this.locals, (PyCompilerFlags) null));
                checkPythonError();
            }
            if (num != null) {
                python.Py_DecRef(python.PyRun_StringFlags("env.seed(" + num + ")", 256, globals, this.locals, (PyCompilerFlags) null));
                checkPythonError();
            }
            PyObject PyRun_StringFlags = python.PyRun_StringFlags("env.observation_space.shape", 258, globals, this.locals, (PyCompilerFlags) null);
            int[] iArr = new int[(int) python.PyTuple_Size(PyRun_StringFlags)];
            for (int i = 0; i < iArr.length; i++) {
                iArr[i] = (int) python.PyLong_AsLong(python.PyTuple_GetItem(PyRun_StringFlags, i));
            }
            this.observationSpace = new ArrayObservationSpace(iArr);
            python.Py_DecRef(PyRun_StringFlags);
            PyObject PyRun_StringFlags2 = python.PyRun_StringFlags("env.action_space.n", 258, globals, this.locals, (PyCompilerFlags) null);
            this.actionSpace = new DiscreteSpace((int) python.PyLong_AsLong(PyRun_StringFlags2));
            python.Py_DecRef(PyRun_StringFlags2);
            checkPythonError();
            python.PyGILState_Release(PyGILState_Ensure);
        } catch (Throwable th) {
            python.PyGILState_Release(PyGILState_Ensure);
            throw th;
        }
    }

    public GymEnv(String str, boolean z, boolean z2, int[] iArr) {
        this(str, z, z2, null, iArr);
    }

    public GymEnv(String str, boolean z, boolean z2, Integer num, int[] iArr) {
        this(str, z, z2, num);
        this.actionTransformer = new ActionTransformer(getActionSpace(), iArr);
    }

    public ObservationSpace<OBSERVATION> getObservationSpace() {
        return this.observationSpace;
    }

    public AS getActionSpace() {
        return this.actionTransformer == null ? this.actionSpace : this.actionTransformer;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public StepReply<OBSERVATION> step(A a) {
        int PyGILState_Ensure = python.PyGILState_Ensure();
        try {
            if (this.render) {
                python.Py_DecRef(python.PyRun_StringFlags("env.render()", 256, globals, this.locals, (PyCompilerFlags) null));
                checkPythonError();
            }
            python.Py_DecRef(python.PyRun_StringFlags("state, reward, done, info = env.step(" + ((Integer) a) + ")", 256, globals, this.locals, (PyCompilerFlags) null));
            checkPythonError();
            PyArrayObject pyArrayObject = new PyArrayObject(python.PyDict_GetItemString(this.locals, "state"));
            DoublePointer capacity = new DoublePointer(numpy.PyArray_BYTES(pyArrayObject)).capacity(numpy.PyArray_Size(pyArrayObject));
            numpy.PyArray_DIMS(pyArrayObject).capacity(numpy.PyArray_NDIM(pyArrayObject));
            double PyFloat_AsDouble = python.PyFloat_AsDouble(python.PyDict_GetItemString(this.locals, "reward"));
            this.done = python.PyLong_AsLong(python.PyDict_GetItemString(this.locals, "done")) != 0;
            checkPythonError();
            double[] dArr = new double[(int) capacity.capacity()];
            capacity.get(dArr);
            StepReply<OBSERVATION> stepReply = new StepReply<>(new Box(dArr), PyFloat_AsDouble, this.done, (Object) null);
            python.PyGILState_Release(PyGILState_Ensure);
            return stepReply;
        } catch (Throwable th) {
            python.PyGILState_Release(PyGILState_Ensure);
            throw th;
        }
    }

    public boolean isDone() {
        return this.done;
    }

    public OBSERVATION reset() {
        int PyGILState_Ensure = python.PyGILState_Ensure();
        try {
            python.Py_DecRef(python.PyRun_StringFlags("state = env.reset()", 256, globals, this.locals, (PyCompilerFlags) null));
            checkPythonError();
            PyArrayObject pyArrayObject = new PyArrayObject(python.PyDict_GetItemString(this.locals, "state"));
            DoublePointer capacity = new DoublePointer(numpy.PyArray_BYTES(pyArrayObject)).capacity(numpy.PyArray_Size(pyArrayObject));
            numpy.PyArray_DIMS(pyArrayObject).capacity(numpy.PyArray_NDIM(pyArrayObject));
            checkPythonError();
            this.done = false;
            double[] dArr = new double[(int) capacity.capacity()];
            capacity.get(dArr);
            Box box = new Box(dArr);
            python.PyGILState_Release(PyGILState_Ensure);
            return box;
        } catch (Throwable th) {
            python.PyGILState_Release(PyGILState_Ensure);
            throw th;
        }
    }

    public void close() {
        int PyGILState_Ensure = python.PyGILState_Ensure();
        try {
            python.Py_DecRef(python.PyRun_StringFlags("env.close()", 256, globals, this.locals, (PyCompilerFlags) null));
            checkPythonError();
            python.Py_DecRef(this.locals);
        } finally {
            python.PyGILState_Release(PyGILState_Ensure);
        }
    }

    /* renamed from: newInstance, reason: merged with bridge method [inline-methods] */
    public GymEnv<OBSERVATION, A, AS> m1newInstance() {
        return new GymEnv<>(this.envId, this.render, this.monitor);
    }

    public String getEnvId() {
        return this.envId;
    }

    public boolean isRender() {
        return this.render;
    }

    public boolean isMonitor() {
        return this.monitor;
    }

    static {
        try {
            python.Py_AddPath(gym.cachePackages());
            program = python.Py_DecodeLocale(GymEnv.class.getSimpleName(), (SizeTPointer) null);
            python.Py_SetProgramName(program);
            python.Py_Initialize();
            python.PyEval_InitThreads();
            python.PySys_SetArgvEx(1, program, 0);
            if (numpy._import_array() < 0) {
                python.PyErr_Print();
                throw new RuntimeException("numpy.core.multiarray failed to import");
            }
            globals = python.PyModule_GetDict(python.PyImport_AddModule("__main__"));
            python.PyEval_SaveThread();
        } catch (IOException e) {
            python.PyMem_RawFree(program);
            throw new RuntimeException(e);
        }
    }
}
