package org.deeplearning4j.rl4j.experience;

import java.util.List;
import org.deeplearning4j.rl4j.learning.sync.ExpReplay;
import org.deeplearning4j.rl4j.learning.sync.IExpReplay;
import org.deeplearning4j.rl4j.observation.Observation;
import org.nd4j.linalg.api.rng.Random;

/* loaded from: input_file:org/deeplearning4j/rl4j/experience/ReplayMemoryExperienceHandler.class */
public class ReplayMemoryExperienceHandler<A> implements ExperienceHandler<A, StateActionRewardState<A>> {
    private static final int DEFAULT_MAX_REPLAY_MEMORY_SIZE = 150000;
    private static final int DEFAULT_BATCH_SIZE = 32;
    private final int batchSize;
    private IExpReplay<A> expReplay;
    private StateActionRewardState<A> pendingStateActionRewardState;

    /* loaded from: input_file:org/deeplearning4j/rl4j/experience/ReplayMemoryExperienceHandler$Configuration.class */
    public static class Configuration {
        private int maxReplayMemorySize;
        private int batchSize;

        /* loaded from: input_file:org/deeplearning4j/rl4j/experience/ReplayMemoryExperienceHandler$Configuration$ConfigurationBuilder.class */
        public static abstract class ConfigurationBuilder<C extends Configuration, B extends ConfigurationBuilder<C, B>> {
            private boolean maxReplayMemorySize$set;
            private int maxReplayMemorySize$value;
            private boolean batchSize$set;
            private int batchSize$value;

            protected abstract B self();

            public abstract C build();

            public B maxReplayMemorySize(int i) {
                this.maxReplayMemorySize$value = i;
                this.maxReplayMemorySize$set = true;
                return self();
            }

            public B batchSize(int i) {
                this.batchSize$value = i;
                this.batchSize$set = true;
                return self();
            }

            public String toString() {
                return "ReplayMemoryExperienceHandler.Configuration.ConfigurationBuilder(maxReplayMemorySize$value=" + this.maxReplayMemorySize$value + ", batchSize$value=" + this.batchSize$value + ")";
            }
        }

        /* loaded from: input_file:org/deeplearning4j/rl4j/experience/ReplayMemoryExperienceHandler$Configuration$ConfigurationBuilderImpl.class */
        private static final class ConfigurationBuilderImpl extends ConfigurationBuilder<Configuration, ConfigurationBuilderImpl> {
            private ConfigurationBuilderImpl() {
            }

            /* JADX INFO: Access modifiers changed from: protected */
            /* JADX WARN: Can't rename method to resolve collision */
            @Override // org.deeplearning4j.rl4j.experience.ReplayMemoryExperienceHandler.Configuration.ConfigurationBuilder
            public ConfigurationBuilderImpl self() {
                return this;
            }

            @Override // org.deeplearning4j.rl4j.experience.ReplayMemoryExperienceHandler.Configuration.ConfigurationBuilder
            public Configuration build() {
                return new Configuration(this);
            }
        }

        private static int $default$batchSize() {
            return 32;
        }

        protected Configuration(ConfigurationBuilder<?, ?> configurationBuilder) {
            int i;
            if (((ConfigurationBuilder) configurationBuilder).maxReplayMemorySize$set) {
                this.maxReplayMemorySize = ((ConfigurationBuilder) configurationBuilder).maxReplayMemorySize$value;
            } else {
                i = ReplayMemoryExperienceHandler.DEFAULT_MAX_REPLAY_MEMORY_SIZE;
                this.maxReplayMemorySize = i;
            }
            if (((ConfigurationBuilder) configurationBuilder).batchSize$set) {
                this.batchSize = ((ConfigurationBuilder) configurationBuilder).batchSize$value;
            } else {
                this.batchSize = $default$batchSize();
            }
        }

        public static ConfigurationBuilder<?, ?> builder() {
            return new ConfigurationBuilderImpl();
        }

        public int getMaxReplayMemorySize() {
            return this.maxReplayMemorySize;
        }

        public int getBatchSize() {
            return this.batchSize;
        }

        public void setMaxReplayMemorySize(int i) {
            this.maxReplayMemorySize = i;
        }

