package ai.djl.nn.recurrent;

import ai.djl.nn.recurrent.RecurrentBlock;
import ai.djl.util.Preconditions;

/* loaded from: input_file:ai/djl/nn/recurrent/RNN.class */
public class RNN extends RecurrentBlock {

    /* loaded from: input_file:ai/djl/nn/recurrent/RNN$Activation.class */
    public enum Activation {
        RELU,
        TANH
    }

    /* loaded from: input_file:ai/djl/nn/recurrent/RNN$Builder.class */
    public static final class Builder extends RecurrentBlock.BaseBuilder<Builder> {
        /* JADX INFO: Access modifiers changed from: protected */
        /* JADX WARN: Can't rename method to resolve collision */
        @Override // ai.djl.nn.recurrent.RecurrentBlock.BaseBuilder
        public Builder self() {
            return this;
        }

        public Builder setActivation(Activation activation) {
            this.activation = activation;
            return self();
        }

        public RNN build() {
            Preconditions.checkArgument(this.stateSize > 0 && this.numStackedLayers > 0, "Must set stateSize and numStackedLayers");
            return new RNN(this);
        }
    }

    RNN(Builder builder) {
        super(builder);
        this.mode = builder.activation == Activation.RELU ? "rnn_relu" : "rnn_tanh";
        this.gates = 1;
    }

    public static Builder builder() {
        return new Builder();
    }
}
