package org.deeplearning4j.remote;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import javax.servlet.ServletInputStream;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import lombok.NonNull;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.parallelism.ParallelInference;
import org.nd4j.adapters.InferenceAdapter;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.MultiDataSet;
import org.nd4j.remote.clients.serde.BinaryDeserializer;
import org.nd4j.remote.clients.serde.BinarySerializer;
import org.nd4j.remote.clients.serde.JsonDeserializer;
import org.nd4j.remote.clients.serde.JsonSerializer;
import org.nd4j.remote.serving.SameDiffServlet;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/remote/DL4jServlet.class */
public class DL4jServlet<I, O> extends SameDiffServlet<I, O> {
    private static final Logger log = LoggerFactory.getLogger(DL4jServlet.class);
    protected ParallelInference parallelInference;
    protected Model model;
    protected boolean parallelEnabled;

    /* loaded from: input_file:org/deeplearning4j/remote/DL4jServlet$Builder.class */
    public static class Builder<I, O> {
        private ParallelInference pi;
        private Model model;
        private InferenceAdapter<I, O> inferenceAdapter;
        private JsonSerializer<O> serializer;
        private JsonDeserializer<I> deserializer;
        private BinarySerializer<O> binarySerializer;
        private BinaryDeserializer<I> binaryDeserializer;
        private int port;
        private boolean parallelEnabled = true;

        public Builder(@NonNull ParallelInference parallelInference) {
            if (parallelInference == null) {
                throw new NullPointerException("pi is marked non-null but is null");
            }
            this.pi = parallelInference;
        }

        public Builder(@NonNull Model model) {
            if (model == null) {
                throw new NullPointerException("model is marked non-null but is null");
            }
            this.model = model;
        }

        public Builder<I, O> inferenceAdapter(@NonNull InferenceAdapter<I, O> inferenceAdapter) {
            if (inferenceAdapter == null) {
                throw new NullPointerException("inferenceAdapter is marked non-null but is null");
            }
            this.inferenceAdapter = inferenceAdapter;
            return this;
        }

        public Builder<I, O> serializer(JsonSerializer<O> jsonSerializer) {
            this.serializer = jsonSerializer;
            return this;
        }

        public Builder<I, O> deserializer(JsonDeserializer<I> jsonDeserializer) {
            this.deserializer = jsonDeserializer;
            return this;
        }

        public Builder<I, O> binarySerializer(BinarySerializer<O> binarySerializer) {
            this.binarySerializer = binarySerializer;
            return this;
        }

        public Builder<I, O> binaryDeserializer(BinaryDeserializer<I> binaryDeserializer) {
            this.binaryDeserializer = binaryDeserializer;
            return this;
        }

        public Builder<I, O> port(int i) {
            this.port = i;
            return this;
        }

        public Builder<I, O> parallelEnabled(boolean z) {
            this.parallelEnabled = z;
            return this;
        }

        public DL4jServlet<I, O> build() {
            return this.parallelEnabled ? new DL4jServlet<>(this.pi, this.inferenceAdapter, this.serializer, this.deserializer, this.binarySerializer, this.binaryDeserializer) : new DL4jServlet<>(this.model, this.inferenceAdapter, this.serializer, this.deserializer, this.binarySerializer, this.binaryDeserializer);
        }
    }

    public DL4jServlet(@NonNull ParallelInference parallelInference, @NonNull InferenceAdapter<I, O> inferenceAdapter, JsonSerializer<O> jsonSerializer, JsonDeserializer<I> jsonDeserializer) {
        super(inferenceAdapter, jsonSerializer, jsonDeserializer);
        this.parallelEnabled = true;
        if (parallelInference == null) {
            throw new NullPointerException("parallelInference is marked non-null but is null");
        }
        if (inferenceAdapter == null) {
            throw new NullPointerException("inferenceAdapter is marked non-null but is null");
        }
        this.parallelInference = parallelInference;
        this.model = null;
        this.parallelEnabled = true;
    }

    public DL4jServlet(@NonNull Model model, @NonNull InferenceAdapter<I, O> inferenceAdapter, JsonSerializer<O> jsonSerializer, JsonDeserializer<I> jsonDeserializer) {
        super(inferenceAdapter, jsonSerializer, jsonDeserializer);
        this.parallelEnabled = true;
        if (model == null) {
            throw new NullPointerException("model is marked non-null but is null");
        }
        if (inferenceAdapter == null) {
            throw new NullPointerException("inferenceAdapter is marked non-null but is null");
        }
        this.model = model;
        this.parallelInference = null;
        this.parallelEnabled = false;
    }

    public DL4jServlet(@NonNull ParallelInference parallelInference, @NonNull InferenceAdapter<I, O> inferenceAdapter, BinarySerializer<O> binarySerializer, BinaryDeserializer<I> binaryDeserializer) {
        super(inferenceAdapter, binarySerializer, binaryDeserializer);
        this.parallelEnabled = true;
        if (parallelInference == null) {
            throw new NullPointerException("parallelInference is marked non-null but is null");
        }
        if (inferenceAdapter == null) {
            throw new NullPointerException("inferenceAdapter is marked non-null but is null");
        }
        this.parallelInference = parallelInference;
        this.model = null;
        this.parallelEnabled = true;
    }

