package org.deeplearning4j.rl4j.network;

import lombok.NonNull;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.rl4j.network.ChannelToNetworkInputMapper;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;

/* loaded from: input_file:org/deeplearning4j/rl4j/network/QNetwork.class */
public class QNetwork extends BaseNetwork<QNetwork> {
    private static final String[] LABEL_NAMES = {"Q"};

    /* loaded from: input_file:org/deeplearning4j/rl4j/network/QNetwork$Builder.class */
    public static class Builder {
        private final NetworkHelper networkHelper = new NetworkHelper();
        private ChannelToNetworkInputMapper.NetworkInputToChannelBinding[] networkInputsToFeatureBindings;
        private String[] channelNames;
        private String inputChannelName;
        private ComputationGraph cgNetwork;
        private MultiLayerNetwork mlnNetwork;

        public Builder withNetwork(@NonNull ComputationGraph computationGraph) {
            if (computationGraph == null) {
                throw new NullPointerException("network is marked non-null but is null");
            }
            this.cgNetwork = computationGraph;
            return this;
        }

        public Builder withNetwork(@NonNull MultiLayerNetwork multiLayerNetwork) {
            if (multiLayerNetwork == null) {
                throw new NullPointerException("network is marked non-null but is null");
            }
            this.mlnNetwork = multiLayerNetwork;
            return this;
        }

        public Builder inputBindings(ChannelToNetworkInputMapper.NetworkInputToChannelBinding[] networkInputToChannelBindingArr) {
            this.networkInputsToFeatureBindings = networkInputToChannelBindingArr;
            return this;
        }

        public Builder specificBinding(String str) {
            this.inputChannelName = str;
            return this;
        }

        public Builder channelNames(String[] strArr) {
            this.channelNames = strArr;
            return this;
        }

        public QNetwork build() {
            INetworkHandler buildHandler;
            Preconditions.checkState((this.cgNetwork == null && this.mlnNetwork == null) ? false : true, "A network must be set.");
            if (this.cgNetwork != null) {
                buildHandler = this.networkInputsToFeatureBindings == null ? this.networkHelper.buildHandler(this.cgNetwork, this.inputChannelName, this.channelNames, QNetwork.LABEL_NAMES, "Q") : this.networkHelper.buildHandler(this.cgNetwork, this.networkInputsToFeatureBindings, this.channelNames, QNetwork.LABEL_NAMES, "Q");
            } else {
                buildHandler = this.networkHelper.buildHandler(this.mlnNetwork, this.inputChannelName, this.channelNames, "Q", "Q");
            }
            return new QNetwork(buildHandler);
        }
    }

    private QNetwork(INetworkHandler iNetworkHandler) {
        super(iNetworkHandler);
    }

    @Override // org.deeplearning4j.rl4j.network.BaseNetwork
    protected NeuralNetOutput packageResult(INDArray[] iNDArrayArr) {
        NeuralNetOutput neuralNetOutput = new NeuralNetOutput();
        neuralNetOutput.put("Q", iNDArrayArr[0]);
        return neuralNetOutput;
    }

    @Override // org.deeplearning4j.rl4j.network.BaseNetwork
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public QNetwork mo26clone() {
        return new QNetwork(getNetworkHandler().m27clone());
    }

    public static Builder builder() {
        return new Builder();
    }
}
