package com.eshore.tensorflow;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.nio.ByteBuffer;
import java.nio.FloatBuffer;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.apache.commons.io.IOUtils;
import org.tensorflow.Graph;
import org.tensorflow.Operation;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.TensorFlow;
import org.tensorflow.framework.ConfigProto;
import org.tensorflow.types.UInt8;

/* loaded from: input_file:com/eshore/tensorflow/TensorFlowInferenceInterface.class */
public class TensorFlowInferenceInterface {
    private final String modelName;
    private final Graph g;
    private final Session sess;
    private Session.Runner runner;
    private List<String> feedNames;
    private List<Tensor<?>> feedTensors;
    private List<String> fetchNames;
    private List<Tensor<?>> fetchTensors;
    private RunStats runStats;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/eshore/tensorflow/TensorFlowInferenceInterface$TensorId.class */
    public static class TensorId {
        String name;
        int outputIndex;

        private TensorId() {
        }

        public static TensorId parse(String str) {
            TensorId tensorId = new TensorId();
            int lastIndexOf = str.lastIndexOf(58);
            if (lastIndexOf < 0) {
                tensorId.outputIndex = 0;
                tensorId.name = str;
                return tensorId;
            }
            try {
                tensorId.outputIndex = Integer.parseInt(str.substring(lastIndexOf + 1));
                tensorId.name = str.substring(0, lastIndexOf);
            } catch (NumberFormatException e) {
                tensorId.outputIndex = 0;
                tensorId.name = str;
            }
            return tensorId;
        }
    }

    public TensorFlowInferenceInterface(String str, String str2) {
        this.feedNames = new ArrayList();
        this.feedTensors = new ArrayList();
        this.fetchNames = new ArrayList();
        this.fetchTensors = new ArrayList();
        this.modelName = str2;
        this.g = new Graph();
        this.sess = new Session(this.g, ConfigProto.newBuilder().setLogDevicePlacement(true).setAllowSoftPlacement(true).build().toByteArray());
        this.runner = this.sess.runner();
        try {
            loadGraph(str, this.g);
            System.out.println("Successfully loaded model from '" + str2 + "'");
        } catch (IOException e) {
            throw new RuntimeException("Failed to load model from '" + str2 + "'", e);
        }
    }

    public Graph getGraph() {
        return this.g;
    }

    public TensorFlowInferenceInterface(InputStream inputStream) throws IOException {
        this.feedNames = new ArrayList();
        this.feedTensors = new ArrayList();
        this.fetchNames = new ArrayList();
        this.fetchTensors = new ArrayList();
        this.modelName = "unnamed";
        this.g = new Graph();
        this.sess = new Session(this.g, ConfigProto.newBuilder().setLogDevicePlacement(false).setAllowSoftPlacement(true).build().toByteArray());
        this.runner = this.sess.runner();
        loadGraph(inputStream, this.g);
        System.out.println("Successfully loaded model from input stream");
    }

    public void run(String[] strArr) {
        run(strArr, false);
    }

    public void run(String[] strArr, boolean z) {
        closeFetches();
        for (String str : strArr) {
            try {
                this.fetchNames.add(str);
                TensorId parse = TensorId.parse(str);
                this.runner.fetch(parse.name, parse.outputIndex);
            } finally {
                closeFeeds();
                this.runner = this.sess.runner();
            }
        }
        try {
            if (z) {
                Session.Run runAndFetchMetadata = this.runner.setOptions(RunStats.runOptions()).runAndFetchMetadata();
                this.fetchTensors = runAndFetchMetadata.outputs;
                if (this.runStats == null) {
                    this.runStats = new RunStats();
                }
                this.runStats.add(runAndFetchMetadata.metadata);
            } else {
                this.fetchTensors = this.runner.run();
            }
        } catch (RuntimeException e) {
            System.out.println("Failed to run TensorFlow inference with inputs:[" + this.feedNames + "], outputs:[" + this.fetchNames + "]");
            throw e;
        }
    }

