package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete;

import org.deeplearning4j.gym.StepReply;
import org.deeplearning4j.rl4j.agent.learning.algorithm.dqn.BaseTransitionTDAlgorithm;
import org.deeplearning4j.rl4j.agent.learning.algorithm.dqn.DoubleDQN;
import org.deeplearning4j.rl4j.agent.learning.algorithm.dqn.StandardDQN;
import org.deeplearning4j.rl4j.agent.learning.behavior.ILearningBehavior;
import org.deeplearning4j.rl4j.agent.learning.behavior.LearningBehavior;
import org.deeplearning4j.rl4j.agent.learning.update.UpdateRule;
import org.deeplearning4j.rl4j.agent.learning.update.updater.NeuralNetUpdaterConfiguration;
import org.deeplearning4j.rl4j.agent.learning.update.updater.sync.SyncLabelsNeuralNetUpdater;
import org.deeplearning4j.rl4j.experience.ReplayMemoryExperienceHandler;
import org.deeplearning4j.rl4j.learning.IHistoryProcessor;
import org.deeplearning4j.rl4j.learning.Learning;
import org.deeplearning4j.rl4j.learning.configuration.QLearningConfiguration;
import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning;
import org.deeplearning4j.rl4j.mdp.MDP;
import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.policy.DQNPolicy;
import org.deeplearning4j.rl4j.policy.EpsGreedy;
import org.deeplearning4j.rl4j.space.DiscreteSpace;
import org.deeplearning4j.rl4j.space.Encodable;
import org.deeplearning4j.rl4j.util.LegacyMDPWrapper;
import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscrete.class */
public abstract class QLearningDiscrete<O extends Encodable> extends QLearning<O, Integer, DiscreteSpace> {
    private final QLearningConfiguration configuration;
    private final LegacyMDPWrapper<O, Integer, DiscreteSpace> mdp;
    private DQNPolicy<O> policy;
    private EpsGreedy<Integer> egPolicy;
    private final IDQN qNetwork;
    private int lastAction;
    private double accuReward;
    private final ILearningBehavior<Integer> learningBehavior;

    @Override // org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning
    protected LegacyMDPWrapper<O, Integer, DiscreteSpace> getLegacyMDPWrapper() {
        return this.mdp;
    }