        public void setBatchSize(int i) {
            this.batchSize = i;
        }

        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof Configuration)) {
                return false;
            }
            Configuration configuration = (Configuration) obj;
            return configuration.canEqual(this) && getMaxReplayMemorySize() == configuration.getMaxReplayMemorySize() && getBatchSize() == configuration.getBatchSize();
        }

        protected boolean canEqual(Object obj) {
            return obj instanceof Configuration;
        }

        public int hashCode() {
            return (((1 * 59) + getMaxReplayMemorySize()) * 59) + getBatchSize();
        }

        public String toString() {
            return "ReplayMemoryExperienceHandler.Configuration(maxReplayMemorySize=" + getMaxReplayMemorySize() + ", batchSize=" + getBatchSize() + ")";
        }
    }

    public ReplayMemoryExperienceHandler(IExpReplay<A> iExpReplay) {
        this.expReplay = iExpReplay;
        this.batchSize = iExpReplay.getDesignatedBatchSize();
    }

    public ReplayMemoryExperienceHandler(Configuration configuration, Random random) {
        this(new ExpReplay(configuration.maxReplayMemorySize, configuration.batchSize, random));
    }

    @Override // org.deeplearning4j.rl4j.experience.ExperienceHandler
    public void addExperience(Observation observation, A a, double d, boolean z) {
        setNextObservationOnPending(observation);
        this.pendingStateActionRewardState = new StateActionRewardState<>(observation, a, d, z);
    }

    @Override // org.deeplearning4j.rl4j.experience.ExperienceHandler
    public void setFinalObservation(Observation observation) {
        setNextObservationOnPending(observation);
        this.pendingStateActionRewardState = null;
    }

    @Override // org.deeplearning4j.rl4j.experience.ExperienceHandler
    public int getTrainingBatchSize() {
        return this.expReplay.getBatchSize();
    }

    @Override // org.deeplearning4j.rl4j.experience.ExperienceHandler
    public boolean isTrainingBatchReady() {
        return this.expReplay.getBatchSize() >= this.batchSize;
    }

    @Override // org.deeplearning4j.rl4j.experience.ExperienceHandler
    public List<StateActionRewardState<A>> generateTrainingBatch() {
        return this.expReplay.getBatch();
    }

    @Override // org.deeplearning4j.rl4j.experience.ExperienceHandler
    public void reset() {
        this.pendingStateActionRewardState = null;
    }

    private void setNextObservationOnPending(Observation observation) {
        if (this.pendingStateActionRewardState != null) {
            this.pendingStateActionRewardState.setNextObservation(observation);
            this.expReplay.store(this.pendingStateActionRewardState);
        }
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof ReplayMemoryExperienceHandler)) {
            return false;
        }
        ReplayMemoryExperienceHandler replayMemoryExperienceHandler = (ReplayMemoryExperienceHandler) obj;
        if (!replayMemoryExperienceHandler.canEqual(this) || this.batchSize != replayMemoryExperienceHandler.batchSize) {
            return false;
        }
        IExpReplay<A> iExpReplay = this.expReplay;
        IExpReplay<A> iExpReplay2 = replayMemoryExperienceHandler.expReplay;
        if (iExpReplay == null) {
            if (iExpReplay2 != null) {
                return false;
            }
        } else if (!iExpReplay.equals(iExpReplay2)) {
            return false;
        }
        StateActionRewardState<A> stateActionRewardState = this.pendingStateActionRewardState;
        StateActionRewardState<A> stateActionRewardState2 = replayMemoryExperienceHandler.pendingStateActionRewardState;
        return stateActionRewardState == null ? stateActionRewardState2 == null : stateActionRewardState.equals(stateActionRewardState2);
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof ReplayMemoryExperienceHandler;
    }

    public int hashCode() {
        int i = (1 * 59) + this.batchSize;
        IExpReplay<A> iExpReplay = this.expReplay;
        int hashCode = (i * 59) + (iExpReplay == null ? 43 : iExpReplay.hashCode());
        StateActionRewardState<A> stateActionRewardState = this.pendingStateActionRewardState;
        return (hashCode * 59) + (stateActionRewardState == null ? 43 : stateActionRewardState.hashCode());
    }
}
