package org.deeplearning4j.remote;

import java.util.List;
import lombok.NonNull;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.api.ModelAdapter;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.parallelism.ParallelInference;
import org.deeplearning4j.parallelism.inference.InferenceMode;
import org.deeplearning4j.parallelism.inference.LoadBalanceMode;
import org.deeplearning4j.remote.DL4jServlet;
import org.nd4j.adapters.InferenceAdapter;
import org.nd4j.adapters.InputAdapter;
import org.nd4j.adapters.OutputAdapter;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.MultiDataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.remote.SameDiffJsonModelServer;
import org.nd4j.remote.clients.serde.BinaryDeserializer;
import org.nd4j.remote.clients.serde.BinarySerializer;
import org.nd4j.remote.clients.serde.JsonDeserializer;
import org.nd4j.remote.clients.serde.JsonSerializer;

/* loaded from: input_file:org/deeplearning4j/remote/JsonModelServer.class */
public class JsonModelServer<I, O> extends SameDiffJsonModelServer<I, O> {
    protected ParallelInference parallelInference;
    protected ModelAdapter<O> modelAdapter;
    protected ComputationGraph cgModel;
    protected MultiLayerNetwork mlnModel;
    protected InferenceMode inferenceMode;
    protected int numWorkers;
    protected boolean enabledParallel;

    /* loaded from: input_file:org/deeplearning4j/remote/JsonModelServer$Builder.class */
    public static class Builder<I, O> {
        private SameDiff sdModel;
        private ComputationGraph cgModel;
        private MultiLayerNetwork mlnModel;
        private ParallelInference pi;
        private String[] orderedInputNodes;
        private String[] orderedOutputNodes;
        private InferenceAdapter<I, O> inferenceAdapter;
        private JsonSerializer<O> serializer;
        private JsonDeserializer<I> deserializer;
        private BinarySerializer<O> binarySerializer;
        private BinaryDeserializer<I> binaryDeserializer;
        private InputAdapter<I> inputAdapter;
        private OutputAdapter<O> outputAdapter;
        private int port;
        private boolean parallelMode = true;
        private InferenceMode inferenceMode = InferenceMode.BATCHED;
        private int numWorkers = Nd4j.getAffinityManager().getNumberOfDevices();

        public Builder(@NonNull SameDiff sameDiff) {
            if (sameDiff == null) {
                throw new NullPointerException("sdModel is marked @NonNull but is null");
            }
            this.sdModel = sameDiff;
        }

        public Builder(@NonNull MultiLayerNetwork multiLayerNetwork) {
            if (multiLayerNetwork == null) {
                throw new NullPointerException("mlnModel is marked @NonNull but is null");
            }
            this.mlnModel = multiLayerNetwork;
        }

        public Builder(@NonNull ComputationGraph computationGraph) {
            if (computationGraph == null) {
                throw new NullPointerException("cgModel is marked @NonNull but is null");
            }
            this.cgModel = computationGraph;
        }

        public Builder(@NonNull ParallelInference parallelInference) {
            if (parallelInference == null) {
                throw new NullPointerException("pi is marked @NonNull but is null");
            }
            this.pi = parallelInference;
        }

        public Builder<I, O> inferenceAdapter(@NonNull InferenceAdapter<I, O> inferenceAdapter) {
            if (inferenceAdapter == null) {
                throw new NullPointerException("inferenceAdapter is marked @NonNull but is null");
            }
            this.inferenceAdapter = inferenceAdapter;
            return this;
        }

        public Builder<I, O> inputAdapter(@NonNull InputAdapter<I> inputAdapter) {
            if (inputAdapter == null) {
                throw new NullPointerException("inputAdapter is marked @NonNull but is null");
            }
            this.inputAdapter = inputAdapter;
            return this;
        }

        public Builder<I, O> outputAdapter(@NonNull OutputAdapter<O> outputAdapter) {
            if (outputAdapter == null) {
                throw new NullPointerException("outputAdapter is marked @NonNull but is null");
            }
            this.outputAdapter = outputAdapter;
            return this;
        }

        public Builder<I, O> outputSerializer(@NonNull JsonSerializer<O> jsonSerializer) {
            if (jsonSerializer == null) {
                throw new NullPointerException("serializer is marked @NonNull but is null");
            }
            this.serializer = jsonSerializer;
            return this;
        }

        public Builder<I, O> inputDeserializer(@NonNull JsonDeserializer<I> jsonDeserializer) {
            if (jsonDeserializer == null) {
                throw new NullPointerException("deserializer is marked @NonNull but is null");
            }
            this.deserializer = jsonDeserializer;
            return this;
        }

