package org.eclipse.deeplearning4j.omnihub;

import java.io.BufferedInputStream;
import java.io.File;
import java.io.IOException;
import java.net.HttpURLConnection;
import java.net.MalformedURLException;
import java.net.URI;
import java.net.URL;
import java.net.URLConnection;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.input.CountingInputStream;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.omnihub.OmnihubConfig;
import org.nd4j.autodiff.samediff.SameDiff;

/* loaded from: input_file:org/eclipse/deeplearning4j/omnihub/OmniHubUtils.class */
public class OmniHubUtils {
    public static MultiLayerNetwork loadNetwork(String str) throws IOException {
        return loadNetwork(str, false);
    }

    public static MultiLayerNetwork loadNetwork(String str, boolean z) throws IOException {
        return MultiLayerNetwork.load(downloadAndLoadFromZoo("dl4j", str, z), true);
    }

    public static ComputationGraph loadCompGraph(String str) throws IOException {
        return loadCompGraph(str, false);
    }

    public static ComputationGraph loadCompGraph(String str, boolean z) throws IOException {
        return ComputationGraph.load(downloadAndLoadFromZoo("dl4j", str, z), true);
    }

    public static SameDiff loadSameDiffModel(String str) {
        return loadSameDiffModel(str, false);
    }

    public static SameDiff loadSameDiffModel(String str, boolean z) {
        return SameDiff.load(downloadAndLoadFromZoo("samediff", str, z), true);
    }

    public static File downloadAndLoadFromZoo(String str, String str2, boolean z) {
        File file = new File(new File(OmnihubConfig.getOmnihubHome(), str), str2);
        if (z && file.exists()) {
            file.delete();
        }
        if (!file.exists()) {
            try {
                CountingInputStream progressInputStream = new ProgressInputStream(new BufferedInputStream(URI.create(OmnihubConfig.getOmnihubUrl() + "/" + str + "/" + str2).toURL().openStream()), getFileSize(URI.create(r0).toURL()));
                try {
                    FileUtils.copyInputStreamToFile(progressInputStream, file);
                    progressInputStream.close();
                } catch (Throwable th) {
                    try {
                        progressInputStream.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                    throw th;
                }
            } catch (MalformedURLException e) {
                e.printStackTrace();
            } catch (IOException e2) {
                e2.printStackTrace();
            }
        }
        return file;
    }

    private static int getFileSize(URL url) {
        URLConnection uRLConnection = null;
        try {
            try {
                uRLConnection = url.openConnection();
                if (uRLConnection instanceof HttpURLConnection) {
                    ((HttpURLConnection) uRLConnection).setRequestMethod("HEAD");
                }
                uRLConnection.getInputStream();
                int contentLength = uRLConnection.getContentLength();
                if (uRLConnection instanceof HttpURLConnection) {
                    ((HttpURLConnection) uRLConnection).disconnect();
                }
                return contentLength;
            } catch (IOException e) {
                throw new RuntimeException(e);
            }
        } catch (Throwable th) {
            if (uRLConnection instanceof HttpURLConnection) {
                ((HttpURLConnection) uRLConnection).disconnect();
            }
            throw th;
        }
    }
}
