package org.deeplearning4j.rl4j.agent.learning.algorithm.actorcritic;

import java.util.List;
import lombok.NonNull;
import org.deeplearning4j.rl4j.agent.learning.algorithm.IUpdateAlgorithm;
import org.deeplearning4j.rl4j.agent.learning.update.Features;
import org.deeplearning4j.rl4j.agent.learning.update.FeaturesBuilder;
import org.deeplearning4j.rl4j.agent.learning.update.FeaturesLabels;
import org.deeplearning4j.rl4j.agent.learning.update.Gradients;
import org.deeplearning4j.rl4j.experience.StateActionReward;
import org.deeplearning4j.rl4j.network.ITrainableNeuralNet;
import org.nd4j.linalg.api.ndarray.INDArray;

/* loaded from: input_file:org/deeplearning4j/rl4j/agent/learning/algorithm/actorcritic/AdvantageActorCritic.class */
public class AdvantageActorCritic implements IUpdateAlgorithm<Gradients, StateActionReward<Integer>> {
    private final ITrainableNeuralNet threadCurrent;
    private final double gamma;
    private final ActorCriticHelper algorithmHelper;
    private final FeaturesBuilder featuresBuilder;

    /* loaded from: input_file:org/deeplearning4j/rl4j/agent/learning/algorithm/actorcritic/AdvantageActorCritic$Configuration.class */
    public static class Configuration {
        double gamma;

        /* loaded from: input_file:org/deeplearning4j/rl4j/agent/learning/algorithm/actorcritic/AdvantageActorCritic$Configuration$ConfigurationBuilder.class */
        public static abstract class ConfigurationBuilder<C extends Configuration, B extends ConfigurationBuilder<C, B>> {
            private boolean gamma$set;
            private double gamma$value;

            protected abstract B self();

            public abstract C build();

            public B gamma(double d) {
                this.gamma$value = d;
                this.gamma$set = true;
                return self();
            }

            public String toString() {
                return "AdvantageActorCritic.Configuration.ConfigurationBuilder(gamma$value=" + this.gamma$value + ")";
            }
        }

        /* loaded from: input_file:org/deeplearning4j/rl4j/agent/learning/algorithm/actorcritic/AdvantageActorCritic$Configuration$ConfigurationBuilderImpl.class */
        private static final class ConfigurationBuilderImpl extends ConfigurationBuilder<Configuration, ConfigurationBuilderImpl> {
            private ConfigurationBuilderImpl() {
            }

            /* JADX INFO: Access modifiers changed from: protected */
            /* JADX WARN: Can't rename method to resolve collision */
            @Override // org.deeplearning4j.rl4j.agent.learning.algorithm.actorcritic.AdvantageActorCritic.Configuration.ConfigurationBuilder
            public ConfigurationBuilderImpl self() {
                return this;
            }

            @Override // org.deeplearning4j.rl4j.agent.learning.algorithm.actorcritic.AdvantageActorCritic.Configuration.ConfigurationBuilder
            public Configuration build() {
                return new Configuration(this);
            }
        }

        private static double $default$gamma() {
            return 0.99d;
        }

        protected Configuration(ConfigurationBuilder<?, ?> configurationBuilder) {
            if (((ConfigurationBuilder) configurationBuilder).gamma$set) {
                this.gamma = ((ConfigurationBuilder) configurationBuilder).gamma$value;
            } else {
                this.gamma = $default$gamma();
            }
        }

        public static ConfigurationBuilder<?, ?> builder() {
            return new ConfigurationBuilderImpl();
        }

        public double getGamma() {
            return this.gamma;
        }

        public void setGamma(double d) {
            this.gamma = d;
        }

        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof Configuration)) {
                return false;
            }
            Configuration configuration = (Configuration) obj;
            return configuration.canEqual(this) && Double.compare(getGamma(), configuration.getGamma()) == 0;
        }

        protected boolean canEqual(Object obj) {
            return obj instanceof Configuration;
        }

        public int hashCode() {
            long doubleToLongBits = Double.doubleToLongBits(getGamma());
            return (1 * 59) + ((int) ((doubleToLongBits >>> 32) ^ doubleToLongBits));
        }

        public String toString() {
            return "AdvantageActorCritic.Configuration(gamma=" + getGamma() + ")";
        }
    }

    public AdvantageActorCritic(@NonNull ITrainableNeuralNet iTrainableNeuralNet, int i, @NonNull Configuration configuration) {
        if (iTrainableNeuralNet == null) {
            throw new NullPointerException("threadCurrent is marked non-null but is null");
        }
        if (configuration == null) {
            throw new NullPointerException("configuration is marked non-null but is null");
        }
        this.threadCurrent = iTrainableNeuralNet;
        this.gamma = configuration.getGamma();
        this.algorithmHelper = iTrainableNeuralNet.isRecurrent() ? new RecurrentActorCriticHelper(i) : new NonRecurrentActorCriticHelper(i);
        this.featuresBuilder = new FeaturesBuilder(iTrainableNeuralNet.isRecurrent());
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.deeplearning4j.rl4j.agent.learning.algorithm.IUpdateAlgorithm
    public Gradients compute(List<StateActionReward<Integer>> list) {
        int size = list.size();
        Features build = this.featuresBuilder.build(list);
        INDArray createValueLabels = this.algorithmHelper.createValueLabels(size);
        INDArray createPolicyLabels = this.algorithmHelper.createPolicyLabels(size);
        double d = list.get(size - 1).isTerminal() ? 0.0d : this.threadCurrent.output(list.get(size - 1).getObservation()).get("value").getDouble(0L);
        for (int i = size - 1; i >= 0; i--) {
            StateActionReward<Integer> stateActionReward = list.get(i);
            d = stateActionReward.getReward() + (this.gamma * d);
            createValueLabels.putScalar(i, d);
            this.algorithmHelper.setPolicy(createPolicyLabels, i, stateActionReward.getAction().intValue(), d - this.threadCurrent.output(list.get(i).getObservation()).get("value").getDouble(0L));
        }
        FeaturesLabels featuresLabels = new FeaturesLabels(build);
        featuresLabels.putLabels("value", createValueLabels);
        featuresLabels.putLabels("policy", createPolicyLabels);
        return this.threadCurrent.computeGradients(featuresLabels);
    }
}