    public Graph graph() {
        return this.g;
    }

    public void feed(String str, float[] fArr, long... jArr) {
        addFeed(str, Tensor.create(jArr, FloatBuffer.wrap(fArr)));
    }

    public void feed(String str, FloatBuffer floatBuffer, long... jArr) {
        addFeed(str, Tensor.create(jArr, floatBuffer));
    }

    public void feed(String str, byte[] bArr, long... jArr) {
        addFeed(str, Tensor.create(UInt8.class, jArr, ByteBuffer.wrap(bArr)));
    }

    public TensorFlowInferenceInterface addFeed(String str, Tensor<?> tensor) {
        TensorId parse = TensorId.parse(str);
        this.runner.feed(parse.name, parse.outputIndex, tensor);
        this.feedNames.add(str);
        this.feedTensors.add(tensor);
        return this;
    }

    public static int length(long[] jArr) {
        int i = 1;
        for (long j : jArr) {
            i *= (int) j;
        }
        return i;
    }

    public void fetch(String str, float[] fArr) {
        fetch(str, FloatBuffer.wrap(fArr));
    }

    public void fetch(String str, FloatBuffer floatBuffer) {
        getTensor(str).writeTo(floatBuffer);
    }

    public Tensor<?> getTensor(String str) {
        int i = 0;
        Iterator<String> it = this.fetchNames.iterator();
        while (it.hasNext()) {
            if (it.next().equals(str)) {
                return this.fetchTensors.get(i);
            }
            i++;
        }
        throw new RuntimeException("Node '" + str + "' was not provided to run(), so it cannot be read");
    }

    public Operation graphOperation(String str) {
        Operation operation = this.g.operation(str);
        if (operation == null) {
            throw new RuntimeException("Node '" + str + "' does not exist in model '" + this.modelName + "'");
        }
        return operation;
    }

    public String getStatString() {
        return this.runStats == null ? "" : this.runStats.summary();
    }

    public void close() {
        closeFeeds();
        closeFetches();
        this.sess.close();
        this.g.close();
        if (this.runStats != null) {
            this.runStats.close();
        }
        this.runStats = null;
    }

    protected void finalize() throws Throwable {
        try {
            close();
        } finally {
            super.finalize();
        }
    }

    private void loadGraph(String str, Graph graph) throws IOException {
        long currentTimeMillis = System.currentTimeMillis();
        try {
            graph.importGraphDef(Files.readAllBytes(Paths.get(str, new String[0])));
            System.out.println("Model load took " + (System.currentTimeMillis() - currentTimeMillis) + "ms, TensorFlow version: " + TensorFlow.version());
        } catch (IOException e) {
            System.err.println("Failed to read [" + str + "]: " + e.getMessage());
            throw e;
        }
    }

    private void loadGraph(InputStream inputStream, Graph graph) throws IOException {
        long currentTimeMillis = System.currentTimeMillis();
        try {
            ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(inputStream.available() + 1);
            IOUtils.copy(inputStream, byteArrayOutputStream);
            graph.importGraphDef(byteArrayOutputStream.toByteArray());
            System.out.println("Model load took " + (System.currentTimeMillis() - currentTimeMillis) + "ms, TensorFlow version: " + TensorFlow.version());
        } catch (IOException e) {
            System.err.println("Failed to read [" + inputStream + "]: " + e.getMessage());
            throw e;
        }
    }

    private void closeFeeds() {
        Iterator<Tensor<?>> it = this.feedTensors.iterator();
        while (it.hasNext()) {
            it.next().close();
        }
        this.feedTensors.clear();
        this.feedNames.clear();
    }

    private void closeFetches() {
        Iterator<Tensor<?>> it = this.fetchTensors.iterator();
        while (it.hasNext()) {
            it.next().close();
        }
        this.fetchTensors.clear();
        this.fetchNames.clear();
    }
}