    public DL4jServlet(@NonNull Model model, @NonNull InferenceAdapter<I, O> inferenceAdapter, JsonSerializer<O> jsonSerializer, JsonDeserializer<I> jsonDeserializer, BinarySerializer<O> binarySerializer, BinaryDeserializer<I> binaryDeserializer) {
        super(inferenceAdapter, jsonSerializer, jsonDeserializer, binarySerializer, binaryDeserializer);
        this.parallelEnabled = true;
        if (model == null) {
            throw new NullPointerException("model is marked non-null but is null");
        }
        if (inferenceAdapter == null) {
            throw new NullPointerException("inferenceAdapter is marked non-null but is null");
        }
        this.model = model;
        this.parallelInference = null;
        this.parallelEnabled = false;
    }

    public DL4jServlet(@NonNull ParallelInference parallelInference, @NonNull InferenceAdapter<I, O> inferenceAdapter, JsonSerializer<O> jsonSerializer, JsonDeserializer<I> jsonDeserializer, BinarySerializer<O> binarySerializer, BinaryDeserializer<I> binaryDeserializer) {
        super(inferenceAdapter, jsonSerializer, jsonDeserializer, binarySerializer, binaryDeserializer);
        this.parallelEnabled = true;
        if (parallelInference == null) {
            throw new NullPointerException("parallelInference is marked non-null but is null");
        }
        if (inferenceAdapter == null) {
            throw new NullPointerException("inferenceAdapter is marked non-null but is null");
        }
        this.parallelInference = parallelInference;
        this.model = null;
        this.parallelEnabled = true;
    }

    private O process(MultiDataSet multiDataSet) {
        Object obj = null;
        if (this.parallelEnabled) {
            obj = this.inferenceAdapter.apply(this.parallelInference.output(multiDataSet.getFeatures(), multiDataSet.getFeaturesMaskArrays()));
        } else {
            synchronized (this) {
                if (this.model instanceof ComputationGraph) {
                    obj = this.inferenceAdapter.apply(this.model.output(false, multiDataSet.getFeatures(), multiDataSet.getFeaturesMaskArrays()));
                } else if (this.model instanceof MultiLayerNetwork) {
                    Preconditions.checkArgument(multiDataSet.getFeatures().length > 0 || (multiDataSet.getFeaturesMaskArrays() != null && multiDataSet.getFeaturesMaskArrays().length > 0), "Input data for MultilayerNetwork is invalid!");
                    InferenceAdapter inferenceAdapter = this.inferenceAdapter;
                    INDArray[] iNDArrayArr = new INDArray[1];
                    iNDArrayArr[0] = this.model.output(multiDataSet.getFeatures()[0], false, multiDataSet.getFeaturesMaskArrays() != null ? multiDataSet.getFeaturesMaskArrays()[0] : null, (INDArray) null);
                    obj = inferenceAdapter.apply(iNDArrayArr);
                }
            }
        }
        return (O) obj;
    }

    protected void doPost(HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse) throws IOException {
        MultiDataSet multiDataSet = null;
        if (!httpServletRequest.getPathInfo().equals("/v1/serving")) {
            sendError(httpServletRequest.getRequestURI(), httpServletResponse);
            return;
        }
        String contentType = httpServletRequest.getContentType();
        if (contentType.equals("application/json")) {
            if (validateRequest(httpServletRequest, httpServletResponse)) {
                BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(httpServletRequest.getInputStream()));
                char[] cArr = new char[128];
                StringBuilder sb = new StringBuilder();
                while (true) {
                    int read = bufferedReader.read(cArr);
                    if (read <= 0) {
                        break;
                    } else {
                        sb.append(cArr, 0, read);
                    }
                }
                multiDataSet = this.inferenceAdapter.apply(this.deserializer.deserialize(sb.toString()));
            }
        } else if (contentType.equals("application/octet-stream")) {
            ServletInputStream inputStream = httpServletRequest.getInputStream();
            int contentLength = httpServletRequest.getContentLength();
            if (contentLength <= 0) {
                httpServletResponse.sendError(411, "Content length is unavailable");
            } else {
                byte[] bArr = new byte[contentLength];
                inputStream.read(bArr, 0, contentLength);
                multiDataSet = this.inferenceAdapter.apply(this.binaryDeserializer.deserialize(bArr));
            }
        }
        if (multiDataSet == null) {
            log.error("InferenceAdapter failed");
            return;
        }
        O process = process(multiDataSet);
        if (this.binarySerializer != null) {
            byte[] serialize = this.binarySerializer.serialize(process);
            httpServletResponse.setContentType("application/octet-stream");
            httpServletResponse.setContentLength(serialize.length);
            httpServletResponse.getOutputStream().write(serialize);
            return;
        }
        try {
            httpServletResponse.getWriter().write(this.serializer.serialize(process));
        } catch (IOException e) {
            log.error(e.getMessage());
        }
    }

    public DL4jServlet() {
        this.parallelEnabled = true;
    }
}
