package org.springframework.cloud.fn.computer.vision;

import ai.djl.inference.Predictor;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.spring.configuration.DjlAutoConfiguration;
import java.awt.image.RenderedImage;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.function.Supplier;
import javax.imageio.ImageIO;
import org.springframework.boot.autoconfigure.AutoConfiguration;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Bean;
import org.springframework.integration.support.MessageBuilder;
import org.springframework.messaging.Message;

@EnableConfigurationProperties({ComputerVisionFunctionProperties.class})
@AutoConfiguration(after = {DjlAutoConfiguration.class})
/* loaded from: input_file:org/springframework/cloud/fn/computer/vision/ComputerVisionFunctionConfiguration.class */
public class ComputerVisionFunctionConfiguration {
    private static final ImageFactory IMAGE_FACTORY = ImageFactory.getInstance();
    private final Supplier<Predictor<?, ?>> predictorProvider;
    private final ComputerVisionFunctionProperties cvProperties;

    public ComputerVisionFunctionConfiguration(Supplier<Predictor<?, ?>> supplier, ComputerVisionFunctionProperties computerVisionFunctionProperties) {
        this.predictorProvider = supplier;
        this.cvProperties = computerVisionFunctionProperties;
    }

    @ConditionalOnProperty(prefix = "djl", name = {"output-class"}, havingValue = "ai.djl.modality.cv.output.DetectedObjects")
    @Bean(name = {"computerVisionFunction"})
    public Function<Message<byte[]>, Message<byte[]>> objectDetection() {
        return predictor(JsonHelper::toJson, (detectedObjects, image) -> {
            Image duplicate = image.duplicate();
            duplicate.drawBoundingBoxes(detectedObjects);
            return getByteArray((RenderedImage) duplicate.getWrappedImage(), this.cvProperties.getOutputImageFormatName());
        });
    }

    @ConditionalOnProperty(prefix = "djl", name = {"output-class"}, havingValue = "ai.djl.modality.cv.output.CategoryMask")
    @Bean(name = {"computerVisionFunction"})
    public Function<Message<byte[]>, Message<byte[]>> semanticSegmentation() {
        return predictor(JsonHelper::toJson, (categoryMask, image) -> {
            Image duplicate = image.duplicate();
            categoryMask.drawMask(duplicate, 200, 0);
            return getByteArray((RenderedImage) duplicate.getWrappedImage(), this.cvProperties.getOutputImageFormatName());
        });
    }

    @ConditionalOnProperty(prefix = "djl", name = {"output-class"}, havingValue = "ai.djl.modality.Classifications")
    @Bean(name = {"computerVisionFunction"})
    public Function<Message<byte[]>, Message<byte[]>> imageClassifications() {
        return predictor(JsonHelper::toJson, (classifications, image) -> {
            return getByteArray((RenderedImage) image.duplicate().getWrappedImage(), this.cvProperties.getOutputImageFormatName());
        });
    }

    @ConditionalOnProperty(prefix = "djl", name = {"output-class"}, havingValue = "ai.djl.modality.cv.output.Joints")
    @Bean(name = {"computerVisionFunction"})
    public Function<Message<byte[]>, Message<byte[]>> poseEstimation() {
        return predictor(JsonHelper::toJson, (joints, image) -> {
            Image duplicate = image.duplicate();
            duplicate.drawJoints(joints);
            return getByteArray((RenderedImage) duplicate.getWrappedImage(), this.cvProperties.getOutputImageFormatName());
        });
    }

    private <T> Function<Message<byte[]>, Message<byte[]>> predictor(Function<T, String> function, BiFunction<T, Image, byte[]> biFunction) {
        return message -> {
            try {
                Predictor<?, ?> predictor = this.predictorProvider.get();
                try {
                    Image fromInputStream = IMAGE_FACTORY.fromInputStream(new ByteArrayInputStream((byte[]) message.getPayload()));
                    Object predict = predictor.predict(fromInputStream);
                    String str = (String) function.apply(predict);
                    byte[] bArr = (byte[]) message.getPayload();
                    if (this.cvProperties.isAugmentEnabled()) {
                        bArr = (byte[]) biFunction.apply(predict, fromInputStream);
                    }
                    Message build = MessageBuilder.withPayload(bArr).setHeader(this.cvProperties.getOutputHeaderName(), str).build();
                    if (predictor != null) {
                        predictor.close();
                    }
                    return build;
                } finally {
                }
            } catch (Exception e) {
                throw new IllegalStateException(e);
            }
        };
    }

    private static byte[] getByteArray(RenderedImage renderedImage, String str) {
        try {
            ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
            ImageIO.write(renderedImage, str, byteArrayOutputStream);
            return byteArrayOutputStream.toByteArray();
        } catch (IOException e) {
            throw new UncheckedIOException(e);
        }
    }
}
