package org.deeplearning4j.rl4j.util;

import java.util.HashMap;
import java.util.Map;
import org.datavec.image.transform.ColorConversionTransform;
import org.datavec.image.transform.CropImageTransform;
import org.datavec.image.transform.ImageTransform;
import org.datavec.image.transform.MultiImageTransform;
import org.datavec.image.transform.ResizeImageTransform;
import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.observation.transform.EncodableToINDArrayTransform;
import org.deeplearning4j.rl4j.observation.transform.TransformProcess;
import org.deeplearning4j.rl4j.observation.transform.filter.UniformSkippingFilter;
import org.deeplearning4j.rl4j.observation.transform.legacy.EncodableToImageWritableTransform;
import org.deeplearning4j.rl4j.observation.transform.legacy.ImageWritableToINDArrayTransform;
import org.deeplearning4j.rl4j.observation.transform.operation.HistoryMergeTransform;
import org.deeplearning4j.rl4j.observation.transform.operation.SimpleNormalizationTransform;
import org.deeplearning4j.rl4j.space.ActionSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.space.ObservationSpace;
import org.nd4j.linalg.api.ndarray.INDArray;

/* loaded from: input_file:org/deeplearning4j/rl4j/util/LegacyMDPWrapper.class */
public class LegacyMDPWrapper<OBSERVATION extends Encodable, A, AS extends ActionSpace<A>> implements MDP<Observation, A, AS> {
    private final MDP<OBSERVATION, A, AS> wrappedMDP;
    private final WrapperObservationSpace observationSpace;
    private final int[] shape;
    private TransformProcess transformProcess;
    private IHistoryProcessor historyProcessor;
    private int skipFrame = 1;
    private int steps = 0;

    /* loaded from: input_file:org/deeplearning4j/rl4j/util/LegacyMDPWrapper$WrapperObservationSpace.class */
    public static class WrapperObservationSpace implements ObservationSpace<Observation> {
        private final int[] shape;

        public WrapperObservationSpace(int[] iArr) {
            this.shape = iArr;
        }

        public String getName() {
            return null;
        }

        public INDArray getLow() {
            return null;
        }

        public INDArray getHigh() {
            return null;
        }

        public int[] getShape() {
            return this.shape;
        }
    }

    public LegacyMDPWrapper(MDP<OBSERVATION, A, AS> mdp, IHistoryProcessor iHistoryProcessor) {
        this.wrappedMDP = mdp;
        this.shape = mdp.getObservationSpace().getShape();
        this.observationSpace = new WrapperObservationSpace(this.shape);
        this.historyProcessor = iHistoryProcessor;
        setHistoryProcessor(iHistoryProcessor);
    }

    public void setHistoryProcessor(IHistoryProcessor iHistoryProcessor) {
        this.historyProcessor = iHistoryProcessor;
        createTransformProcess();
    }

    private void createTransformProcess() {
        IHistoryProcessor historyProcessor = getHistoryProcessor();
        if (historyProcessor == null || this.shape.length != 3) {
            this.transformProcess = TransformProcess.builder().transform("data", new EncodableToINDArrayTransform()).build("data");
            return;
        }
        int skipFrame = historyProcessor.getConf().getSkipFrame();
        int historyLength = historyProcessor.getConf().getHistoryLength();
        int i = this.shape[1];
        int i2 = this.shape[2];
        this.transformProcess = TransformProcess.builder().filter(new UniformSkippingFilter(skipFrame)).transform("data", new EncodableToImageWritableTransform()).transform("data", new MultiImageTransform(new ImageTransform[]{new CropImageTransform(historyProcessor.getConf().getOffsetY(), historyProcessor.getConf().getOffsetX(), i - historyProcessor.getConf().getCroppingHeight(), i2 - historyProcessor.getConf().getCroppingWidth()), new ResizeImageTransform(historyProcessor.getConf().getRescaledWidth(), historyProcessor.getConf().getRescaledHeight()), new ColorConversionTransform(6)})).transform("data", new ImageWritableToINDArrayTransform()).transform("data", new SimpleNormalizationTransform(0.0d, 255.0d)).transform("data", HistoryMergeTransform.builder().isFirstDimenstionBatch(true).build(historyLength)).build("data");
    }

    public AS getActionSpace() {
        return (AS) this.wrappedMDP.getActionSpace();
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* renamed from: reset, reason: merged with bridge method [inline-methods] */
    public Observation m40reset() {
        this.transformProcess.reset();
        Encodable reset = this.wrappedMDP.reset();
        record(reset);
        if (this.historyProcessor != null) {
            this.skipFrame = this.historyProcessor.getConf().getSkipFrame();
        }
        return this.transformProcess.transform(buildChannelsData(reset), 0, false);
    }

    /* JADX WARN: Multi-variable type inference failed */
    public StepReply<Observation> step(A a) {
        IHistoryProcessor historyProcessor = getHistoryProcessor();
        StepReply step = this.wrappedMDP.step(a);
        INDArray input = getInput((Encodable) step.getObservation());
        if (historyProcessor != null) {
            historyProcessor.record(input);
        }
        int i = this.steps;
        this.steps = i + 1;
        return new StepReply<>(this.transformProcess.transform(buildChannelsData((Encodable) step.getObservation()), i, step.isDone()), step.getReward(), step.isDone(), step.getInfo());
    }

    private void record(OBSERVATION observation) {
        INDArray input = getInput(observation);
        IHistoryProcessor historyProcessor = getHistoryProcessor();
        if (historyProcessor != null) {
            historyProcessor.record(input);
        }
    }

    private Map<String, Object> buildChannelsData(final OBSERVATION observation) {
        return new HashMap<String, Object>() { // from class: org.deeplearning4j.rl4j.util.LegacyMDPWrapper.1
            {
                put("data", observation);
            }
        };
    }

    public void close() {
        this.wrappedMDP.close();
    }

    public boolean isDone() {
        return this.wrappedMDP.isDone();
    }

    public MDP<Observation, A, AS> newInstance() {
        return new LegacyMDPWrapper(this.wrappedMDP.newInstance(), this.historyProcessor);
    }

    private INDArray getInput(OBSERVATION observation) {
        return observation.getData();
    }

    public MDP<OBSERVATION, A, AS> getWrappedMDP() {
        return this.wrappedMDP;
    }

    /* renamed from: getObservationSpace, reason: merged with bridge method [inline-methods] */
    public WrapperObservationSpace m41getObservationSpace() {
        return this.observationSpace;
    }

    public void setTransformProcess(TransformProcess transformProcess) {
        this.transformProcess = transformProcess;
    }

    private IHistoryProcessor getHistoryProcessor() {
        return this.historyProcessor;
    }
}
