package org.deeplearning4j.nn.modelimport.keras;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.graph.PreprocessorVertex;
import org.deeplearning4j.nn.conf.layers.ActivationLayer;
import org.deeplearning4j.nn.conf.layers.BatchNormalization;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.DropoutLayer;
import org.deeplearning4j.nn.conf.layers.GravesLSTM;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.layers.LossLayer;
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/nn/modelimport/keras/KerasLayer.class */
public class KerasLayer {
    public static final String LAYER_FIELD_CLASS_NAME = "class_name";
    public static final String LAYER_CLASS_NAME_INPUT = "InputLayer";
    public static final String LAYER_CLASS_NAME_ACTIVATION = "Activation";
    public static final String LAYER_CLASS_NAME_DROPOUT = "Dropout";
    public static final String LAYER_CLASS_NAME_DENSE = "Dense";
    public static final String LAYER_CLASS_NAME_TIME_DISTRIBUTED_DENSE = "TimeDistributedDense";
    public static final String LAYER_CLASS_NAME_LSTM = "LSTM";
    public static final String LAYER_CLASS_NAME_CONVOLUTION_2D = "Convolution2D";
    public static final String LAYER_CLASS_NAME_MAX_POOLING_2D = "MaxPooling2D";
    public static final String LAYER_CLASS_NAME_AVERAGE_POOLING_2D = "AveragePooling2D";
    public static final String LAYER_CLASS_NAME_FLATTEN = "Flatten";
    public static final String LAYER_CLASS_NAME_RESHAPE = "Reshape";
    public static final String LAYER_CLASS_NAME_REPEATVECTOR = "RepeatVector";
    public static final String LAYER_CLASS_NAME_MERGE = "Merge";
    public static final String LAYER_CLASS_NAME_BATCHNORMALIZATION = "BatchNormalization";
    public static final String LAYER_FIELD_CONFIG = "config";
    public static final String LAYER_FIELD_NAME = "name";
    public static final String LAYER_FIELD_DROPOUT = "dropout";
    public static final String LAYER_FIELD_OUTPUT_DIM = "output_dim";
    public static final String LAYER_FIELD_SUBSAMPLE = "subsample";
    public static final String LAYER_FIELD_NB_ROW = "nb_row";
    public static final String LAYER_FIELD_NB_COL = "nb_col";
    public static final String LAYER_FIELD_NB_FILTER = "nb_filter";
    public static final String LAYER_FIELD_STRIDES = "strides";
    public static final String LAYER_FIELD_POOL_SIZE = "pool_size";
    public static final String LAYER_FIELD_DROPOUT_U = "dropout_U";
    public static final String LAYER_FIELD_DROPOUT_W = "dropout_W";
    public static final String LAYER_FIELD_BATCH_INPUT_SHAPE = "batch_input_shape";
    public static final String LAYER_FIELD_INBOUND_NODES = "inbound_nodes";
    public static final String LAYER_FIELD_BORDER_MODE = "border_mode";
    public static final String LAYER_FIELD_GAMMA_REGULARIZER = "gamma_regularizer";
    public static final String LAYER_FIELD_BETA_REGULARIZER = "beta_regularizer";
    public static final String LAYER_FIELD_MODE = "mode";
    public static final String LAYER_FIELD_AXIS = "axis";
    public static final String LAYER_FIELD_EPSILON = "epsilon";
    public static final String LAYER_FIELD_MOMENTUM = "momentum";
    public static final String LAYER_BORDER_MODE_SAME = "same";
    public static final String LAYER_BORDER_MODE_VALID = "valid";
    public static final String LAYER_BORDER_MODE_FULL = "full";
    public static final int LAYER_BATCHNORM_MODE_1 = 1;
    public static final int LAYER_BATCHNORM_MODE_2 = 2;
    public static final String LAYER_FIELD_W_REGULARIZER = "W_regularizer";
    public static final String LAYER_FIELD_B_REGULARIZER = "b_regularizer";
    public static final String REGULARIZATION_TYPE_L1 = "l1";
    public static final String REGULARIZATION_TYPE_L2 = "l2";
    public static final String LAYER_FIELD_INIT = "init";
    public static final String LAYER_FIELD_INNER_INIT = "inner_init";
    public static final String INIT_UNIFORM = "uniform";
    public static final String INIT_ZERO = "zero";
    public static final String INIT_GLOROT_NORMAL = "glorot_normal";
    public static final String INIT_GLOROT_UNIFORM = "glorot_uniform";
    public static final String INIT_HE_NORMAL = "he_normal";
    public static final String INIT_HE_UNIFORM = "he_uniform";
    public static final String INIT_LECUN_UNIFORM = "lecun_uniform";
    public static final String INIT_NORMAL = "normal";
    public static final String INIT_ORTHOGONAL = "orthogonal";
    public static final String INIT_IDENTITY = "identity";
    public static final String LAYER_FIELD_ACTIVATION = "activation";
    public static final String LAYER_FIELD_INNER_ACTIVATION = "inner_activation";
    public static final String KERAS_ACTIVATION_LINEAR = "linear";
    public static final String DL4J_ACTIVATION_IDENTITY = "identity";
    public static final String KERAS_ACTIVATION_HARD_SIGMOID = "hard_sigmoid";
    public static final String DL4J_ACTIVATION_HARDSIGMOID = "hardsigmoid";
    public static final String LAYER_FIELD_FORGET_BIAS_INIT = "forget_bias_init";
    public static final String LSTM_FORGET_BIAS_INIT_ZERO = "zero";
    public static final String LSTM_FORGET_BIAS_INIT_ONE = "one";
    public static final String LAYER_FIELD_DIM_ORDERING = "dim_ordering";
    public static final String DIM_ORDERING_THEANO = "th";
    public static final String DIM_ORDERING_TENSORFLOW = "tf";
    public static final String LAYER_CLASS_NAME_LOSS = "Loss";
    public static final String LAYER_FIELD_LOSS = "loss";
    public static final String LOSS_SQUARED_LOSS_1 = "mean_squared_error";
    public static final String KERAS_LOSS_SQUARED_LOSS_2 = "mse";
    public static final String KERAS_LOSS_MEAN_ABSOLUTE_ERROR_1 = "mean_absolute_error";
    public static final String KERAS_LOSS_MEAN_ABSOLUTE_ERROR_2 = "mae";
    public static final String KERAS_LOSS_MEAN_ABSOLUTE_PERCENTAGE_ERROR_1 = "mean_absolute_percentage_error";
    public static final String KERAS_LOSS_MEAN_ABSOLUTE_PERCENTAGE_ERROR_2 = "mape";
    public static final String KERAS_LOSS_MEAN_SQUARED_LOGARITHMIC_ERROR_1 = "mean_squared_logarithmic_error";
    public static final String KERAS_LOSS_MEAN_SQUARED_LOGARITHMIC_ERROR_2 = "msle";
    public static final String KERAS_LOSS_SQUARED_HINGE = "squared_hinge";
    public static final String KERAS_LOSS_HINGE = "hinge";
    public static final String KERAS_LOSS_XENT = "binary_crossentropy";
    public static final String KERAS_LOSS_MCXENT = "categorical_crossentropy";
    public static final String KERAS_LOSS_SP_XE = "sparse_categorical_crossentropy";
    public static final String KERAS_LOSS_KL_DIVERGENCE_1 = "kullback_leibler_divergence";
    public static final String KERAS_LOSS_KL_DIVERGENCE_2 = "kld";
    public static final String KERAS_LOSS_POISSON = "poisson";
    public static final String KERAS_LOSS_COSINE_PROXIMITY = "cosine_proximity";
    private static Logger log = LoggerFactory.getLogger(KerasLayer.class);
    private Map<String, Object> layerConfig;
    private String className;
    private String layerName;
    private DimOrder dimOrder;
    private int[] inputShape;
    private List<String> inboundLayerNames;
    private Layer dl4jLayer;
    private boolean train;