        public Builder<I, O> outputBinarySerializer(@NonNull BinarySerializer<O> binarySerializer) {
            if (binarySerializer == null) {
                throw new NullPointerException("serializer is marked @NonNull but is null");
            }
            this.binarySerializer = binarySerializer;
            return this;
        }

        public Builder<I, O> inputBinaryDeserializer(@NonNull BinaryDeserializer<I> binaryDeserializer) {
            if (binaryDeserializer == null) {
                throw new NullPointerException("deserializer is marked @NonNull but is null");
            }
            this.binaryDeserializer = binaryDeserializer;
            return this;
        }

        public Builder<I, O> inferenceMode(@NonNull InferenceMode inferenceMode) {
            if (inferenceMode == null) {
                throw new NullPointerException("inferenceMode is marked @NonNull but is null");
            }
            this.inferenceMode = inferenceMode;
            return this;
        }

        public Builder<I, O> numWorkers(int i) {
            this.numWorkers = i;
            return this;
        }

        public Builder<I, O> orderedInputNodes(String... strArr) {
            this.orderedInputNodes = strArr;
            return this;
        }

        public Builder<I, O> orderedInputNodes(@NonNull List<String> list) {
            if (list == null) {
                throw new NullPointerException("args is marked @NonNull but is null");
            }
            this.orderedInputNodes = (String[]) list.toArray(new String[list.size()]);
            return this;
        }

        public Builder<I, O> orderedOutputNodes(String... strArr) {
            Preconditions.checkArgument(strArr != null && strArr.length > 0, "OutputNodes should contain at least 1 element");
            this.orderedOutputNodes = strArr;
            return this;
        }

        public Builder<I, O> orderedOutputNodes(@NonNull List<String> list) {
            if (list == null) {
                throw new NullPointerException("args is marked @NonNull but is null");
            }
            Preconditions.checkArgument(list.size() > 0, "OutputNodes should contain at least 1 element");
            this.orderedOutputNodes = (String[]) list.toArray(new String[list.size()]);
            return this;
        }

        public Builder<I, O> port(int i) {
            this.port = i;
            return this;
        }

        public Builder<I, O> parallelMode(boolean z) {
            this.parallelMode = z;
            return this;
        }

        public JsonModelServer<I, O> build() {
            JsonModelServer<I, O> jsonModelServer;
            if (this.inferenceAdapter == null) {
                if (this.inputAdapter == null || this.outputAdapter == null) {
                    throw new IllegalArgumentException("Either InferenceAdapter<I,O> or InputAdapter<I> + OutputAdapter<O> should be configured");
                }
                this.inferenceAdapter = new InferenceAdapter<I, O>() { // from class: org.deeplearning4j.remote.JsonModelServer.Builder.1
                    public MultiDataSet apply(I i) {
                        return Builder.this.inputAdapter.apply(i);
                    }

                    public O apply(INDArray... iNDArrayArr) {
                        return (O) Builder.this.outputAdapter.apply(iNDArrayArr);
                    }
                };
            }
            if (this.sdModel != null) {
                jsonModelServer = new JsonModelServer<>(this.sdModel, this.inferenceAdapter, this.serializer, this.deserializer, this.binarySerializer, this.binaryDeserializer, this.port, this.orderedInputNodes, this.orderedOutputNodes);
            } else if (this.cgModel != null) {
                jsonModelServer = new JsonModelServer<>(this.cgModel, this.inferenceAdapter, this.serializer, this.deserializer, this.binarySerializer, this.binaryDeserializer, this.port, this.inferenceMode, this.numWorkers);
            } else if (this.mlnModel != null) {
                jsonModelServer = new JsonModelServer<>(this.mlnModel, this.inferenceAdapter, this.serializer, this.deserializer, this.binarySerializer, this.binaryDeserializer, this.port, this.inferenceMode, this.numWorkers);
            } else {
                if (this.pi == null) {
                    throw new IllegalStateException("No models were defined for JsonModelServer");
                }
                jsonModelServer = new JsonModelServer<>(this.pi, this.inferenceAdapter, this.serializer, this.deserializer, this.binarySerializer, this.binaryDeserializer, this.port);
            }
            jsonModelServer.enabledParallel = this.parallelMode;
            return jsonModelServer;
        }
    }

