package com.github.handong0123.tensorflow.deploy.session.model;

import com.github.handong0123.tensorflow.deploy.session.entity.ModelInput;
import com.github.handong0123.tensorflow.deploy.session.entity.ModelOutput;
import com.github.handong0123.tensorflow.deploy.session.entity.ModelParam;
import com.google.common.primitives.Longs;
import java.io.IOException;
import java.lang.reflect.Array;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import javax.annotation.PostConstruct;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.tensorflow.Graph;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.framework.ConfigProto;
import org.tensorflow.framework.GPUOptions;

/* loaded from: input_file:com/github/handong0123/tensorflow/deploy/session/model/TensorflowModelServiceImpl.class */
public class TensorflowModelServiceImpl implements TensorflowModelService {
    private static final Logger LOG = LoggerFactory.getLogger(TensorflowModelServiceImpl.class);
    private static final String DEFAULT_GPU_ID = "-1";
    private static final float DEFAULT_PER_GPU_MEMORY_FRACTION = 0.95f;
    private String modelFile;
    private String modelPath;
    private String gpuId;
    private float perGpuMemoryFraction;
    private Session session;
    private Graph graph;

    public TensorflowModelServiceImpl(String str, String str2) {
        this(str, str2, DEFAULT_GPU_ID, DEFAULT_PER_GPU_MEMORY_FRACTION);
    }

    public TensorflowModelServiceImpl(String str, String str2, String str3) {
        this(str, str2, str3, DEFAULT_PER_GPU_MEMORY_FRACTION);
    }

    public TensorflowModelServiceImpl(String str, String str2, String str3, float f) {
        this.modelFile = str;
        this.modelPath = str2;
        this.perGpuMemoryFraction = f;
        this.gpuId = str3;
        try {
            init();
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    @PostConstruct
    public void init() throws IOException {
        byte[] readAllBytes = Files.readAllBytes(Paths.get(this.modelPath, this.modelFile));
        this.graph = new Graph();
        if (!DEFAULT_GPU_ID.equals(this.gpuId)) {
            this.session = new Session(this.graph, ConfigProto.newBuilder().setGpuOptions(GPUOptions.newBuilder().setVisibleDeviceList(this.gpuId).setPerProcessGpuMemoryFraction(this.perGpuMemoryFraction).setAllowGrowth(true).build()).setAllowSoftPlacement(true).build().toByteArray());
            LOG.info("GPU:model init success,{}", Paths.get(this.modelPath, this.modelFile));
        } else {
            this.graph.importGraphDef(readAllBytes);
            this.session = new Session(this.graph);
            LOG.info("CPU:model init success,{}", Paths.get(this.modelPath, this.modelFile));
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // com.github.handong0123.tensorflow.deploy.session.model.TensorflowModelService
    public ModelOutput predict(ModelInput modelInput) {
        ArrayList arrayList;
        List run;
        Tensor create;
        if (null == modelInput) {
            return null;
        }
        ModelOutput modelOutput = new ModelOutput();
        long currentTimeMillis = System.currentTimeMillis();
        ArrayList arrayList2 = new ArrayList();
        ArrayList arrayList3 = new ArrayList();
        try {
            try {
                Session.Runner runner = this.session.runner();
                Iterator<ModelParam> it = modelInput.getPlaceHolderInput().iterator();
                while (it.hasNext()) {
                    ModelParam next = it.next();
                    Object data = next.getData();
                    if (data instanceof String) {
                        create = Tensor.create(new byte[]{((String) data).getBytes()});
                    } else if (data instanceof String[]) {
                        String[] strArr = (String[]) data;
                        byte[] bArr = new byte[strArr.length];
                        for (int i = 0; i < strArr.length; i++) {
                            bArr[i] = strArr[i].getBytes();
                        }
                        create = Tensor.create(bArr);
                    } else {
                        create = Tensor.create(data);
                    }
                    runner = runner.feed(next.getPlaceHolderName(), create);
                    arrayList2.add(create);
                }
                arrayList = new ArrayList();
                for (String str : modelInput.getExpectedOutput().keySet()) {
                    runner = runner.fetch(str);
                    arrayList.add(str);
                }
                run = runner.run();
                LOG.info("Model Run Cost Time: {}", Long.valueOf(System.currentTimeMillis() - currentTimeMillis));
            } catch (Exception e) {
                e.printStackTrace();
                arrayList2.forEach((v0) -> {
                    v0.close();
                });
                arrayList3.forEach((v0) -> {
                    v0.close();
                });
            }
            if (arrayList.size() != run.size()) {
                throw new Exception("Model Run Error: OutTensor Size Error");
            }
            for (int i2 = 0; i2 < run.size(); i2++) {
                Tensor tensor = (Tensor) run.get(i2);
                String str2 = (String) arrayList.get(i2);
                Object newInstance = Array.newInstance(modelInput.getExpectedOutput().get(str2).getType(), Longs.asList(tensor.shape()).stream().mapToInt((v0) -> {
                    return v0.intValue();
                }).toArray());
                tensor.copyTo(newInstance);
                modelOutput.addOutput(str2, newInstance);
            }
            arrayList2.forEach((v0) -> {
                v0.close();
            });
            run.forEach((v0) -> {
                v0.close();
            });
            return modelOutput;
        } catch (Throwable th) {
            arrayList2.forEach((v0) -> {
                v0.close();
            });
            arrayList3.forEach((v0) -> {
                v0.close();
            });
            throw th;
        }
    }

    @Override // com.github.handong0123.tensorflow.deploy.session.model.TensorflowModelService
    public void modelReload() {
        try {
            Graph graph = new Graph();
            graph.importGraphDef(Files.readAllBytes(Paths.get(this.modelPath, this.modelFile)));
            Session session = new Session(graph);
            synchronized (this) {
                LOG.info("Start Model Reload...");
                this.session.close();
                this.session = session;
                this.graph.close();
                this.graph = graph;
                LOG.info("Finish Model Reload");
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
    }
}