    /* loaded from: input_file:org/deeplearning4j/nn/modelimport/keras/KerasLayer$DimOrder.class */
    public enum DimOrder {
        NONE,
        THEANO,
        TENSORFLOW,
        UNKNOWN
    }

    public KerasLayer(Map<String, Object> map) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        this(map, false);
    }

    public KerasLayer(Map<String, Object> map, boolean z) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        this.dimOrder = DimOrder.NONE;
        this.inboundLayerNames = new ArrayList();
        this.className = (String) checkAndGetField(map, "class_name");
        Map<String, Object> map2 = (Map) checkAndGetField(map, "config");
        for (String str : map.keySet()) {
            if (!str.equals("config")) {
                map2.put(str, map.get(str));
            }
        }
        this.layerConfig = map2;
        this.train = z;
        this.layerName = (String) checkAndGetField(this.layerConfig, LAYER_FIELD_NAME);
        this.dl4jLayer = buildLayerFromConfig(this.layerConfig, this.train);
        this.dimOrder = getDimOrderFromConfig(this.layerConfig);
        this.inputShape = getInputShapeFromConfig(this.layerConfig, this.dimOrder);
        this.inboundLayerNames = getInboundLayerNamesFromConfig(this.layerConfig);
    }

    public Map<String, Object> getConfiguration() {
        return this.layerConfig;
    }

    public String getClassName() {
        return this.className;
    }

    public String getName() {
        return this.layerName;
    }

    public DimOrder getDimOrder() {
        return this.dimOrder;
    }

    public int[] getInputShape() {
        return this.inputShape;
    }

    public List<String> getInboundLayerNames() {
        return this.inboundLayerNames;
    }

    public void setInboundLayerNames(List<String> list) {
        this.inboundLayerNames = new ArrayList(list);
    }

    public void addInboundLayer(String str) {
        this.inboundLayerNames.add(str);
    }

    public boolean getTrain() {
        return this.train;
    }

    public boolean isValidInboundLayer() {
        return this.dl4jLayer != null || this.className.equals(LAYER_CLASS_NAME_INPUT);
    }

    public boolean isDl4jLayer() {
        return this.dl4jLayer != null;
    }

    public Layer getDl4jLayer() {
        return this.dl4jLayer;
    }

    public boolean isDl4jPreprocessor() throws UnsupportedKerasConfigurationException {
        throw new UnsupportedKerasConfigurationException("Conversion from Keras layer to DL4J preprocessor not impemented.");
    }

    public PreprocessorVertex getDl4jPreprocessor() throws UnsupportedKerasConfigurationException {
        throw new UnsupportedKerasConfigurationException("Conversion from Keras layer to DL4J preprocessor not impemented.");
    }

    public static KerasLayer createInputLayer(String str, int[] iArr) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        HashMap hashMap = new HashMap();
        hashMap.put(LAYER_FIELD_NAME, str);
        ArrayList arrayList = new ArrayList();
        arrayList.add(null);
        for (int i : iArr) {
            arrayList.add(Integer.valueOf(i));
        }
        hashMap.put(LAYER_FIELD_BATCH_INPUT_SHAPE, arrayList);
        HashMap hashMap2 = new HashMap();
        hashMap2.put("config", hashMap);
        hashMap2.put("class_name", LAYER_CLASS_NAME_INPUT);
        return new KerasLayer(hashMap2, false);
    }

    public static KerasLayer createLossLayer(String str, String str2) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        return createLossLayer(str, str2, true);
    }

    public static KerasLayer createLossLayer(String str, String str2, boolean z) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        HashMap hashMap = new HashMap();
        hashMap.put(LAYER_FIELD_NAME, str);
        hashMap.put("loss", str2);
        HashMap hashMap2 = new HashMap();
        hashMap2.put("config", hashMap);
        hashMap2.put("class_name", LAYER_CLASS_NAME_LOSS);
        return new KerasLayer(hashMap2, z);
    }

    public static Layer buildLayerFromConfig(Map<String, Object> map) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        return buildLayerFromConfig(map, false);
    }

    public static Layer buildLayerFromConfig(Map<String, Object> map, boolean z) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        if (!map.containsKey("class_name")) {
            throw new InvalidKerasConfigurationException("Missing class_name field.");
        }
        String str = (String) map.get("class_name");
        ActivationLayer activationLayer = null;
        boolean z2 = -1;
        switch (str.hashCode()) {
            case -2123993378:
                if (str.equals(LAYER_CLASS_NAME_REPEATVECTOR)) {
                    z2 = 12;
                    break;
                }
                break;
            case -1768326327:
                if (str.equals(LAYER_CLASS_NAME_TIME_DISTRIBUTED_DENSE)) {
                    z2 = 3;
                    break;
                }
                break;
            case -1533165266:
                if (str.equals(LAYER_CLASS_NAME_RESHAPE)) {
                    z2 = 11;
                    break;
                }
                break;
            case -932923449:
                if (str.equals(LAYER_CLASS_NAME_INPUT)) {
                    z2 = 14;
                    break;
                }
                break;
            case -768792413:
                if (str.equals(LAYER_CLASS_NAME_BATCHNORMALIZATION)) {
                    z2 = 8;
                    break;
                }
                break;
            case -704578081:
                if (str.equals(LAYER_CLASS_NAME_DROPOUT)) {
                    z2 = true;
                    break;
                }
                break;
            case -260285130:
                if (str.equals(LAYER_CLASS_NAME_ACTIVATION)) {
                    z2 = false;
                    break;
                }
                break;
            case 2346560:
                if (str.equals(LAYER_CLASS_NAME_LSTM)) {
                    z2 = 4;
                    break;
                }
                break;
            case 2374467:
                if (str.equals(LAYER_CLASS_NAME_LOSS)) {
                    z2 = 9;
                    break;
                }
                break;
            case 65917695:
                if (str.equals(LAYER_CLASS_NAME_DENSE)) {
                    z2 = 2;
                    break;
                }
                break;
            case 74232856:
                if (str.equals(LAYER_CLASS_NAME_MERGE)) {
                    z2 = 13;
                    break;
                }
                break;
            case 885848548:
                if (str.equals(LAYER_CLASS_NAME_FLATTEN)) {
                    z2 = 10;
                    break;
                }
                break;
            case 1165122971:
                if (str.equals(LAYER_CLASS_NAME_AVERAGE_POOLING_2D)) {
                    z2 = 7;
                    break;
                }
                break;
            case 1268966964:
                if (str.equals(LAYER_CLASS_NAME_MAX_POOLING_2D)) {
                    z2 = 6;
                    break;
                }
                break;
            case 1816290250:
                if (str.equals(LAYER_CLASS_NAME_CONVOLUTION_2D)) {
                    z2 = 5;
                    break;
                }
                break;
        }
        switch (z2) {
            case false:
                activationLayer = buildActivationLayer(map, z);
                break;
            case LAYER_BATCHNORM_MODE_1 /* 1 */:
                activationLayer = buildDropoutLayer(map, z);
                break;
            case LAYER_BATCHNORM_MODE_2 /* 2 */:
            case true:
                activationLayer = buildDenseLayer(map, z);
                break;
            case true:
                activationLayer = buildGravesLstmLayer(map, z);
                break;
            case true:
                activationLayer = buildConvolutionLayer(map, z);
                break;
            case true:
            case true:
                activationLayer = buildSubsamplingLayer(map, z);
                break;
            case true:
                activationLayer = buildBatchNormalizationLayer(map, z);
                break;
            case true:
                activationLayer = buildLossLayer(map, z);
                break;
            case true:
            case true:
            case true:
            case true:
            case true:
                log.warn("Found Keras " + str + ". DL4J adds \"preprocessor\" layers during model compilation: https://github.com/deeplearning4j/deeplearning4j/blob/master/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/MultiLayerConfiguration.java#L429");
                break;
            default:
                throw new InvalidKerasConfigurationException("Unsupported keras layer type " + str);
        }
        return activationLayer;
    }

    public static String mapActivation(String str) {
        String str2;
        boolean z = -1;
        switch (str.hashCode()) {
            case -1102672091:
                if (str.equals(KERAS_ACTIVATION_LINEAR)) {
                    z = false;
                    break;
                }
                break;
            case 1196061498:
                if (str.equals(KERAS_ACTIVATION_HARD_SIGMOID)) {
                    z = true;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                str2 = "identity";
                break;
            case LAYER_BATCHNORM_MODE_1 /* 1 */:
                str2 = DL4J_ACTIVATION_HARDSIGMOID;
                break;
            default:
                str2 = str;
                break;
        }
        return str2;
    }

    public static WeightInit mapWeightInitialization(String str) throws UnsupportedKerasConfigurationException {
        WeightInit weightInit = WeightInit.XAVIER;
        if (str != null) {
            boolean z = -1;
            switch (str.hashCode()) {
                case -1692760471:
                    if (str.equals(INIT_HE_NORMAL)) {
                        z = 2;
                        break;
                    }
                    break;
                case -1179581214:
                    if (str.equals(INIT_GLOROT_UNIFORM)) {
                        z = true;
                        break;
                    }
                    break;
                case -1039745817:
                    if (str.equals(INIT_NORMAL)) {
                        z = 6;
                        break;
                    }
                    break;
                case -286926412:
                    if (str.equals(INIT_UNIFORM)) {
                        z = 5;
                        break;
                    }
                    break;
                case -135761730:
                    if (str.equals("identity")) {
                        z = 7;
                        break;
                    }
                    break;
                case 3735208:
                    if (str.equals("zero")) {
                        z = 4;
                        break;
                    }
                    break;
                case 322700184:
                    if (str.equals(INIT_LECUN_UNIFORM)) {
                        z = 9;
                        break;
                    }
                    break;
                case 944455794:
                    if (str.equals(INIT_HE_UNIFORM)) {
                        z = 3;
                        break;
                    }
                    break;
                case 1160471161:
                    if (str.equals(INIT_ORTHOGONAL)) {
                        z = 8;
                        break;
                    }
                    break;
                case 1286763513:
                    if (str.equals(INIT_GLOROT_NORMAL)) {
                        z = false;
                        break;
                    }
                    break;
            }
            switch (z) {
                case false:
                    weightInit = WeightInit.XAVIER;
                    break;
                case LAYER_BATCHNORM_MODE_1 /* 1 */:
                    weightInit = WeightInit.XAVIER_UNIFORM;
                    break;
                case LAYER_BATCHNORM_MODE_2 /* 2 */:
                    weightInit = WeightInit.RELU;
                    break;
                case true:
                    weightInit = WeightInit.RELU_UNIFORM;
                    break;
                case true:
                    weightInit = WeightInit.ZERO;
                    break;
                case true:
                case true:
                case true:
                case true:
                case true:
                default:
                    throw new UnsupportedKerasConfigurationException("Unknown keras weight initializer " + weightInit);
            }
        }
        return weightInit;
    }

    /* JADX WARN: Failed to find 'out' block for switch in B:55:0x0197. Please report as an issue. */
    public static LossFunctions.LossFunction mapLossFunction(String str) throws UnsupportedKerasConfigurationException {
        LossFunctions.LossFunction lossFunction;
        LossFunctions.LossFunction lossFunction2 = LossFunctions.LossFunction.SQUARED_LOSS;
        boolean z = -1;
        switch (str.hashCode()) {
            case -1457658384:
                if (str.equals(KERAS_LOSS_MCXENT)) {
                    z = 12;
                    break;
                }
                break;
            case -978668753:
                if (str.equals(KERAS_LOSS_SP_XE)) {
                    z = 11;
                    break;
                }
                break;
            case -967804198:
                if (str.equals(KERAS_LOSS_KL_DIVERGENCE_1)) {
                    z = 13;
                    break;
                }
                break;
            case -572202065:
                if (str.equals(KERAS_LOSS_XENT)) {
                    z = 10;
                    break;
                }
                break;
            case -440639562:
                if (str.equals(LOSS_SQUARED_LOSS_1)) {
                    z = false;
                    break;
                }
                break;
            case -400457335:
                if (str.equals(KERAS_LOSS_POISSON)) {
                    z = 15;
                    break;
                }
                break;
            case 106275:
                if (str.equals(KERAS_LOSS_KL_DIVERGENCE_2)) {
                    z = 14;
                    break;
                }
                break;
            case 107857:
                if (str.equals(KERAS_LOSS_MEAN_ABSOLUTE_ERROR_2)) {
                    z = 3;
                    break;
                }
                break;
            case 108415:
                if (str.equals(KERAS_LOSS_SQUARED_LOSS_2)) {
                    z = true;
                    break;
                }
                break;
            case 3344009:
                if (str.equals(KERAS_LOSS_MEAN_ABSOLUTE_PERCENTAGE_ERROR_2)) {
                    z = 5;
                    break;
                }
                break;
            case 3361183:
                if (str.equals(KERAS_LOSS_MEAN_SQUARED_LOGARITHMIC_ERROR_2)) {
                    z = 7;
                    break;
                }
                break;
            case 99283243:
                if (str.equals(KERAS_LOSS_HINGE)) {
                    z = 9;
                    break;
                }
                break;
            case 310787257:
                if (str.equals(KERAS_LOSS_COSINE_PROXIMITY)) {
                    z = 16;
                    break;
                }
                break;
            case 1186645242:
                if (str.equals(KERAS_LOSS_MEAN_ABSOLUTE_ERROR_1)) {
                    z = 2;
                    break;
                }
                break;
            case 1641967089:
                if (str.equals(KERAS_LOSS_MEAN_ABSOLUTE_PERCENTAGE_ERROR_1)) {
                    z = 4;
                    break;
                }
                break;
            case 1818882486:
                if (str.equals(KERAS_LOSS_MEAN_SQUARED_LOGARITHMIC_ERROR_1)) {
                    z = 6;
                    break;
                }
                break;
            case 2061899219:
                if (str.equals(KERAS_LOSS_SQUARED_HINGE)) {
                    z = 8;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
            case LAYER_BATCHNORM_MODE_1 /* 1 */:
                lossFunction = LossFunctions.LossFunction.SQUARED_LOSS;
                return lossFunction;
            case LAYER_BATCHNORM_MODE_2 /* 2 */:
            case true:
                lossFunction = LossFunctions.LossFunction.MEAN_ABSOLUTE_ERROR;
                return lossFunction;
            case true:
            case true:
                lossFunction = LossFunctions.LossFunction.MEAN_ABSOLUTE_PERCENTAGE_ERROR;
                return lossFunction;
            case true:
            case true:
                lossFunction = LossFunctions.LossFunction.MEAN_SQUARED_LOGARITHMIC_ERROR;
                return lossFunction;
            case true:
                lossFunction = LossFunctions.LossFunction.SQUARED_HINGE;
                return lossFunction;
            case true:
                lossFunction = LossFunctions.LossFunction.HINGE;
                return lossFunction;
            case true:
                lossFunction = LossFunctions.LossFunction.XENT;
                return lossFunction;
            case true:
                log.warn("Sparse cross entropy not implemented, using multiclass cross entropy instead.");
            case true:
                lossFunction = LossFunctions.LossFunction.MCXENT;
                return lossFunction;
            case true:
            case true:
                lossFunction = LossFunctions.LossFunction.KL_DIVERGENCE;
                return lossFunction;
            case true:
                lossFunction = LossFunctions.LossFunction.POISSON;
                return lossFunction;
            case true:
                lossFunction = LossFunctions.LossFunction.COSINE_PROXIMITY;
                return lossFunction;
            default:
                throw new UnsupportedKerasConfigurationException("Unknown Keras loss function " + str);
        }
    }

    private DimOrder getDimOrderFromConfig(Map<String, Object> map) {
        DimOrder dimOrder = DimOrder.NONE;
        if (map.containsKey(LAYER_FIELD_DIM_ORDERING)) {
            String str = (String) map.get(LAYER_FIELD_DIM_ORDERING);
            boolean z = -1;
            switch (str.hashCode()) {
                case 3698:
                    if (str.equals(DIM_ORDERING_TENSORFLOW)) {
                        z = false;
                        break;
                    }
                    break;
                case 3700:
                    if (str.equals(DIM_ORDERING_THEANO)) {
                        z = true;
                        break;
                    }
                    break;
            }
            switch (z) {
                case false:
                    dimOrder = DimOrder.TENSORFLOW;
                    break;
                case LAYER_BATCHNORM_MODE_1 /* 1 */:
                    dimOrder = DimOrder.THEANO;
                    break;
                default:
                    log.warn("Keras layer has unknown Keras dimension order: " + dimOrder);
                    break;
            }
        }
        return dimOrder;
    }

    private int[] getInputShapeFromConfig(Map<String, Object> map, DimOrder dimOrder) {
        if (!map.containsKey(LAYER_FIELD_BATCH_INPUT_SHAPE)) {
            return null;
        }
        List list = (List) map.get(LAYER_FIELD_BATCH_INPUT_SHAPE);
        int[] iArr = new int[list.size() - 1];
        for (int i = 1; i < list.size(); i++) {
            iArr[i - 1] = list.get(i) != null ? ((Integer) list.get(i)).intValue() : 0;
        }
        if (dimOrder == DimOrder.THEANO && iArr.length == 3 && (this.dl4jLayer instanceof ConvolutionLayer)) {
            int i2 = iArr[0];
            iArr[0] = iArr[1];
            iArr[1] = iArr[2];
            iArr[2] = i2;
        }
        return iArr;
    }

    private static List<String> getInboundLayerNamesFromConfig(Map<String, Object> map) {
        ArrayList arrayList = new ArrayList();
        if (map.containsKey(LAYER_FIELD_INBOUND_NODES)) {
            List list = (List) map.get(LAYER_FIELD_INBOUND_NODES);
            if (list.size() > 0) {
                Iterator it = ((List) list.get(0)).iterator();
                while (it.hasNext()) {
                    arrayList.add((String) ((List) it.next()).get(0));
                }
            }
        }
        return arrayList;
    }

    private static double getL1Regularization(Map<String, Object> map) {
        if (map == null || !map.containsKey(REGULARIZATION_TYPE_L1)) {
            return 0.0d;
        }
        return ((Double) map.get(REGULARIZATION_TYPE_L1)).doubleValue();
    }

    private static double getL2Regularization(Map<String, Object> map) {
        if (map == null || !map.containsKey(REGULARIZATION_TYPE_L2)) {
            return 0.0d;
        }
        return ((Double) map.get(REGULARIZATION_TYPE_L2)).doubleValue();
    }

    private static void checkForUnknownRegularizer(Map<String, Object> map, boolean z) throws UnsupportedKerasConfigurationException {
        if (map != null) {
            Set<String> keySet = map.keySet();
            keySet.remove(REGULARIZATION_TYPE_L1);
            keySet.remove(REGULARIZATION_TYPE_L2);
            keySet.remove(LAYER_FIELD_NAME);
            if (keySet.size() > 0) {
                String str = (String) keySet.toArray()[0];
                if (z) {
                    throw new UnsupportedKerasConfigurationException("Unknown regularization field " + str);
                }
                log.warn("Ignoring unknown regularization field " + str);
            }
        }
    }

    private static ActivationLayer buildActivationLayer(Map<String, Object> map, boolean z) throws UnsupportedKerasConfigurationException {
        ActivationLayer.Builder builder = new ActivationLayer.Builder();
        finishLayerConfig(builder, map, z);
        return builder.build();
    }

    private static DropoutLayer buildDropoutLayer(Map<String, Object> map, boolean z) throws UnsupportedKerasConfigurationException {
        DropoutLayer.Builder builder = new DropoutLayer.Builder();
        finishLayerConfig(builder, map, z);
        return builder.build();
    }

    private static DenseLayer buildDenseLayer(Map<String, Object> map, boolean z) throws UnsupportedKerasConfigurationException {
        DenseLayer.Builder nOut = new DenseLayer.Builder().nOut(((Integer) map.get(LAYER_FIELD_OUTPUT_DIM)).intValue());
        finishLayerConfig(nOut, map, z);
        return nOut.build();
    }

    private static ConvolutionLayer buildConvolutionLayer(Map<String, Object> map, boolean z) throws UnsupportedKerasConfigurationException {
        List list = (List) map.get(LAYER_FIELD_SUBSAMPLE);
        int intValue = ((Integer) map.get(LAYER_FIELD_NB_ROW)).intValue();
        int intValue2 = ((Integer) map.get(LAYER_FIELD_NB_COL)).intValue();
        String str = (String) map.get(LAYER_FIELD_BORDER_MODE);
        ConvolutionLayer.Builder nOut = new ConvolutionLayer.Builder().stride(new int[]{((Integer) list.get(0)).intValue(), ((Integer) list.get(1)).intValue()}).kernelSize(new int[]{intValue, intValue2}).nOut(((Integer) map.get(LAYER_FIELD_NB_FILTER)).intValue());
        boolean z2 = -1;
        switch (str.hashCode()) {
            case 3154575:
                if (str.equals(LAYER_BORDER_MODE_FULL)) {
                    z2 = 2;
                    break;
                }
                break;
            case 3522662:
                if (str.equals(LAYER_BORDER_MODE_SAME)) {
                    z2 = false;
                    break;
                }
                break;
            case 111972348:
                if (str.equals(LAYER_BORDER_MODE_VALID)) {
                    z2 = true;
                    break;
                }
                break;
        }
        switch (z2) {
            case false:
                nOut.convolutionMode(ConvolutionMode.Same);
                break;
            case LAYER_BATCHNORM_MODE_1 /* 1 */:
                nOut.convolutionMode(ConvolutionMode.Truncate);
                break;
            case LAYER_BATCHNORM_MODE_2 /* 2 */:
                nOut.convolutionMode(ConvolutionMode.Truncate).padding(new int[]{intValue - 1, intValue2 - 1});
                break;
        }
        finishLayerConfig(nOut, map, z);
        return nOut.build();
    }

    private static SubsamplingLayer buildSubsamplingLayer(Map<String, Object> map, boolean z) throws UnsupportedKerasConfigurationException {
        List list = (List) map.get(LAYER_FIELD_STRIDES);
        List list2 = (List) map.get(LAYER_FIELD_POOL_SIZE);
        SubsamplingLayer.Builder kernelSize = new SubsamplingLayer.Builder().stride(new int[]{((Integer) list.get(0)).intValue(), ((Integer) list.get(1)).intValue()}).kernelSize(new int[]{((Integer) list2.get(0)).intValue(), ((Integer) list2.get(1)).intValue()});
        String str = (String) map.get("class_name");
        boolean z2 = -1;
        switch (str.hashCode()) {
            case 1165122971:
                if (str.equals(LAYER_CLASS_NAME_AVERAGE_POOLING_2D)) {
                    z2 = true;
                    break;
                }
                break;
            case 1268966964:
                if (str.equals(LAYER_CLASS_NAME_MAX_POOLING_2D)) {
                    z2 = false;
                    break;
                }
                break;
        }
        switch (z2) {
            case false:
                kernelSize.poolingType(SubsamplingLayer.PoolingType.MAX);
                break;
            case LAYER_BATCHNORM_MODE_1 /* 1 */:
                kernelSize.poolingType(SubsamplingLayer.PoolingType.AVG);
                break;
            default:
                throw new UnsupportedKerasConfigurationException("Unsupported Keras pooling layer " + str);
        }
        String str2 = (String) map.get(LAYER_FIELD_BORDER_MODE);
        boolean z3 = -1;
        switch (str2.hashCode()) {
            case 3154575:
                if (str2.equals(LAYER_BORDER_MODE_FULL)) {
                    z3 = 2;
                    break;
                }
                break;
            case 3522662:
                if (str2.equals(LAYER_BORDER_MODE_SAME)) {
                    z3 = false;
                    break;
                }
                break;
            case 111972348:
                if (str2.equals(LAYER_BORDER_MODE_VALID)) {
                    z3 = true;
                    break;
                }
                break;
        }
        switch (z3) {
            case false:
                kernelSize.convolutionMode(ConvolutionMode.Same);
                break;
            case LAYER_BATCHNORM_MODE_1 /* 1 */:
                kernelSize.convolutionMode(ConvolutionMode.Truncate);
                break;
            case LAYER_BATCHNORM_MODE_2 /* 2 */:
                kernelSize.convolutionMode(ConvolutionMode.Truncate).padding(new int[]{((Integer) list2.get(0)).intValue() - 1, ((Integer) list2.get(1)).intValue() - 1});
                break;
        }
        finishLayerConfig(kernelSize, map, z);
        return kernelSize.build();
    }

    private static GravesLSTM buildGravesLstmLayer(Map<String, Object> map, boolean z) throws UnsupportedKerasConfigurationException {
        if (!map.get(LAYER_FIELD_INIT).equals(map.get(LAYER_FIELD_INNER_INIT))) {
            if (z) {
                throw new UnsupportedKerasConfigurationException("Specifying different initialization for LSTM inner cells not supported.");
            }
            log.warn("Specifying different initialization for LSTM inner cells not supported.");
        }
        if (((Double) map.get(LAYER_FIELD_DROPOUT_U)).doubleValue() > 0.0d) {
            throw new UnsupportedKerasConfigurationException("Dropout > 0 on LSTM recurrent connections not supported.");
        }
        GravesLSTM.Builder builder = new GravesLSTM.Builder();
        builder.nOut(((Integer) map.get(LAYER_FIELD_OUTPUT_DIM)).intValue());
        builder.gateActivationFunction(mapActivation((String) map.get(LAYER_FIELD_INNER_ACTIVATION)));
        String str = (String) map.get(LAYER_FIELD_FORGET_BIAS_INIT);
        boolean z2 = -1;
        switch (str.hashCode()) {
            case 110182:
                if (str.equals(LSTM_FORGET_BIAS_INIT_ONE)) {
                    z2 = true;
                    break;
                }
                break;
            case 3735208:
                if (str.equals("zero")) {
                    z2 = false;
                    break;
                }
                break;
        }
        switch (z2) {
            case false:
                builder.forgetGateBiasInit(0.0d);
                break;
            case LAYER_BATCHNORM_MODE_1 /* 1 */:
                builder.forgetGateBiasInit(1.0d);
                break;
            default:
                if (!z) {
                    builder.forgetGateBiasInit(1.0d);
                    log.warn("Unsupported bias initialization: " + str + ". Using ONE instead");
                    break;
                } else {
                    throw new UnsupportedKerasConfigurationException("Unsupported bias initialization: " + str);
                }
        }
        map.put(LAYER_FIELD_DROPOUT, Double.valueOf(((Double) map.get(LAYER_FIELD_DROPOUT_W)).doubleValue()));
        finishLayerConfig(builder, map, z);
        return builder.build();
    }

    private static BatchNormalization buildBatchNormalizationLayer(Map<String, Object> map, boolean z) throws UnsupportedKerasConfigurationException {
        if (z) {
            if (map.get(LAYER_FIELD_GAMMA_REGULARIZER) != null) {
                throw new UnsupportedKerasConfigurationException("Regularization for BatchNormalization gamma parameter not supported");
            }
            log.warn("Regularization for BatchNormalization gamma parameter not supported...ignoring.");
        }
        if (z) {
            if (map.get(LAYER_FIELD_BETA_REGULARIZER) != null) {
                throw new UnsupportedKerasConfigurationException("Regularization for BatchNormalization beta parameter not supported");
            }
            log.warn("Regularization for BatchNormalization beta parameter not supported...ignoring.");
        }
        switch (((Integer) map.get(LAYER_FIELD_MODE)).intValue()) {
            case LAYER_BATCHNORM_MODE_1 /* 1 */:
                throw new UnsupportedKerasConfigurationException("Keras BatchNormalization mode 1 (sample-wise) not supported");
            case LAYER_BATCHNORM_MODE_2 /* 2 */:
                throw new UnsupportedKerasConfigurationException("Keras BatchNormalization (per-batch statistics during testing) 2 not supported");
            default:
                log.warn("Ignoring BatchNormalization axis=" + ((Integer) map.get(LAYER_FIELD_AXIS)).intValue() + " config. DL4J BatchNormalization defaults to the \"channels\" axis");
                BatchNormalization.Builder builder = new BatchNormalization.Builder();
                builder.eps(((Double) map.get(LAYER_FIELD_EPSILON)).doubleValue()).momentum(((Double) map.get(LAYER_FIELD_MOMENTUM)).doubleValue());
                finishLayerConfig(builder, map, z);
                return builder.build();
        }
    }

    private static LossLayer buildLossLayer(Map<String, Object> map, boolean z) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        LossFunctions.LossFunction lossFunction;
        try {
            lossFunction = mapLossFunction((String) checkAndGetField(map, "loss"));
        } catch (UnsupportedKerasConfigurationException e) {
            if (z) {
                throw e;
            }
            log.warn("Unsupported Keras loss function. Replacing with MSE.");
            lossFunction = LossFunctions.LossFunction.SQUARED_LOSS;
        }
        LossLayer.Builder builder = new LossLayer.Builder(lossFunction);
        finishLayerConfig(builder, map, z);
        return builder.build();
    }

    private static Layer.Builder finishLayerConfig(Layer.Builder builder, Map<String, Object> map, boolean z) throws UnsupportedKerasConfigurationException {
        WeightInit weightInit;
        if (map.containsKey(LAYER_FIELD_DROPOUT)) {
            builder.dropOut(1.0d - ((Double) map.get(LAYER_FIELD_DROPOUT)).doubleValue());
        }
        if (map.containsKey(LAYER_FIELD_ACTIVATION)) {
            builder.activation(mapActivation((String) map.get(LAYER_FIELD_ACTIVATION)));
        }
        builder.name((String) map.get(LAYER_FIELD_NAME));
        if (map.containsKey(LAYER_FIELD_INIT)) {
            String str = (String) map.get(LAYER_FIELD_INIT);
            try {
                weightInit = mapWeightInitialization(str);
            } catch (UnsupportedKerasConfigurationException e) {
                if (z) {
                    throw e;
                }
                weightInit = WeightInit.XAVIER;
                log.warn("Unknown weight initializer " + str + " (Using XAVIER instead).");
            }
            builder.weightInit(weightInit);
            if (weightInit == WeightInit.ZERO) {
                builder.biasInit(0.0d);
            }
        }
        if (map.containsKey(LAYER_FIELD_W_REGULARIZER)) {
            Map map2 = (Map) map.get(LAYER_FIELD_W_REGULARIZER);
            double l1Regularization = getL1Regularization(map2);
            if (l1Regularization > 0.0d) {
                builder.l1(l1Regularization);
            }
            double l2Regularization = getL2Regularization(map2);
            if (l2Regularization > 0.0d) {
                builder.l2(l2Regularization);
            }
            checkForUnknownRegularizer(map2, z);
        }
        if (map.containsKey(LAYER_FIELD_B_REGULARIZER)) {
            Map map3 = (Map) map.get(LAYER_FIELD_B_REGULARIZER);
            double l1Regularization2 = getL1Regularization(map3);
            double l2Regularization2 = getL2Regularization(map3);
            if (l1Regularization2 > 0.0d || l2Regularization2 > 0.0d) {
                if (z) {
                    throw new UnsupportedKerasConfigurationException("Bias regularization not implemented");
                }
                log.warn("Bias regularization not supported. Ignoring.");
            }
        }
        return builder;
    }

    private static Object checkAndGetField(Map<String, Object> map, String str) throws InvalidKerasConfigurationException {
        if (map.containsKey(str)) {
            return map.get(str);
        }
        throw new InvalidKerasConfigurationException("Field " + str + " missing from layer config");
    }
}
