package org.deeplearning4j.rl4j.agent;

import lombok.NonNull;
import org.deeplearning4j.rl4j.agent.listener.AgentListener;
import org.deeplearning4j.rl4j.agent.listener.AgentListenerList;
import org.deeplearning4j.rl4j.environment.Environment;
import org.deeplearning4j.rl4j.environment.StepResult;
import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.observation.transform.TransformProcess;
import org.deeplearning4j.rl4j.policy.IPolicy;
import org.nd4j.common.base.Preconditions;

/* loaded from: input_file:org/deeplearning4j/rl4j/agent/Agent.class */
public class Agent<ACTION> implements IAgent<ACTION> {
    private final String id;
    private final Environment<ACTION> environment;
    private final IPolicy<ACTION> policy;
    private final TransformProcess transformProcess;
    protected final AgentListenerList<ACTION> listeners;
    private final Integer maxEpisodeSteps;
    private Observation observation;
    private ACTION lastAction;
    private int episodeStepCount;
    private double reward;
    protected boolean canContinue;

    /* loaded from: input_file:org/deeplearning4j/rl4j/agent/Agent$Configuration.class */
    public static class Configuration {
        Integer maxEpisodeSteps;

        /* loaded from: input_file:org/deeplearning4j/rl4j/agent/Agent$Configuration$ConfigurationBuilder.class */
        public static abstract class ConfigurationBuilder<C extends Configuration, B extends ConfigurationBuilder<C, B>> {
            private boolean maxEpisodeSteps$set;
            private Integer maxEpisodeSteps$value;

            protected abstract B self();

            public abstract C build();

            public B maxEpisodeSteps(Integer num) {
                this.maxEpisodeSteps$value = num;
                this.maxEpisodeSteps$set = true;
                return self();
            }

            public String toString() {
                return "Agent.Configuration.ConfigurationBuilder(maxEpisodeSteps$value=" + this.maxEpisodeSteps$value + ")";
            }
        }

        /* loaded from: input_file:org/deeplearning4j/rl4j/agent/Agent$Configuration$ConfigurationBuilderImpl.class */
        private static final class ConfigurationBuilderImpl extends ConfigurationBuilder<Configuration, ConfigurationBuilderImpl> {
            private ConfigurationBuilderImpl() {
            }

            /* JADX INFO: Access modifiers changed from: protected */
            /* JADX WARN: Can't rename method to resolve collision */
            @Override // org.deeplearning4j.rl4j.agent.Agent.Configuration.ConfigurationBuilder
            public ConfigurationBuilderImpl self() {
                return this;
            }

            @Override // org.deeplearning4j.rl4j.agent.Agent.Configuration.ConfigurationBuilder
            public Configuration build() {
                return new Configuration(this);
            }
        }

        private static Integer $default$maxEpisodeSteps() {
            return null;
        }

        /* JADX INFO: Access modifiers changed from: protected */
        public Configuration(ConfigurationBuilder<?, ?> configurationBuilder) {
            if (((ConfigurationBuilder) configurationBuilder).maxEpisodeSteps$set) {
                this.maxEpisodeSteps = ((ConfigurationBuilder) configurationBuilder).maxEpisodeSteps$value;
            } else {
                this.maxEpisodeSteps = $default$maxEpisodeSteps();
            }
        }

        public static ConfigurationBuilder<?, ?> builder() {
            return new ConfigurationBuilderImpl();
        }

        public Integer getMaxEpisodeSteps() {
            return this.maxEpisodeSteps;
        }

