package org.deeplearning4j.rl4j.builder;

import lombok.NonNull;
import org.apache.commons.lang3.builder.Builder;
import org.deeplearning4j.rl4j.agent.learning.algorithm.IUpdateAlgorithm;
import org.deeplearning4j.rl4j.agent.learning.algorithm.actorcritic.AdvantageActorCritic;
import org.deeplearning4j.rl4j.agent.learning.update.Gradients;
import org.deeplearning4j.rl4j.agent.learning.update.updater.async.AsyncSharedNetworksUpdateHandler;
import org.deeplearning4j.rl4j.builder.BaseAsyncAgentLearnerBuilder;
import org.deeplearning4j.rl4j.environment.Environment;
import org.deeplearning4j.rl4j.experience.StateActionReward;
import org.deeplearning4j.rl4j.network.ITrainableNeuralNet;
import org.deeplearning4j.rl4j.observation.transform.TransformProcess;
import org.deeplearning4j.rl4j.policy.ACPolicy;
import org.deeplearning4j.rl4j.policy.IPolicy;
import org.nd4j.linalg.api.rng.Random;

/* loaded from: input_file:org/deeplearning4j/rl4j/builder/AdvantageActorCriticBuilder.class */
public class AdvantageActorCriticBuilder extends BaseAsyncAgentLearnerBuilder<Configuration> {
    private final Random rnd;

    /* loaded from: input_file:org/deeplearning4j/rl4j/builder/AdvantageActorCriticBuilder$Configuration.class */
    public static class Configuration extends BaseAsyncAgentLearnerBuilder.Configuration {
        AdvantageActorCritic.Configuration advantageActorCriticConfiguration;

        /* loaded from: input_file:org/deeplearning4j/rl4j/builder/AdvantageActorCriticBuilder$Configuration$ConfigurationBuilder.class */
        public static abstract class ConfigurationBuilder<C extends Configuration, B extends ConfigurationBuilder<C, B>> extends BaseAsyncAgentLearnerBuilder.Configuration.ConfigurationBuilder<C, B> {
            private AdvantageActorCritic.Configuration advantageActorCriticConfiguration;

            /* JADX INFO: Access modifiers changed from: protected */
            @Override // org.deeplearning4j.rl4j.builder.BaseAsyncAgentLearnerBuilder.Configuration.ConfigurationBuilder, org.deeplearning4j.rl4j.builder.BaseAgentLearnerBuilder.Configuration.ConfigurationBuilder
            public abstract B self();

            @Override // org.deeplearning4j.rl4j.builder.BaseAsyncAgentLearnerBuilder.Configuration.ConfigurationBuilder, org.deeplearning4j.rl4j.builder.BaseAgentLearnerBuilder.Configuration.ConfigurationBuilder
            public abstract C build();

            public B advantageActorCriticConfiguration(AdvantageActorCritic.Configuration configuration) {
                this.advantageActorCriticConfiguration = configuration;
                return self();
            }

            @Override // org.deeplearning4j.rl4j.builder.BaseAsyncAgentLearnerBuilder.Configuration.ConfigurationBuilder, org.deeplearning4j.rl4j.builder.BaseAgentLearnerBuilder.Configuration.ConfigurationBuilder
            public String toString() {
                return "AdvantageActorCriticBuilder.Configuration.ConfigurationBuilder(super=" + super.toString() + ", advantageActorCriticConfiguration=" + this.advantageActorCriticConfiguration + ")";
            }
        }

        /* loaded from: input_file:org/deeplearning4j/rl4j/builder/AdvantageActorCriticBuilder$Configuration$ConfigurationBuilderImpl.class */
        private static final class ConfigurationBuilderImpl extends ConfigurationBuilder<Configuration, ConfigurationBuilderImpl> {
            private ConfigurationBuilderImpl() {
            }

            /* JADX INFO: Access modifiers changed from: protected */
            @Override // org.deeplearning4j.rl4j.builder.AdvantageActorCriticBuilder.Configuration.ConfigurationBuilder, org.deeplearning4j.rl4j.builder.BaseAsyncAgentLearnerBuilder.Configuration.ConfigurationBuilder, org.deeplearning4j.rl4j.builder.BaseAgentLearnerBuilder.Configuration.ConfigurationBuilder
            public ConfigurationBuilderImpl self() {
                return this;
            }

            @Override // org.deeplearning4j.rl4j.builder.AdvantageActorCriticBuilder.Configuration.ConfigurationBuilder, org.deeplearning4j.rl4j.builder.BaseAsyncAgentLearnerBuilder.Configuration.ConfigurationBuilder, org.deeplearning4j.rl4j.builder.BaseAgentLearnerBuilder.Configuration.ConfigurationBuilder
            public Configuration build() {
                return new Configuration(this);
            }
        }