    public QLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN idqn, QLearningConfiguration qLearningConfiguration, int i) {
        this(mdp, idqn, qLearningConfiguration, i, Nd4j.getRandomFactory().getNewRandomInstance(qLearningConfiguration.getSeed().longValue()));
    }

    public QLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN idqn, QLearningConfiguration qLearningConfiguration, int i, Random random) {
        this(mdp, idqn, qLearningConfiguration, i, buildLearningBehavior(idqn, qLearningConfiguration, random), random);
    }

    public QLearningDiscrete(MDP<O, Integer, DiscreteSpace> mdp, IDQN idqn, QLearningConfiguration qLearningConfiguration, int i, ILearningBehavior<Integer> iLearningBehavior, Random random) {
        this.accuReward = 0.0d;
        this.configuration = qLearningConfiguration;
        this.mdp = new LegacyMDPWrapper<>(mdp, null);
        this.qNetwork = idqn;
        this.policy = new DQNPolicy<>(getQNetwork());
        this.egPolicy = new EpsGreedy<>(this.policy, mdp, qLearningConfiguration.getUpdateStart(), i, random, qLearningConfiguration.getMinEpsilon(), this);
        this.learningBehavior = iLearningBehavior;
    }

    /* JADX WARN: Type inference failed for: r0v1, types: [org.deeplearning4j.rl4j.network.IOutputNeuralNet, org.deeplearning4j.rl4j.network.ITrainableNeuralNet] */
    /* JADX WARN: Type inference failed for: r0v11, types: [org.deeplearning4j.rl4j.agent.learning.update.updater.NeuralNetUpdaterConfiguration$NeuralNetUpdaterConfigurationBuilder] */
    /* JADX WARN: Type inference failed for: r0v16, types: [org.deeplearning4j.rl4j.experience.ReplayMemoryExperienceHandler$Configuration$ConfigurationBuilder] */
    /* JADX WARN: Type inference failed for: r0v3, types: [org.deeplearning4j.rl4j.agent.learning.algorithm.dqn.BaseTransitionTDAlgorithm$Configuration$ConfigurationBuilder] */
    private static ILearningBehavior<Integer> buildLearningBehavior(IDQN idqn, QLearningConfiguration qLearningConfiguration, Random random) {
        ?? clone = idqn.mo26clone();
        BaseTransitionTDAlgorithm.Configuration build = BaseTransitionTDAlgorithm.Configuration.builder().gamma(qLearningConfiguration.getGamma()).errorClamp(qLearningConfiguration.getErrorClamp()).build();
        return LearningBehavior.builder().experienceHandler(new ReplayMemoryExperienceHandler(ReplayMemoryExperienceHandler.Configuration.builder().maxReplayMemorySize(qLearningConfiguration.getExpRepMaxSize()).batchSize(qLearningConfiguration.getBatchSize()).build(), random)).updateRule(new UpdateRule(qLearningConfiguration.isDoubleDQN() ? new DoubleDQN(idqn, clone, build) : new StandardDQN(idqn, clone, build), new SyncLabelsNeuralNetUpdater(idqn, clone, NeuralNetUpdaterConfiguration.builder().targetUpdateFrequency(qLearningConfiguration.getTargetDqnUpdateFreq()).build()))).build();
    }

    @Override // org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning, org.deeplearning4j.rl4j.learning.ILearning
    public MDP<O, Integer, DiscreteSpace> getMdp() {
        return this.mdp.getWrappedMDP();
    }

    @Override // org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning, org.deeplearning4j.rl4j.learning.sync.SyncLearning
    public void postEpoch() {
        if (getHistoryProcessor() != null) {
            getHistoryProcessor().stopMonitor();
        }
    }

    @Override // org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning, org.deeplearning4j.rl4j.learning.sync.SyncLearning
    public void preEpoch() {
        this.lastAction = this.mdp.getActionSpace().noOp().intValue();
        this.accuReward = 0.0d;
        this.learningBehavior.handleEpisodeStart();
    }

    @Override // org.deeplearning4j.rl4j.learning.Learning
    public void setHistoryProcessor(IHistoryProcessor iHistoryProcessor) {
        super.setHistoryProcessor(iHistoryProcessor);
        this.mdp.setHistoryProcessor(iHistoryProcessor);
    }

    @Override // org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning
    protected QLearning.QLStepReturn<Observation> trainStep(Observation observation) {
        Double valueOf = Double.valueOf(Double.NaN);
        if (!observation.isSkipped()) {
            valueOf = Double.valueOf(getQNetwork().output(observation).get("Q").getDouble(Learning.getMaxAction(r0).intValue()));
            this.lastAction = getEgPolicy().nextAction(observation).intValue();
        }
        StepReply<Observation> step = this.mdp.step(Integer.valueOf(this.lastAction));
        this.accuReward += step.getReward() * this.configuration.getRewardFactor();
        if (!observation.isSkipped()) {
            this.learningBehavior.handleNewExperience(observation, Integer.valueOf(this.lastAction), this.accuReward, step.isDone());
            this.accuReward = 0.0d;
        }
        return new QLearning.QLStepReturn<>(valueOf, getQNetwork().getLatestScore(), step);
    }

    @Override // org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning
    protected void finishEpoch(Observation observation) {
        this.learningBehavior.handleEpisodeEnd(observation);
    }

    @Override // org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning, org.deeplearning4j.rl4j.learning.ILearning
    public QLearningConfiguration getConfiguration() {
        return this.configuration;
    }

    @Override // org.deeplearning4j.rl4j.learning.ILearning
    public DQNPolicy<O> getPolicy() {
        return this.policy;
    }

    @Override // org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning
    public EpsGreedy<Integer> getEgPolicy() {
        return this.egPolicy;
    }

    @Override // org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning
    public IDQN getQNetwork() {
        return this.qNetwork;
    }
}
