package org.deeplearning4j.nn.modelimport.keras;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.modelimport.keras.KerasModel;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.api.ndarray.INDArray;

/* loaded from: input_file:org/deeplearning4j/nn/modelimport/keras/KerasSequentialModel.class */
public class KerasSequentialModel extends KerasModel {
    public KerasSequentialModel(KerasModel.ModelBuilder modelBuilder) throws UnsupportedKerasConfigurationException, IOException, InvalidKerasConfigurationException {
        this(modelBuilder.modelJson, modelBuilder.modelYaml, modelBuilder.trainingJson, modelBuilder.weights, modelBuilder.train);
    }

    public KerasSequentialModel(String str, String str2, String str3, Map<String, Map<String, INDArray>> map, boolean z) throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        Map<String, Object> parseYamlString;
        if (str != null) {
            parseYamlString = parseJsonString(str);
        } else {
            if (str2 == null) {
                throw new InvalidKerasConfigurationException("Requires model configuration as either JSON or YAML string.");
            }
            parseYamlString = parseYamlString(str2);
        }
        this.className = (String) checkAndGetModelField(parseYamlString, "class_name");
        if (!this.className.equals(KerasModel.MODEL_CLASS_NAME_SEQUENTIAL)) {
            throw new InvalidKerasConfigurationException("Model class name must be Sequential (found " + this.className + ")");
        }
        this.train = z;
        helperPrepareLayers((List) checkAndGetModelField(parseYamlString, "config"));
        KerasLayer createInputLayer = KerasLayer.createInputLayer("input1", this.layers.get(this.layerNamesOrdered.get(0)).getInputShape());
        this.layers.put(createInputLayer.getName(), createInputLayer);
        this.inputLayerNames = new ArrayList<>(Arrays.asList(createInputLayer.getName()));
        this.outputLayerNames = new ArrayList<>(Arrays.asList(this.layerNamesOrdered.get(this.layerNamesOrdered.size() - 1)));
        this.layerNamesOrdered.add(0, createInputLayer.getName());
        String str4 = null;
        for (String str5 : this.layerNamesOrdered) {
            if (str4 != null) {
                this.layers.get(str5).setInboundLayerNames(Arrays.asList(str4));
            }
            str4 = str5;
        }
        helperPrepareGraph();
        if (str3 != null) {
            helperImportTrainingConfiguration(str3);
        }
        this.weights = map;
    }

    protected KerasSequentialModel() {
    }

    public MultiLayerConfiguration getMultiLayerConfiguration() throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        if (!this.className.equals(KerasModel.MODEL_CLASS_NAME_SEQUENTIAL)) {
            throw new InvalidKerasConfigurationException("Keras model class name " + this.className + " incompatible with MultiLayerNetwork");
        }
        if (this.inputLayerNames.size() != 1) {
            throw new InvalidKerasConfigurationException("MultiLayeNetwork expects only 1 input (found " + this.inputLayerNames.size() + ")");
        }
        if (this.outputLayerNames.size() != 1) {
            throw new InvalidKerasConfigurationException("MultiLayeNetwork expects only 1 output (found " + this.outputLayerNames.size() + ")");
        }
        NeuralNetConfiguration.ListBuilder list = new NeuralNetConfiguration.Builder().list();
        int i = 0;
        Iterator<String> it = this.layerNamesOrdered.iterator();
        while (it.hasNext()) {
            KerasLayer kerasLayer = this.layers.get(it.next());
            if (kerasLayer.isDl4jLayer()) {
                int i2 = i;
                i++;
                list.layer(i2, kerasLayer.getDl4jLayer());
            }
        }
        InputType inferInputType = inferInputType(this.inputLayerNames.get(0));
        if (inferInputType != null) {
            list.setInputType(inferInputType);
        }
        if (this.truncatedBPTT == 0) {
            throw new UnsupportedKerasConfigurationException("Cannot import recurrent models without fixed length sequence input.");
        }
        if (this.truncatedBPTT > 0) {
            list.tBPTTForwardLength(this.truncatedBPTT).tBPTTBackwardLength(this.truncatedBPTT);
        }
        return list.build();
    }

    public MultiLayerNetwork getMultiLayerNetwork() throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        return getMultiLayerNetwork(true);
    }

    public MultiLayerNetwork getMultiLayerNetwork(boolean z) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        MultiLayerNetwork multiLayerNetwork = new MultiLayerNetwork(getMultiLayerConfiguration());
        multiLayerNetwork.init();
        if (z) {
            multiLayerNetwork = (MultiLayerNetwork) copyWeightsToModel(multiLayerNetwork, this.weights, this.layers);
        }
        return multiLayerNetwork;
    }
}