    protected JsonModelServer(@NonNull SameDiff sameDiff, InferenceAdapter<I, O> inferenceAdapter, JsonSerializer<O> jsonSerializer, JsonDeserializer<I> jsonDeserializer, BinarySerializer<O> binarySerializer, BinaryDeserializer<I> binaryDeserializer, int i, String[] strArr, String[] strArr2) {
        super(sameDiff, inferenceAdapter, jsonSerializer, jsonDeserializer, binarySerializer, binaryDeserializer, i, strArr, strArr2);
        this.enabledParallel = true;
        if (sameDiff == null) {
            throw new NullPointerException("sdModel is marked @NonNull but is null");
        }
    }

    protected JsonModelServer(@NonNull ComputationGraph computationGraph, InferenceAdapter<I, O> inferenceAdapter, JsonSerializer<O> jsonSerializer, JsonDeserializer<I> jsonDeserializer, BinarySerializer<O> binarySerializer, BinaryDeserializer<I> binaryDeserializer, int i, @NonNull InferenceMode inferenceMode, int i2) {
        super(inferenceAdapter, jsonSerializer, jsonDeserializer, binarySerializer, binaryDeserializer, i);
        this.enabledParallel = true;
        if (computationGraph == null) {
            throw new NullPointerException("cgModel is marked @NonNull but is null");
        }
        if (inferenceMode == null) {
            throw new NullPointerException("inferenceMode is marked @NonNull but is null");
        }
        this.cgModel = computationGraph;
        this.inferenceMode = inferenceMode;
        this.numWorkers = i2;
    }

    protected JsonModelServer(@NonNull MultiLayerNetwork multiLayerNetwork, InferenceAdapter<I, O> inferenceAdapter, JsonSerializer<O> jsonSerializer, JsonDeserializer<I> jsonDeserializer, BinarySerializer<O> binarySerializer, BinaryDeserializer<I> binaryDeserializer, int i, @NonNull InferenceMode inferenceMode, int i2) {
        super(inferenceAdapter, jsonSerializer, jsonDeserializer, binarySerializer, binaryDeserializer, i);
        this.enabledParallel = true;
        if (multiLayerNetwork == null) {
            throw new NullPointerException("mlnModel is marked @NonNull but is null");
        }
        if (inferenceMode == null) {
            throw new NullPointerException("inferenceMode is marked @NonNull but is null");
        }
        this.mlnModel = multiLayerNetwork;
        this.inferenceMode = inferenceMode;
        this.numWorkers = i2;
    }

    protected JsonModelServer(@NonNull ParallelInference parallelInference, InferenceAdapter<I, O> inferenceAdapter, JsonSerializer<O> jsonSerializer, JsonDeserializer<I> jsonDeserializer, BinarySerializer<O> binarySerializer, BinaryDeserializer<I> binaryDeserializer, int i) {
        super(inferenceAdapter, jsonSerializer, jsonDeserializer, binarySerializer, binaryDeserializer, i);
        this.enabledParallel = true;
        if (parallelInference == null) {
            throw new NullPointerException("pi is marked @NonNull but is null");
        }
        this.parallelInference = parallelInference;
    }

    public void stop() throws Exception {
        if (this.parallelInference != null) {
            this.parallelInference.shutdown();
        }
        super.stop();
    }

    public void start() throws Exception {
        if (this.sdModel != null) {
            super.start();
            return;
        }
        Preconditions.checkArgument((this.cgModel == null && this.mlnModel == null) ? false : true, "Model serving requires either MultilayerNetwork or ComputationGraph defined");
        ComputationGraph computationGraph = this.cgModel != null ? this.cgModel : this.mlnModel;
        if (this.enabledParallel) {
            if (this.parallelInference == null) {
                Preconditions.checkArgument(this.numWorkers >= 1, "Number of workers should be >= 1, got " + this.numWorkers + " instead");
                this.parallelInference = new ParallelInference.Builder(computationGraph).inferenceMode(this.inferenceMode).workers(this.numWorkers).loadBalanceMode(LoadBalanceMode.FIFO).batchLimit(16).queueLimit(128).build();
            }
            this.servingServlet = new DL4jServlet.Builder(this.parallelInference).parallelEnabled(true).serializer(this.serializer).deserializer(this.deserializer).binarySerializer(this.binarySerializer).binaryDeserializer(this.binaryDeserializer).inferenceAdapter(this.inferenceAdapter).build();
        } else {
            this.servingServlet = new DL4jServlet.Builder((Model) computationGraph).parallelEnabled(false).serializer(this.serializer).deserializer(this.deserializer).binarySerializer(this.binarySerializer).binaryDeserializer(this.binaryDeserializer).inferenceAdapter(this.inferenceAdapter).build();
        }
        start(this.port, this.servingServlet);
    }
}
