package ai.djl.modality.cv;

import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.output.BoundingBox;
import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.modality.cv.output.Joints;
import ai.djl.modality.cv.output.Landmark;
import ai.djl.modality.cv.output.Mask;
import ai.djl.modality.cv.output.Point;
import ai.djl.modality.cv.output.Rectangle;
import ai.djl.modality.cv.util.NDImageUtils;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.util.RandomUtils;
import java.awt.BasicStroke;
import java.awt.Color;
import java.awt.FontMetrics;
import java.awt.Graphics2D;
import java.awt.RenderingHints;
import java.awt.image.BufferedImage;
import java.awt.image.ImageObserver;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.ByteBuffer;
import java.nio.file.Path;
import java.util.List;
import java.util.stream.IntStream;
import javax.imageio.ImageIO;

/* loaded from: input_file:ai/djl/modality/cv/BufferedImageFactory.class */
public class BufferedImageFactory extends ImageFactory {

    /* loaded from: input_file:ai/djl/modality/cv/BufferedImageFactory$BufferedImageWrapper.class */
    private class BufferedImageWrapper implements Image {
        private BufferedImage image;

        BufferedImageWrapper(BufferedImage bufferedImage) {
            this.image = bufferedImage;
        }

        @Override // ai.djl.modality.cv.Image
        public int getWidth() {
            return this.image.getWidth();
        }

        @Override // ai.djl.modality.cv.Image
        public int getHeight() {
            return this.image.getHeight();
        }

        @Override // ai.djl.modality.cv.Image
        public Object getWrappedImage() {
            return this.image;
        }

        @Override // ai.djl.modality.cv.Image
        public Image getSubImage(int i, int i2, int i3, int i4) {
            return new BufferedImageWrapper(this.image.getSubimage(i, i2, i3, i4));
        }

        @Override // ai.djl.modality.cv.Image
        public Image duplicate() {
            BufferedImage bufferedImage = new BufferedImage(this.image.getWidth(), this.image.getHeight(), this.image.getType());
            byte[] data = this.image.getRaster().getDataBuffer().getData();
            System.arraycopy(data, 0, bufferedImage.getRaster().getDataBuffer().getData(), 0, data.length);
            return new BufferedImageWrapper(bufferedImage);
        }

        private void convertIdNeeded() {
            if (this.image.getType() == 2) {
                return;
            }
            BufferedImage bufferedImage = new BufferedImage(this.image.getWidth(), this.image.getHeight(), 2);
            Graphics2D createGraphics = bufferedImage.createGraphics();
            createGraphics.drawImage(this.image, 0, 0, (ImageObserver) null);
            createGraphics.dispose();
            this.image = bufferedImage;
        }

        @Override // ai.djl.modality.cv.Image
        public NDArray toNDArray(NDManager nDManager, Image.Flag flag) {
            int width = this.image.getWidth();
            int height = this.image.getHeight();
            int i = flag == Image.Flag.GRAYSCALE ? 1 : 3;
            ByteBuffer allocateDirect = nDManager.allocateDirect(i * height * width);
            if (this.image.getType() == 10) {
                int[] iArr = new int[width * height];
                this.image.getData().getPixels(0, 0, width, height, iArr);
                for (int i2 : iArr) {
                    byte b = (byte) i2;
                    allocateDirect.put(b);
                    if (flag != Image.Flag.GRAYSCALE) {
                        allocateDirect.put(b);
                        allocateDirect.put(b);
                    }
                }
            } else {
                for (int i3 : this.image.getRGB(0, 0, width, height, (int[]) null, 0, width)) {
                    int i4 = (i3 >> 16) & 255;
                    int i5 = (i3 >> 8) & 255;
                    int i6 = i3 & 255;
                    if (flag == Image.Flag.GRAYSCALE) {
                        allocateDirect.put((byte) Math.round((0.299f * i4) + (0.587f * i5) + (0.114f * i6)));
                    } else {
                        allocateDirect.put((byte) i4);
                        allocateDirect.put((byte) i5);
                        allocateDirect.put((byte) i6);
                    }
                }
            }
            allocateDirect.rewind();
            return nDManager.create(allocateDirect, new Shape(height, width, i), DataType.UINT8);
        }

        @Override // ai.djl.modality.cv.Image
        public void save(OutputStream outputStream, String str) throws IOException {
            BufferedImageFactory.this.save(this.image, outputStream, str);
        }

        @Override // ai.djl.modality.cv.Image
        public List<BoundingBox> findBoundingBoxes() {
            throw new UnsupportedOperationException("Not supported for BufferedImage");
        }

        @Override // ai.djl.modality.cv.Image
        public void drawBoundingBoxes(DetectedObjects detectedObjects) {
            convertIdNeeded();
            Graphics2D graphics2D = (Graphics2D) this.image.getGraphics();
            graphics2D.setStroke(new BasicStroke(2));
            graphics2D.setRenderingHint(RenderingHints.KEY_ANTIALIASING, RenderingHints.VALUE_ANTIALIAS_ON);
            int width = this.image.getWidth();
            int height = this.image.getHeight();
            for (DetectedObjects.DetectedObject detectedObject : detectedObjects.items()) {
                String className = detectedObject.getClassName();
                BoundingBox boundingBox = detectedObject.getBoundingBox();
                graphics2D.setPaint(randomColor().darker());
                Rectangle bounds = boundingBox.getBounds();
                int x = (int) (bounds.getX() * width);
                int y = (int) (bounds.getY() * height);
                graphics2D.drawRect(x, y, (int) (bounds.getWidth() * width), (int) (bounds.getHeight() * height));
                drawText(graphics2D, className, x, y, 2, 4);
                if (boundingBox instanceof Mask) {
                    drawMask((Mask) boundingBox);
                } else if (boundingBox instanceof Landmark) {
                    drawLandmarks(boundingBox);
                }
            }
            graphics2D.dispose();
        }