        protected Configuration(ConfigurationBuilder<?, ?> configurationBuilder) {
            super(configurationBuilder);
            this.advantageActorCriticConfiguration = ((ConfigurationBuilder) configurationBuilder).advantageActorCriticConfiguration;
        }

        public static ConfigurationBuilder<?, ?> builder() {
            return new ConfigurationBuilderImpl();
        }

        @Override // org.deeplearning4j.rl4j.builder.BaseAsyncAgentLearnerBuilder.Configuration, org.deeplearning4j.rl4j.builder.BaseAgentLearnerBuilder.Configuration
        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof Configuration)) {
                return false;
            }
            Configuration configuration = (Configuration) obj;
            if (!configuration.canEqual(this) || !super.equals(obj)) {
                return false;
            }
            AdvantageActorCritic.Configuration advantageActorCriticConfiguration = getAdvantageActorCriticConfiguration();
            AdvantageActorCritic.Configuration advantageActorCriticConfiguration2 = configuration.getAdvantageActorCriticConfiguration();
            return advantageActorCriticConfiguration == null ? advantageActorCriticConfiguration2 == null : advantageActorCriticConfiguration.equals(advantageActorCriticConfiguration2);
        }

        @Override // org.deeplearning4j.rl4j.builder.BaseAsyncAgentLearnerBuilder.Configuration, org.deeplearning4j.rl4j.builder.BaseAgentLearnerBuilder.Configuration
        protected boolean canEqual(Object obj) {
            return obj instanceof Configuration;
        }

        @Override // org.deeplearning4j.rl4j.builder.BaseAsyncAgentLearnerBuilder.Configuration, org.deeplearning4j.rl4j.builder.BaseAgentLearnerBuilder.Configuration
        public int hashCode() {
            int hashCode = super.hashCode();
            AdvantageActorCritic.Configuration advantageActorCriticConfiguration = getAdvantageActorCriticConfiguration();
            return (hashCode * 59) + (advantageActorCriticConfiguration == null ? 43 : advantageActorCriticConfiguration.hashCode());
        }

        public AdvantageActorCritic.Configuration getAdvantageActorCriticConfiguration() {
            return this.advantageActorCriticConfiguration;
        }

        public void setAdvantageActorCriticConfiguration(AdvantageActorCritic.Configuration configuration) {
            this.advantageActorCriticConfiguration = configuration;
        }

        @Override // org.deeplearning4j.rl4j.builder.BaseAsyncAgentLearnerBuilder.Configuration, org.deeplearning4j.rl4j.builder.BaseAgentLearnerBuilder.Configuration
        public String toString() {
            return "AdvantageActorCriticBuilder.Configuration(advantageActorCriticConfiguration=" + getAdvantageActorCriticConfiguration() + ")";
        }
    }

    public AdvantageActorCriticBuilder(@NonNull Configuration configuration, @NonNull ITrainableNeuralNet iTrainableNeuralNet, @NonNull Builder<Environment<Integer>> builder, @NonNull Builder<TransformProcess> builder2, Random random) {
        super(configuration, iTrainableNeuralNet, builder, builder2);
        if (configuration == null) {
            throw new NullPointerException("configuration is marked non-null but is null");
        }
        if (iTrainableNeuralNet == null) {
            throw new NullPointerException("neuralNet is marked non-null but is null");
        }
        if (builder == null) {
            throw new NullPointerException("environmentBuilder is marked non-null but is null");
        }
        if (builder2 == null) {
            throw new NullPointerException("transformProcessBuilder is marked non-null but is null");
        }
        this.rnd = random;
    }

    @Override // org.deeplearning4j.rl4j.builder.BaseAgentLearnerBuilder
    protected IPolicy<Integer> buildPolicy() {
        return ACPolicy.builder().neuralNet(this.networks.getThreadCurrentNetwork()).isTraining(true).rnd(this.rnd).build();
    }

    @Override // org.deeplearning4j.rl4j.builder.BaseAgentLearnerBuilder
    protected IUpdateAlgorithm<Gradients, StateActionReward<Integer>> buildUpdateAlgorithm() {
        return new AdvantageActorCritic(this.networks.getThreadCurrentNetwork(), getEnvironment().getSchema().getActionSchema().getActionSpaceSize(), ((Configuration) this.configuration).getAdvantageActorCriticConfiguration());
    }

    @Override // org.deeplearning4j.rl4j.builder.BaseAsyncAgentLearnerBuilder
    protected AsyncSharedNetworksUpdateHandler buildAsyncSharedNetworksUpdateHandler() {
        return new AsyncSharedNetworksUpdateHandler(this.networks.getGlobalCurrentNetwork(), ((Configuration) this.configuration).getNeuralNetUpdaterConfiguration());
    }
}
