package ai.konduit.serving.models.deeplearning4j.step;

import ai.konduit.serving.annotation.runner.CanRun;
import ai.konduit.serving.models.deeplearning4j.step.keras.KerasStep;
import ai.konduit.serving.pipeline.api.context.Context;
import ai.konduit.serving.pipeline.api.data.Data;
import ai.konduit.serving.pipeline.api.data.NDArray;
import ai.konduit.serving.pipeline.api.exception.ModelLoadingException;
import ai.konduit.serving.pipeline.api.protocol.URIResolver;
import ai.konduit.serving.pipeline.api.step.PipelineStep;
import ai.konduit.serving.pipeline.api.step.PipelineStepRunner;
import ai.konduit.serving.pipeline.impl.data.JData;
import java.io.File;
import java.io.IOException;
import java.io.PrintWriter;
import java.io.StringWriter;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.function.Function;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.modelimport.keras.KerasModel;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasModelBuilder;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.util.DL4JModelValidator;
import org.nd4j.common.base.Preconditions;
import org.nd4j.common.validation.ValidationResult;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@CanRun({DL4JStep.class, KerasStep.class})
/* loaded from: input_file:ai/konduit/serving/models/deeplearning4j/step/DL4JRunner.class */
public class DL4JRunner implements PipelineStepRunner {
    private static final Logger log = LoggerFactory.getLogger(DL4JRunner.class);
    public static final String DEFAULT_OUT_NAME_SINGLE = "default";
    private DL4JStep step;
    private KerasStep kStep;
    private MultiLayerNetwork net;
    private ComputationGraph graph;

    public DL4JRunner(KerasStep kerasStep) {
        this.kStep = kerasStep;
        try {
            File file = URIResolver.getFile(kerasStep.modelUri());
            KerasModelBuilder enforceTrainingConfig = new KerasModel().modelBuilder().modelHdf5Filename(file.isAbsolute() ? file.getAbsolutePath() : file.getPath()).enforceTrainingConfig(false);
            try {
                this.graph = enforceTrainingConfig.buildModel().getComputationGraph();
                this.net = null;
            } catch (UnsupportedKerasConfigurationException e) {
                throw new RuntimeException("Unsupported Keras layer found in model", e);
            } catch (Throwable th) {
                if (th.getMessage() == null || !th.getMessage().toLowerCase().contains("sequential")) {
                    throw new ModelLoadingException("Failed to load Keras model: model file is invalid or can't be loaded by DL4JRunner", th);
                }
                try {
                    this.net = enforceTrainingConfig.buildSequential().getMultiLayerNetwork();
                    this.graph = null;
                } catch (Throwable th2) {
                    throw new ModelLoadingException("Failed to load Keras Sequential model: model file is invalid or can't be loaded by DL4JRunner", th2);
                }
            }
        } catch (InvalidKerasConfigurationException | UnsupportedKerasConfigurationException e2) {
            throw new ModelLoadingException("Failed to load Keras model: model file is invalid or can't be loaded by DL4JRunner", e2);
        } catch (IOException e3) {
            throw new ModelLoadingException("Failed to load Keras model", e3);
        }
    }

    public DL4JRunner(DL4JStep dL4JStep) {
        this.step = dL4JStep;
        if (this.step.loaderClass() != null) {
            try {
                Object apply = ((Function) Class.forName(this.step.loaderClass()).newInstance()).apply(dL4JStep.modelUri());
                if (apply instanceof MultiLayerNetwork) {
                    this.net = (MultiLayerNetwork) apply;
                    this.graph = null;
                } else {
                    if (!(apply instanceof ComputationGraph)) {
                        throw new ModelLoadingException("DL4JStep: loaderClass=\"" + this.step.loaderClass() + "\" return " + (apply == null ? "null" : apply.getClass().getName()) + " not a MultiLayerNetwork / ComputationGraph");
                    }
                    this.net = null;
                    this.graph = (ComputationGraph) apply;
                }
            } catch (ClassNotFoundException e) {
                throw new ModelLoadingException("DL4JStep: loaderClass=\"" + this.step.loaderClass() + "\" was provided but no class with this name exists", e);
            } catch (IllegalAccessException | InstantiationException e2) {
                throw new ModelLoadingException("DL4JStep: loaderClass=\"" + this.step.loaderClass() + "\" was provided but an instance of this class could not be constructed", e2);
            }
        } else {
            try {
                File file = URIResolver.getFile(dL4JStep.modelUri());
                Preconditions.checkState(file.exists() && file.isFile(), "Could not load MultiLayerNetwork/ComputationGraph from URI {}, file path {}: file does not exist", dL4JStep.modelUri(), file.getAbsolutePath());
                ValidationResult validateMultiLayerNetwork = DL4JModelValidator.validateMultiLayerNetwork(file);
                ValidationResult validateComputationGraph = DL4JModelValidator.validateComputationGraph(file);
                boolean isValid = validateMultiLayerNetwork.isValid();
                boolean z = !isValid && validateComputationGraph.isValid();
                if (!isValid && !z) {
                    StringBuilder sb = new StringBuilder("Model at URI " + dL4JStep.modelUri() + " is not a valid MultiLayerNetwork or ComputationGraph model.\n");
                    sb.append("Attempt to load as MultiLayerNetwork: \n");
                    sb.append("Issues: ").append(validateMultiLayerNetwork.getIssues()).append("\n");
                    if (validateMultiLayerNetwork.getException() != null) {
                        StringWriter stringWriter = new StringWriter();
                        validateMultiLayerNetwork.getException().printStackTrace(new PrintWriter(stringWriter));
                        sb.append(stringWriter.toString());
                        sb.append("\n");
                    }
                    sb.append("Attempt to load as ComputationGraph: \n");
                    sb.append("Issues: ").append(validateComputationGraph.getIssues());
                    if (validateComputationGraph.getException() != null) {
                        StringWriter stringWriter2 = new StringWriter();
                        validateComputationGraph.getException().printStackTrace(new PrintWriter(stringWriter2));
                        sb.append(stringWriter2.toString());
                        sb.append("\n");
                    }
                    throw new IllegalStateException(sb.toString());
                }
                if (isValid) {
                    try {
                        this.net = MultiLayerNetwork.load(file, false);
                        this.graph = null;
                    } catch (IOException e3) {
                        throw new ModelLoadingException("Failed to load Deeplearning4J MultiLayerNetwork from URI " + dL4JStep.modelUri(), e3);
                    }
                } else {
                    try {
                        this.graph = ComputationGraph.load(file, false);
                        this.net = null;
                    } catch (IOException e4) {
                        throw new ModelLoadingException("Failed to load Deeplearning4J ComputationGraph from URI " + dL4JStep.modelUri(), e4);
                    }
                }
            } catch (IOException e5) {
                throw new ModelLoadingException("Failed to load Deeplearning4J model (MultiLayerNetwork or ComputationGraph) from URI " + dL4JStep.modelUri(), e5);
            }
        }
        Nd4j.getExecutioner().enableDebugMode(dL4JStep.debugMode());
        Nd4j.getExecutioner().enableVerboseMode(dL4JStep.verboseMode());
    }