        public void setMaxEpisodeSteps(Integer num) {
            this.maxEpisodeSteps = num;
        }

        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof Configuration)) {
                return false;
            }
            Configuration configuration = (Configuration) obj;
            if (!configuration.canEqual(this)) {
                return false;
            }
            Integer maxEpisodeSteps = getMaxEpisodeSteps();
            Integer maxEpisodeSteps2 = configuration.getMaxEpisodeSteps();
            return maxEpisodeSteps == null ? maxEpisodeSteps2 == null : maxEpisodeSteps.equals(maxEpisodeSteps2);
        }

        protected boolean canEqual(Object obj) {
            return obj instanceof Configuration;
        }

        public int hashCode() {
            Integer maxEpisodeSteps = getMaxEpisodeSteps();
            return (1 * 59) + (maxEpisodeSteps == null ? 43 : maxEpisodeSteps.hashCode());
        }

        public String toString() {
            return "Agent.Configuration(maxEpisodeSteps=" + getMaxEpisodeSteps() + ")";
        }
    }

    public Agent(@NonNull Environment<ACTION> environment, @NonNull TransformProcess transformProcess, @NonNull IPolicy<ACTION> iPolicy, @NonNull Configuration configuration, String str) {
        if (environment == null) {
            throw new NullPointerException("environment is marked non-null but is null");
        }
        if (transformProcess == null) {
            throw new NullPointerException("transformProcess is marked non-null but is null");
        }
        if (iPolicy == null) {
            throw new NullPointerException("policy is marked non-null but is null");
        }
        if (configuration == null) {
            throw new NullPointerException("configuration is marked non-null but is null");
        }
        Preconditions.checkArgument(configuration.getMaxEpisodeSteps() == null || configuration.getMaxEpisodeSteps().intValue() > 0, "Configuration: maxEpisodeSteps must be null (no maximum) or greater than 0, got", configuration.getMaxEpisodeSteps());
        this.environment = environment;
        this.transformProcess = transformProcess;
        this.policy = iPolicy;
        this.maxEpisodeSteps = configuration.getMaxEpisodeSteps();
        this.id = str;
        this.listeners = buildListenerList();
    }

    protected AgentListenerList<ACTION> buildListenerList() {
        return new AgentListenerList<>();
    }

    public void addListener(AgentListener agentListener) {
        this.listeners.add(agentListener);
    }

    @Override // org.deeplearning4j.rl4j.agent.IAgent
    public void run() {
        runEpisode();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void onBeforeEpisode() {
    }

    protected void onAfterEpisode() {
    }

    protected void runEpisode() {
        reset();
        onBeforeEpisode();
        this.canContinue = this.listeners.notifyBeforeEpisode(this);
        while (this.canContinue && !this.environment.isEpisodeFinished() && (this.maxEpisodeSteps == null || this.episodeStepCount < this.maxEpisodeSteps.intValue())) {
            performStep();
        }
        if (this.canContinue) {
            onAfterEpisode();
            this.listeners.notifyAfterEpisode(this);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void reset() {
        resetEnvironment();
        resetPolicy();
        this.reward = 0.0d;
        this.lastAction = getInitialAction();
        this.canContinue = true;
    }

    protected void resetEnvironment() {
        this.episodeStepCount = 0;
        this.observation = this.transformProcess.transform(this.environment.reset(), this.episodeStepCount, false);
    }

    protected void resetPolicy() {
        this.policy.reset();
    }

    protected ACTION getInitialAction() {
        return this.environment.getSchema().getActionSchema().getNoOp();
    }

    protected void performStep() {
        onBeforeStep();
        ACTION decideAction = decideAction(this.observation);
        this.canContinue = this.listeners.notifyBeforeStep(this, this.observation, decideAction);
        if (this.canContinue) {
            StepResult act = act(decideAction);
            onAfterStep(act);
            this.canContinue = this.listeners.notifyAfterStep(this, act);
            if (this.canContinue) {
                incrementEpisodeStepCount();
            }
        }
    }

    protected void incrementEpisodeStepCount() {
        this.episodeStepCount++;
    }

    protected ACTION decideAction(Observation observation) {
        if (!observation.isSkipped()) {
            this.lastAction = this.policy.nextAction(observation);
        }
        return this.lastAction;
    }

    protected StepResult act(ACTION action) {
        Observation observation = this.observation;
        StepResult step = this.environment.step(action);
        this.observation = convertChannelDataToObservation(step, this.episodeStepCount + 1);
        this.reward += computeReward(step);
        onAfterAction(observation, action, step);
        return step;
    }

    protected Observation convertChannelDataToObservation(StepResult stepResult, int i) {
        return this.transformProcess.transform(stepResult.getChannelsData(), i, stepResult.isTerminal());
    }

    protected double computeReward(StepResult stepResult) {
        return stepResult.getReward();
    }

    protected void onAfterAction(Observation observation, ACTION action, StepResult stepResult) {
    }

    protected void onAfterStep(StepResult stepResult) {
    }

    protected void onBeforeStep() {
    }

    @Override // org.deeplearning4j.rl4j.agent.IAgent
    public String getId() {
        return this.id;
    }

    @Override // org.deeplearning4j.rl4j.agent.IAgent
    public Environment<ACTION> getEnvironment() {
        return this.environment;
    }

    @Override // org.deeplearning4j.rl4j.agent.IAgent
    public IPolicy<ACTION> getPolicy() {
        return this.policy;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Observation getObservation() {
        return this.observation;
    }

    protected ACTION getLastAction() {
        return this.lastAction;
    }

    @Override // org.deeplearning4j.rl4j.agent.IAgent
    public int getEpisodeStepCount() {
        return this.episodeStepCount;
    }

    @Override // org.deeplearning4j.rl4j.agent.IAgent
    public double getReward() {
        return this.reward;
    }
}
