package org.deeplearning4j.nn.modelimport.keras;

import java.io.IOException;
import java.io.InputStream;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import org.bytedeco.javacpp.BytePointer;
import org.bytedeco.javacpp.FloatPointer;
import org.bytedeco.javacpp.Loader;
import org.bytedeco.javacpp.hdf5;
import org.deeplearning4j.berkeley.StringUtils;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.modelimport.keras.KerasModel;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.shade.jackson.core.type.TypeReference;
import org.nd4j.shade.jackson.databind.DeserializationFeature;
import org.nd4j.shade.jackson.databind.ObjectMapper;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/nn/modelimport/keras/KerasModelImport.class */
public class KerasModelImport {
    private static final Logger log = LoggerFactory.getLogger(KerasModelImport.class);
    private String modelJson;
    private String trainingJson;
    private String modelClassName;
    private Map<String, Map<String, INDArray>> weights;

    public static ComputationGraph importKerasModelAndWeights(InputStream inputStream) throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        KerasModelImport kerasModelImport = new KerasModelImport(inputStream);
        if (kerasModelImport.getModelClassName().equals(KerasModel.MODEL_CLASS_NAME_MODEL)) {
            return new KerasModel.ModelBuilder().modelJson(kerasModelImport.getModelJson()).trainingJson(kerasModelImport.getTrainingJson()).weights(kerasModelImport.getWeights()).train(false).buildModel().getComputationGraph();
        }
        throw new InvalidKerasConfigurationException("Expected Keras model class name Model (found " + kerasModelImport.getModelClassName() + ")");
    }

    public static MultiLayerNetwork importKerasSequentialModelAndWeights(InputStream inputStream) throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        KerasModelImport kerasModelImport = new KerasModelImport(inputStream);
        if (kerasModelImport.getModelClassName().equals(KerasModel.MODEL_CLASS_NAME_MODEL)) {
            return new KerasModel.ModelBuilder().modelJson(kerasModelImport.getModelJson()).trainingJson(kerasModelImport.getTrainingJson()).weights(kerasModelImport.getWeights()).train(false).buildSequential().getMultiLayerNetwork();
        }
        throw new InvalidKerasConfigurationException("Expected Keras model class name Model (found " + kerasModelImport.getModelClassName() + ")");
    }

    public static ComputationGraph importKerasModelAndWeights(String str) throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        KerasModelImport kerasModelImport = new KerasModelImport(str);
        if (kerasModelImport.getModelClassName().equals(KerasModel.MODEL_CLASS_NAME_MODEL)) {
            return new KerasModel.ModelBuilder().modelJson(kerasModelImport.getModelJson()).trainingJson(kerasModelImport.getTrainingJson()).weights(kerasModelImport.getWeights()).train(false).buildModel().getComputationGraph();
        }
        throw new InvalidKerasConfigurationException("Expected Keras model class name Model (found " + kerasModelImport.getModelClassName() + ")");
    }

    public static MultiLayerNetwork importKerasSequentialModelAndWeights(String str) throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        KerasModelImport kerasModelImport = new KerasModelImport(str);
        if (kerasModelImport.getModelClassName().equals(KerasModel.MODEL_CLASS_NAME_SEQUENTIAL)) {
            return new KerasModel.ModelBuilder().modelJson(kerasModelImport.getModelJson()).trainingJson(kerasModelImport.getTrainingJson()).weights(kerasModelImport.getWeights()).train(false).buildSequential().getMultiLayerNetwork();
        }
        throw new InvalidKerasConfigurationException("Expected Keras model class name Sequential (found " + kerasModelImport.getModelClassName() + ")");
    }

    public static ComputationGraph importKerasModelAndWeights(String str, String str2) throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        KerasModelImport kerasModelImport = new KerasModelImport(str, str2);
        if (kerasModelImport.getModelClassName().equals(KerasModel.MODEL_CLASS_NAME_SEQUENTIAL)) {
            return new KerasModel.ModelBuilder().modelJson(kerasModelImport.getModelJson()).weights(kerasModelImport.getWeights()).train(false).buildModel().getComputationGraph();
        }
        throw new InvalidKerasConfigurationException("Expected Keras model class name Model (found " + kerasModelImport.getModelClassName() + ")");
    }

    public static MultiLayerNetwork importKerasSequentialModelAndWeights(String str, String str2) throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        KerasModelImport kerasModelImport = new KerasModelImport(str, str2);
        if (kerasModelImport.getModelClassName().equals(KerasModel.MODEL_CLASS_NAME_SEQUENTIAL)) {
            return new KerasModel.ModelBuilder().modelJson(kerasModelImport.getModelJson()).trainingJson(kerasModelImport.getTrainingJson()).weights(kerasModelImport.getWeights()).train(false).buildSequential().getMultiLayerNetwork();
        }
        throw new InvalidKerasConfigurationException("Expected Keras model class name Sequential (found " + kerasModelImport.getModelClassName() + ")");
    }

    public static ComputationGraphConfiguration importKerasModelConfiguration(String str) throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        return new KerasModel.ModelBuilder().modelJson(new String(Files.readAllBytes(Paths.get(str, new String[0])))).train(false).buildModel().getComputationGraphConfiguration();
    }

    public static MultiLayerConfiguration importKerasSequentialConfiguration(String str) throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        return new KerasModel.ModelBuilder().modelJson(new String(Files.readAllBytes(Paths.get(str, new String[0])))).train(false).buildSequential().getMultiLayerConfiguration();
    }

    public KerasModelImport(InputStream inputStream) throws UnsupportedOperationException, IOException, UnsupportedKerasConfigurationException, InvalidKerasConfigurationException {
        log.warn("Importing a Keras model from an InputStream pointing to contents of an HDF5 archive currently not supported.");
        throw new UnsupportedOperationException("Importing a Keras model from an InputStream currently not supported because it is not possible to load an HDF5 file from a memory buffer using the HDF5 C++ API. See: http://stackoverflow.com/questions/18449972/how-can-i-open-hdf5-file-from-memory-buffer-using-hdf5-c-api");
    }

    public KerasModelImport(String str) throws IOException, UnsupportedKerasConfigurationException, InvalidKerasConfigurationException {
        hdf5.H5File h5File = new hdf5.H5File(str, hdf5.H5F_ACC_RDONLY);
        this.modelJson = readJsonStringFromHdf5Attribute(h5File, "model_config");
        this.modelClassName = getModelClassName(this.modelJson);
        this.trainingJson = readJsonStringFromHdf5Attribute(h5File, "training_config");
        this.weights = readWeightsFromHdf5(h5File, "/model_weights");
        h5File.close();
    }

    public KerasModelImport(String str, String str2) throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        this.modelJson = new String(Files.readAllBytes(Paths.get(str, new String[0])));
        this.modelClassName = getModelClassName(this.modelJson);
        hdf5.H5File h5File = new hdf5.H5File(str2, hdf5.H5F_ACC_RDONLY);
        this.weights = readWeightsFromHdf5(h5File, "/");
        h5File.close();
    }

    public String getModelJson() {
        return this.modelJson;
    }

    public String getTrainingJson() {
        return this.trainingJson;
    }

    public String getModelClassName() {
        return this.modelClassName;
    }

    public Map<String, Map<String, INDArray>> getWeights() {
        return this.weights;
    }

    private static Map<String, Map<String, INDArray>> readWeightsFromHdf5(hdf5.H5File h5File, String str) throws UnsupportedKerasConfigurationException {
        INDArray create;
        hdf5.Group openGroup = h5File.asCommonFG().openGroup(str);
        HashMap hashMap = new HashMap();
        ArrayList arrayList = new ArrayList();
        arrayList.add(openGroup);
        while (!arrayList.isEmpty()) {
            hdf5.Group group = (hdf5.Group) arrayList.remove(0);
            for (int i = 0; i < group.asCommonFG().getNumObjs(); i++) {
                BytePointer objnameByIdx = group.asCommonFG().getObjnameByIdx(i);
                String string = objnameByIdx.getString();
                switch (group.asCommonFG().childObjType(objnameByIdx)) {
                    case 0:
                    default:
                        arrayList.add(group.asCommonFG().openGroup(objnameByIdx));
                        break;
                    case KerasLayer.LAYER_BATCHNORM_MODE_1 /* 1 */:
                        String[] split = string.split("_");
                        String join = StringUtils.join(Arrays.copyOfRange(split, 0, 2), "_");
                        String join2 = StringUtils.join(Arrays.copyOfRange(split, 2, split.length), "_");
                        Matcher matcher = Pattern.compile(":\\d+$").matcher(join2);
                        hdf5.DataSet openDataSet = group.asCommonFG().openDataSet(objnameByIdx);
                        hdf5.DataSpace space = openDataSet.getSpace();
                        int simpleExtentNdims = space.getSimpleExtentNdims();
                        long[] jArr = new long[simpleExtentNdims];
                        space.getSimpleExtentDims(jArr);
                        if (matcher.find()) {
                            join2 = matcher.replaceFirst("");
                        }
                        switch (simpleExtentNdims) {
                            case KerasLayer.LAYER_BATCHNORM_MODE_1 /* 1 */:
                                float[] fArr = new float[(int) jArr[0]];
                                FloatPointer floatPointer = new FloatPointer(fArr);
                                openDataSet.read(floatPointer, new hdf5.DataType(hdf5.PredType.NATIVE_FLOAT()));
                                floatPointer.get(fArr);
                                create = Nd4j.create((int) jArr[0]);
                                int i2 = 0;
                                for (int i3 = 0; i3 < jArr[0]; i3++) {
                                    int i4 = i2;
                                    i2++;
                                    create.putScalar(i3, fArr[i4]);
                                }
                                break;
                            case KerasLayer.LAYER_BATCHNORM_MODE_2 /* 2 */:
                                float[] fArr2 = new float[(int) (jArr[0] * jArr[1])];
                                FloatPointer floatPointer2 = new FloatPointer(fArr2);
                                openDataSet.read(floatPointer2, new hdf5.DataType(hdf5.PredType.NATIVE_FLOAT()));
                                floatPointer2.get(fArr2);
                                create = Nd4j.create((int) jArr[0], (int) jArr[1]);
                                int i5 = 0;
                                for (int i6 = 0; i6 < jArr[0]; i6++) {
                                    for (int i7 = 0; i7 < jArr[1]; i7++) {
                                        int i8 = i5;
                                        i5++;
                                        create.putScalar(i6, i7, fArr2[i8]);
                                    }
                                }
                                break;
                            case 3:
                            default:
                                throw new UnsupportedKerasConfigurationException("Cannot import weights with rank " + simpleExtentNdims);
                            case 4:
                                float[] fArr3 = new float[(int) (jArr[0] * jArr[1] * jArr[2] * jArr[3])];
                                FloatPointer floatPointer3 = new FloatPointer(fArr3);
                                openDataSet.read(floatPointer3, new hdf5.DataType(hdf5.PredType.NATIVE_FLOAT()));
                                floatPointer3.get(fArr3);
                                create = Nd4j.create(new int[]{(int) jArr[0], (int) jArr[1], (int) jArr[2], (int) jArr[3]});
                                int i9 = 0;
                                for (int i10 = 0; i10 < jArr[0]; i10++) {
                                    for (int i11 = 0; i11 < jArr[1]; i11++) {
                                        for (int i12 = 0; i12 < jArr[2]; i12++) {
                                            for (int i13 = 0; i13 < jArr[3]; i13++) {
                                                int i14 = i9;
                                                i9++;
                                                create.putScalar(i10, i11, i12, i13, fArr3[i14]);
                                            }
                                        }
                                    }
                                }
                                break;
                        }
                        if (!hashMap.containsKey(join)) {
                            hashMap.put(join, new HashMap());
                        }
                        ((Map) hashMap.get(join)).put(join2, create);
                        openDataSet.close();
                        break;
                }
            }
            group.close();
        }
        h5File.close();
        return hashMap;
    }

    private static String readJsonStringFromHdf5Attribute(hdf5.H5File h5File, String str) throws InvalidKerasConfigurationException {
        hdf5.Attribute openAttribute = h5File.openAttribute(str);
        hdf5.VarLenType varLenType = openAttribute.getVarLenType();
        int i = 1;
        do {
            byte[] bArr = new byte[i * 2000];
            BytePointer bytePointer = new BytePointer(bArr);
            openAttribute.read(varLenType, bytePointer);
            bytePointer.get(bArr);
            String str2 = new String(bArr);
            ObjectMapper objectMapper = new ObjectMapper();
            objectMapper.enable(DeserializationFeature.FAIL_ON_READING_DUP_TREE_KEY);
            try {
                objectMapper.readTree(str2);
                return str2;
            } catch (IOException e) {
                i++;
            }
        } while (i <= 100);
        throw new InvalidKerasConfigurationException("Could not read abnormally long Keras config. Please file an issue!");
    }

    private static String getModelClassName(String str) throws IOException, InvalidKerasConfigurationException {
        Map map = (Map) new ObjectMapper().readValue(str, new TypeReference<HashMap<String, Object>>() { // from class: org.deeplearning4j.nn.modelimport.keras.KerasModelImport.1
        });
        if (map.containsKey("class_name")) {
            return (String) map.get("class_name");
        }
        throw new InvalidKerasConfigurationException("Unable to determine Keras model class name.");
    }

    static {
        try {
            Loader.load(hdf5.class);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }
}