    public void close() {
        try {
            if (this.net != null) {
                this.net.close();
            } else {
                this.graph.close();
            }
        } catch (Throwable th) {
            log.warn("Error when closing model", th);
        }
    }

    public PipelineStep getPipelineStep() {
        return this.step != null ? this.step : this.kStep;
    }

    public Data exec(Context context, Data data) {
        INDArray[] iNDArrayArr;
        INDArray[] output;
        INDArray output2;
        int numInputArrays = this.net != null ? 1 : this.graph.getNumInputArrays();
        Preconditions.checkArgument(numInputArrays == data.size(), "Expected %s inputs to DL4JStep but got Data instance with %s inputs (keys: %s)", Integer.valueOf(numInputArrays), Integer.valueOf(data.size()), data.keys());
        if (this.net != null) {
            INDArray onlyArray = getOnlyArray(data);
            synchronized (this.net) {
                output2 = this.net.output(onlyArray);
            }
            return Data.singleton(outputName(), NDArray.create(output2));
        }
        if (numInputArrays == 1) {
            iNDArrayArr = new INDArray[]{getOnlyArray(data)};
        } else if (this.step.inputNames() != null) {
            iNDArrayArr = new INDArray[numInputArrays];
            int i = 0;
            Iterator it = this.step.inputNames().iterator();
            while (it.hasNext()) {
                int i2 = i;
                i++;
                iNDArrayArr[i2] = (INDArray) data.getNDArray((String) it.next()).get();
            }
        } else {
            List networkInputs = this.graph.getConfiguration().getNetworkInputs();
            if (!data.hasAll(networkInputs)) {
                throw new IllegalStateException("Network has " + numInputArrays + " inputs, but no Data input names were specified. Attempting to infer input names also failed: Model has input names " + networkInputs + " but Data object has keys " + data.keys());
            }
            iNDArrayArr = new INDArray[numInputArrays];
            int i3 = 0;
            Iterator it2 = networkInputs.iterator();
            while (it2.hasNext()) {
                int i4 = i3;
                i3++;
                iNDArrayArr[i4] = (INDArray) data.getNDArray((String) it2.next()).get();
            }
        }
        synchronized (this.graph) {
            output = this.graph.output(iNDArrayArr);
        }
        List<String> outputNames = outputNames() != null ? outputNames() : output.length == 1 ? Collections.singletonList(DEFAULT_OUT_NAME_SINGLE) : this.graph.getConfiguration().getNetworkOutputs();
        Preconditions.checkState(outputNames.size() == output.length);
        JData.DataBuilder builder = JData.builder();
        for (int i5 = 0; i5 < output.length; i5++) {
            builder.add(outputNames.get(i5), NDArray.create(output[i5]));
        }
        return builder.build();
    }

    private List<String> outputNames() {
        return this.step != null ? this.step.outputNames() : this.kStep.outputNames();
    }

    private String outputName() {
        return this.step != null ? (this.step.outputNames() == null || this.step.outputNames().isEmpty()) ? DEFAULT_OUT_NAME_SINGLE : (String) this.step.outputNames().get(0) : (this.kStep.outputNames() == null || this.kStep.outputNames().isEmpty()) ? DEFAULT_OUT_NAME_SINGLE : (String) this.kStep.outputNames().get(0);
    }

    private INDArray getOnlyArray(Data data) {
        return (INDArray) data.getNDArray((String) data.keys().get(0)).getAs(INDArray.class);
    }
}
