package org.springframework.cloud.fn.image.recognition;

import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.Charset;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.springframework.cloud.fn.common.tensorflow.GraphRunner;
import org.springframework.cloud.fn.common.tensorflow.GraphRunnerMemory;
import org.springframework.cloud.fn.common.tensorflow.ProtoBufGraphDefinition;
import org.springframework.core.io.DefaultResourceLoader;
import org.springframework.core.io.Resource;
import org.springframework.util.StreamUtils;
import org.tensorflow.Tensor;
import org.tensorflow.op.core.Max;
import org.tensorflow.op.core.Placeholder;
import org.tensorflow.op.dtypes.Cast;
import org.tensorflow.op.image.DecodeJpeg;
import org.tensorflow.op.image.ResizeBilinear;
import org.tensorflow.op.nn.TopK;

/* loaded from: input_file:org/springframework/cloud/fn/image/recognition/ImageRecognition.class */
public class ImageRecognition implements AutoCloseable {
    private final List<String> labels;
    private final GraphRunner imageNormalization;
    private final GraphRunner imageRecognition;
    private final GraphRunner maxProbability = new GraphRunner(Arrays.asList("recognition_result"), Arrays.asList("category", "probability")).withGraphDefinition(ops -> {
        Placeholder placeholder = ops.withName("recognition_result").placeholder(Float.class, new Placeholder.Options[0]);
        ops.withName("category").math.argMax(placeholder, ops.constant(1));
        ops.withName("probability").max(placeholder, ops.constant(1), new Max.Options[0]);
    });
    private final GraphRunner topKProbabilities;

    public ImageRecognition(String str, String str2, int i, int i2, float f, float f2, String str3, String str4, int i3, boolean z) {
        this.labels = labels(str2);
        this.imageNormalization = new GraphRunner("raw_image", "normalized_image").withGraphDefinition(ops -> {
            ops.withName("normalized_image").math.div(ops.math.sub(ops.image.resizeBilinear(ops.expandDims(ops.dtypes.cast(ops.image.decodeJpeg(ops.withName("raw_image").placeholder(String.class, new Placeholder.Options[0]), new DecodeJpeg.Options[]{DecodeJpeg.channels(3L)}), Float.class, new Cast.Options[0]), ops.constant(0)), ops.constant(new int[]{i, i2}), new ResizeBilinear.Options[0]), ops.constant(f)), ops.constant(f2));
        });
        this.imageRecognition = new GraphRunner(str3, str4).withGraphDefinition(new ProtoBufGraphDefinition(toResource(str), z));
        this.topKProbabilities = new GraphRunner("recognition_result", "topK").withGraphDefinition(ops2 -> {
            ops2.withName("topK").nn.topK(ops2.withName("recognition_result").placeholder(Float.class, new Placeholder.Options[0]), ops2.constant(i3), new TopK.Options[]{TopK.sorted(true)});
        });
    }

    public Map<String, Double> recognizeMax(byte[] bArr) {
        Tensor create = Tensor.create(bArr);
        try {
            GraphRunnerMemory graphRunnerMemory = new GraphRunnerMemory();
            try {
                Map map = (Map) this.imageNormalization.andThen(graphRunnerMemory).andThen(this.imageRecognition).andThen(graphRunnerMemory).andThen(this.maxProbability).andThen(graphRunnerMemory).apply(Collections.singletonMap("raw_image", create));
                long[] jArr = new long[1];
                ((Tensor) map.get("category")).copyTo(jArr);
                ((Tensor) map.get("probability")).copyTo(new float[1]);
                Map<String, Double> singletonMap = Collections.singletonMap(this.labels.get((int) jArr[0]), Double.valueOf(r0[0]));
                graphRunnerMemory.close();
                if (create != null) {
                    create.close();
                }
                return singletonMap;
            } finally {
            }
        } catch (Throwable th) {
            if (create != null) {
                try {
                    create.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    public Map<String, Double> recognizeTopK(byte[] bArr) {
        Tensor create = Tensor.create(bArr);
        try {
            GraphRunnerMemory graphRunnerMemory = new GraphRunnerMemory();
            try {
                Map map = (Map) this.imageNormalization.andThen(graphRunnerMemory).andThen(this.imageRecognition).andThen(graphRunnerMemory).andThen(this.topKProbabilities).andThen(graphRunnerMemory).apply(Collections.singletonMap("raw_image", create));
                Tensor tensor = (Tensor) graphRunnerMemory.getTensorMap().get(this.imageRecognition.getSingleFetchName());
                float[][] fArr = new float[(int) tensor.shape()[0]][(int) tensor.shape()[1]];
                tensor.copyTo(fArr);
                Tensor expect = ((Tensor) map.get("topK")).expect(Float.class);
                float[][] fArr2 = new float[(int) expect.shape()[0]][(int) expect.shape()[1]];
                expect.copyTo(fArr2);
                float f = fArr2[0][fArr2[0].length - 1];
                HashMap hashMap = new HashMap();
                for (int i = 0; i < fArr[0].length; i++) {
                    if (fArr[0][i] >= f) {
                        hashMap.put(Float.valueOf(fArr[0][i]), Integer.valueOf(i));
                    }
                }
                LinkedHashMap linkedHashMap = new LinkedHashMap();
                for (float f2 : fArr2[0]) {
                    linkedHashMap.put(this.labels.get(((Integer) hashMap.get(Float.valueOf(f2))).intValue()), Double.valueOf(f2));
                }
                graphRunnerMemory.close();
                if (create != null) {
                    create.close();
                }
                return linkedHashMap;
            } finally {
            }
        } catch (Throwable th) {
            if (create != null) {
                try {
                    create.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    private Resource toResource(String str) {
        return new DefaultResourceLoader().getResource(str);
    }

    private List<String> labels(String str) {
        try {
            InputStream inputStream = toResource(str).getInputStream();
            try {
                List<String> asList = Arrays.asList(StreamUtils.copyToString(inputStream, Charset.forName("UTF-8")).split("\n"));
                if (inputStream != null) {
                    inputStream.close();
                }
                return asList;
            } finally {
            }
        } catch (IOException e) {
            throw new RuntimeException("Failed to initialize the Vocabulary", e);
        }
    }

    public static ImageRecognition inception(String str, int i, int i2, boolean z) {
        return new ImageRecognition(str, "classpath:/labels/inception_labels.txt", i, i, 117.0f, 1.0f, "input", "output", i2, z);
    }

    public static ImageRecognition mobileNetV2(String str, int i, int i2, boolean z) {
        return new ImageRecognition(str, "classpath:/labels/mobilenet_labels.txt", i, i, 0.0f, 127.0f, "input", "MobilenetV2/Predictions/Reshape_1", i2, z);
    }

    public static ImageRecognition mobileNetV1(String str, int i, int i2, boolean z) {
        return new ImageRecognition(str, "classpath:/labels/mobilenet_labels.txt", i, i, 0.0f, 127.0f, "input", "MobilenetV1/Predictions/Reshape_1", i2, z);
    }

    public static List<RecognitionResponse> toRecognitionResponse(Map<String, Double> map) {
        return (List) map.entrySet().stream().map(entry -> {
            return new RecognitionResponse((String) entry.getKey(), (Double) entry.getValue());
        }).collect(Collectors.toList());
    }

    @Override // java.lang.AutoCloseable
    public void close() {
        this.imageNormalization.close();
        this.imageRecognition.close();
        this.maxProbability.close();
        this.topKProbabilities.close();
    }
}