        @Override // ai.djl.modality.cv.Image
        public void drawJoints(Joints joints) {
            convertIdNeeded();
            Graphics2D graphics = this.image.getGraphics();
            graphics.setStroke(new BasicStroke(2));
            int width = this.image.getWidth();
            int height = this.image.getHeight();
            for (Joints.Joint joint : joints.getJoints()) {
                graphics.setPaint(randomColor().darker());
                graphics.fillOval((int) (joint.getX() * width), (int) (joint.getY() * height), 10, 10);
            }
            graphics.dispose();
        }

        private Color randomColor() {
            return new Color(RandomUtils.nextInt(255));
        }

        private void drawText(Graphics2D graphics2D, String str, int i, int i2, int i3, int i4) {
            FontMetrics fontMetrics = graphics2D.getFontMetrics();
            int i5 = i + (i3 / 2);
            int i6 = i2 + (i3 / 2);
            int stringWidth = (fontMetrics.stringWidth(str) + (i4 * 2)) - (i3 / 2);
            int height = fontMetrics.getHeight() + fontMetrics.getDescent();
            int ascent = fontMetrics.getAscent();
            graphics2D.fill(new java.awt.Rectangle(i5, i6, stringWidth, height));
            graphics2D.setPaint(Color.WHITE);
            graphics2D.drawString(str, i5 + i4, i6 + ascent);
        }

        private void drawMask(Mask mask) {
            float nextFloat = RandomUtils.nextFloat();
            float nextFloat2 = RandomUtils.nextFloat();
            float nextFloat3 = RandomUtils.nextFloat();
            int width = this.image.getWidth();
            int height = this.image.getHeight();
            int x = (int) (mask.getX() * width);
            int y = (int) (mask.getY() * height);
            float[][] probDist = mask.getProbDist();
            if (x < 0) {
                x = 0;
            }
            if (y < 0) {
                y = 0;
            }
            BufferedImage bufferedImage = new BufferedImage(probDist.length, probDist[0].length, 2);
            for (int i = 0; i < probDist.length; i++) {
                for (int i2 = 0; i2 < probDist[i].length; i2++) {
                    bufferedImage.setRGB(i, i2, new Color(nextFloat, nextFloat2, nextFloat3, probDist[i][i2] * 0.8f).darker().getRGB());
                }
            }
            Graphics2D graphics = this.image.getGraphics();
            graphics.drawImage(bufferedImage, x, y, (ImageObserver) null);
            graphics.dispose();
        }

        private void drawLandmarks(BoundingBox boundingBox) {
            Graphics2D graphics = this.image.getGraphics();
            graphics.setColor(new Color(246, 96, 0));
            graphics.setStroke(new BasicStroke(4.0f, 0, 0));
            for (Point point : boundingBox.getPath()) {
                graphics.drawRect((int) point.getX(), (int) point.getY(), 2, 2);
            }
            graphics.dispose();
        }
    }

    @Override // ai.djl.modality.cv.ImageFactory
    public Image fromFile(Path path) throws IOException {
        BufferedImage read = ImageIO.read(path.toFile());
        if (read == null) {
            throw new IOException("Failed to read image from: " + path);
        }
        return new BufferedImageWrapper(read);
    }

    @Override // ai.djl.modality.cv.ImageFactory
    public Image fromInputStream(InputStream inputStream) throws IOException {
        BufferedImage read = ImageIO.read(inputStream);
        if (read == null) {
            throw new IOException("Failed to read image from input stream");
        }
        return new BufferedImageWrapper(read);
    }

    @Override // ai.djl.modality.cv.ImageFactory
    public Image fromImage(Object obj) {
        if (obj instanceof BufferedImage) {
            return new BufferedImageWrapper((BufferedImage) obj);
        }
        throw new IllegalArgumentException("only BufferedImage allowed");
    }

    @Override // ai.djl.modality.cv.ImageFactory
    public Image fromNDArray(NDArray nDArray) {
        Shape shape = nDArray.getShape();
        if (shape.dimension() == 4) {
            throw new UnsupportedOperationException("Batch is not supported");
        }
        if (shape.get(0) == 1 || shape.get(2) == 1) {
            throw new UnsupportedOperationException("Grayscale image is not supported");
        }
        if (nDArray.getDataType() != DataType.UINT8 && nDArray.getDataType() != DataType.INT8) {
            throw new IllegalArgumentException("Datatype should be INT8 or UINT8");
        }
        int i = (int) shape.get(0);
        int i2 = (int) shape.get(1);
        if (NDImageUtils.isCHW(shape)) {
            i = (int) shape.get(1);
            i2 = (int) shape.get(2);
        }
        int i3 = i2 * i;
        BufferedImage bufferedImage = new BufferedImage(i2, i, 1);
        int[] uint8Array = nDArray.toUint8Array();
        int i4 = i2;
        IntStream.range(0, i3).forEach(i5 -> {
            int i5 = uint8Array[i5] & 255;
            int i6 = uint8Array[i5 + i3] & 255;
            bufferedImage.setRGB(i5 % i4, i5 / i4, (i5 << 16) | (i6 << 8) | (uint8Array[i5 + (i3 * 2)] & 255));
        });
        return new BufferedImageWrapper(bufferedImage);
    }

    protected void save(BufferedImage bufferedImage, OutputStream outputStream, String str) throws IOException {
        ImageIO.write(bufferedImage, str, outputStream);
    }

    static {
        if (System.getProperty("apple.awt.UIElement") == null) {
            System.setProperty("apple.awt.UIElement", "true");
        }
    }
}
