package org.springframework.cloud.fn.object.detection;

import java.awt.image.BufferedImage;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.util.List;
import java.util.function.BiFunction;
import javax.imageio.ImageIO;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.cloud.fn.common.tensorflow.deprecated.GraphicsUtils;
import org.springframework.cloud.fn.object.detection.domain.ObjectDetection;
import org.springframework.util.CollectionUtils;

/* loaded from: input_file:org/springframework/cloud/fn/object/detection/ObjectDetectionImageAugmenter.class */
public class ObjectDetectionImageAugmenter implements BiFunction<byte[], List<ObjectDetection>, byte[]> {
    private static final Log logger = LogFactory.getLog(ObjectDetectionImageAugmenter.class);
    public static final String DEFAULT_IMAGE_FORMAT = "jpg";
    private String imageFormat;
    private final boolean withMask;
    private boolean agnosticColors;

    public ObjectDetectionImageAugmenter() {
        this(false);
    }

    public ObjectDetectionImageAugmenter(boolean z) {
        this.imageFormat = DEFAULT_IMAGE_FORMAT;
        this.agnosticColors = false;
        this.withMask = z;
    }

    public boolean isAgnosticColors() {
        return this.agnosticColors;
    }

    public void setAgnosticColors(boolean z) {
        this.agnosticColors = z;
    }

    public String getImageFormat() {
        return this.imageFormat;
    }

    public void setImageFormat(String str) {
        this.imageFormat = str;
    }

    @Override // java.util.function.BiFunction
    public byte[] apply(byte[] bArr, List<ObjectDetection> list) {
        float[][] mask;
        if (!CollectionUtils.isEmpty(list)) {
            try {
                BufferedImage read = ImageIO.read(new ByteArrayInputStream(bArr));
                for (ObjectDetection objectDetection : list) {
                    int y1 = (int) (objectDetection.getY1() * read.getHeight());
                    int x1 = (int) (objectDetection.getX1() * read.getWidth());
                    int y2 = (int) (objectDetection.getY2() * read.getHeight());
                    int x2 = (int) (objectDetection.getX2() * read.getWidth());
                    int cid = objectDetection.getCid();
                    GraphicsUtils.drawBoundingBox(read, cid, objectDetection.getName() + ": " + ((int) (100.0f * objectDetection.getConfidence())) + "%", x1, y1, x2, y2, this.agnosticColors);
                    if (this.withMask && objectDetection.getMask() != null && (mask = objectDetection.getMask()) != null) {
                        GraphicsUtils.overlayImages(read, GraphicsUtils.createMaskImage(mask, x2 - x1, y2 - y1, this.agnosticColors ? null : GraphicsUtils.getClassColor(cid)), x1, y1);
                    }
                }
                bArr = GraphicsUtils.toImageByteArray(read, getImageFormat());
            } catch (IOException e) {
                logger.error(e);
            }
        }
        return bArr;
    }
}
