package org.campagnelab.dl.framework.models;

import java.io.DataInputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.nio.charset.Charset;
import java.util.Properties;
import org.apache.commons.io.FileUtils;
import org.campagnelab.dl.framework.mappers.ConfigurableFeatureMapper;
import org.campagnelab.dl.framework.mappers.FeatureMapper;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/campagnelab/dl/framework/models/ModelLoader.class */
public class ModelLoader {
    private final Properties modelProperties;
    private String modelPath;
    private static Logger LOG = LoggerFactory.getLogger(ModelLoader.class);

    public Properties getModelProperties() {
        return this.modelProperties;
    }

    public ModelLoader(String str) {
        this.modelPath = str;
        try {
            this.modelProperties = loadModelProperties();
        } catch (IOException e) {
            throw new RuntimeException("Unable to load model properties at path " + str);
        }
    }

    public void writeTestCount(long j) {
        try {
            FileInputStream fileInputStream = new FileInputStream(this.modelPath + "/config.properties");
            Properties properties = new Properties();
            properties.load(fileInputStream);
            fileInputStream.close();
            properties.setProperty("testRecordCount", Long.toString(j));
            FileOutputStream fileOutputStream = new FileOutputStream(this.modelPath + "/config.properties");
            properties.store(fileOutputStream, "total testRecords added other settings, for use in statistics");
            fileOutputStream.close();
        } catch (FileNotFoundException e) {
            e.printStackTrace();
        } catch (IOException e2) {
            e2.printStackTrace();
        }
    }

    public FeatureMapper loadFeatureMapper(Properties properties) {
        try {
            Class<?> loadClass = getClass().getClassLoader().loadClass(this.modelProperties.getProperty("mapper"));
            System.out.println("Loaded class name: " + loadClass.getName());
            FeatureMapper featureMapper = (FeatureMapper) loadClass.getConstructor(new Class[0]).newInstance(new Object[0]);
            if (featureMapper instanceof ConfigurableFeatureMapper) {
                ((ConfigurableFeatureMapper) featureMapper).configure(properties);
            }
            return featureMapper;
        } catch (Exception e) {
            throw new RuntimeException("Unable to load model properties and initialize feature mapper.", e);
        }
    }

    private Properties loadModelProperties() throws IOException {
        FileInputStream fileInputStream = new FileInputStream(this.modelPath + "/config.properties");
        Properties properties = new Properties();
        properties.load(fileInputStream);
        if (properties.getProperty("precision") != null && properties.getProperty("precision").equals("FP16")) {
            LOG.info("Model uses FP16 precision. Activating support.");
            DataTypeUtil.setDTypeForContext(DataBuffer.Type.HALF);
        }
        return properties;
    }

    public MultiLayerNetwork loadMultiLayerNetwork(String str) throws IOException {
        MultiLayerNetwork loadModel = loadModel(str);
        if (loadModel instanceof MultiLayerNetwork) {
            return loadModel;
        }
        return null;
    }

    public static String getModelLabel(String str) {
        String[] split = str.split("/");
        String str2 = split[split.length - 1];
        String[] strArr = {"-ComputationGraph.bin", "Model.bin"};
        if (!str2.endsWith(".bin")) {
            return null;
        }
        for (String str3 : strArr) {
            if (str2.endsWith(str3)) {
                return str2.substring(0, str2.length() - str3.length());
            }
        }
        return null;
    }

    public static String getModelPath(String str) {
        return str.endsWith(".bin") ? new File(str).getParent() : str;
    }

    public Model loadModel(String str) throws IOException {
        String path = getPath(str, "/%sModel.bin");
        if (new File(path).exists()) {
            return ModelSerializer.restoreMultiLayerNetwork(path);
        }
        String path2 = getPath(str, "/%s-ComputationGraph.bin");
        if (new File(path2).exists()) {
            return ModelSerializer.restoreComputationGraph(path2);
        }
        if (!new File(path2).exists() && !new File(getPath(str, "/%sModelParams.bin")).exists()) {
            return null;
        }
        INDArray read = Nd4j.read(new DataInputStream(new FileInputStream(getPath(str, "/%sModelParams.bin"))));
        MultiLayerNetwork multiLayerNetwork = new MultiLayerNetwork(MultiLayerConfiguration.fromJson(FileUtils.readFileToString(new File(getPath(str, "/%sModelConf.json")), Charset.defaultCharset())));
        multiLayerNetwork.init();
        multiLayerNetwork.setParameters(read);
        return multiLayerNetwork;
    }

    private String getPath(String str, String str2) {
        return this.modelPath + String.format(str2, str);
    }
}
