package org.deeplearning4j.ui.weights;

import java.awt.Color;
import java.awt.Graphics2D;
import java.awt.image.BufferedImage;
import java.awt.image.ImageObserver;
import java.io.File;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import java.util.UUID;
import javax.imageio.ImageIO;
import lombok.NonNull;
import org.datavec.api.util.ClassPathResource;
import org.datavec.image.loader.ImageLoader;
import org.deeplearning4j.api.storage.StatsStorage;
import org.deeplearning4j.api.storage.StatsStorageRouter;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.layers.convolution.ConvolutionLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.optimize.api.IterationListener;
import org.deeplearning4j.ui.UiConnectionInfo;
import org.deeplearning4j.ui.api.UIServer;
import org.deeplearning4j.ui.storage.mapdb.MapDBStatsStorage;
import org.deeplearning4j.util.UIDProvider;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/ui/weights/ConvolutionalIterationListener.class */
public class ConvolutionalIterationListener implements IterationListener {
    private int freq;
    private static final Logger log = LoggerFactory.getLogger(ConvolutionalIterationListener.class);
    private int minibatchNum;
    private boolean openBrowser;
    private String path;
    private boolean firstIteration;
    private Color borderColor;
    private Color bgColor;
    private final StatsStorageRouter ssr;
    private final String sessionID;
    private final String workerID;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/deeplearning4j/ui/weights/ConvolutionalIterationListener$Orientation.class */
    public enum Orientation {
        LANDSCAPE,
        PORTRAIT
    }

    public ConvolutionalIterationListener(UiConnectionInfo uiConnectionInfo, int i) {
        this(new MapDBStatsStorage(), i, true);
    }

    public ConvolutionalIterationListener(int i) {
        this(i, true);
    }

    public ConvolutionalIterationListener(int i, boolean z) {
        this(new MapDBStatsStorage(), i, z);
    }

    public ConvolutionalIterationListener(StatsStorageRouter statsStorageRouter, int i, boolean z) {
        this(statsStorageRouter, i, z, null, null);
    }

    public ConvolutionalIterationListener(StatsStorageRouter statsStorageRouter, int i, boolean z, String str, String str2) {
        this.freq = 10;
        this.minibatchNum = 0;
        this.openBrowser = true;
        this.firstIteration = true;
        this.borderColor = new Color(140, 140, 140);
        this.bgColor = new Color(255, 255, 255);
        this.ssr = statsStorageRouter;
        if (str == null) {
            this.sessionID = UUID.randomUUID().toString();
        } else {
            this.sessionID = str;
        }
        if (str2 == null) {
            this.workerID = UIDProvider.getJVMUID() + "_" + Thread.currentThread().getId();
        } else {
            this.workerID = str2;
        }
        this.freq = i;
        this.openBrowser = z;
        this.path = "http://localhost:" + UIServer.getInstance().getPort() + "/activations";
        if (z && (statsStorageRouter instanceof StatsStorage)) {
            UIServer.getInstance().attach((StatsStorage) statsStorageRouter);
        }
        System.out.println("ConvolutionIterationListener path: " + this.path);
    }

    public boolean invoked() {
        return false;
    }

    public void invoke() {
    }

    public void iterationDone(Model model, int i) {
        if (i % this.freq == 0) {
            ArrayList arrayList = new ArrayList();
            int i2 = 0;
            Random random = new Random();
            BufferedImage bufferedImage = null;
            if (model instanceof MultiLayerNetwork) {
                for (ConvolutionLayer convolutionLayer : ((MultiLayerNetwork) model).getLayers()) {
                    if (convolutionLayer.type() == Layer.Type.CONVOLUTIONAL) {
                        INDArray activate = convolutionLayer.activate();
                        int nextInt = random.nextInt(activate.shape()[0] - 1) + 1;
                        if (i2 == 0) {
                            try {
                                bufferedImage = restoreRGBImage(convolutionLayer.input().tensorAlongDimension(nextInt, new int[]{3, 2, 1}));
                            } catch (Exception e) {
                                throw new RuntimeException(e);
                            }
                        }
                        arrayList.add(activate.tensorAlongDimension(nextInt, new int[]{3, 2, 1}));
                        i2++;
                    }
                }
            } else if (model instanceof ComputationGraph) {
                for (ConvolutionLayer convolutionLayer2 : ((ComputationGraph) model).getLayers()) {
                    if (convolutionLayer2.type() == Layer.Type.CONVOLUTIONAL) {
                        INDArray activate2 = convolutionLayer2.activate();
                        int nextInt2 = random.nextInt(activate2.shape()[0] - 1) + 1;
                        if (i2 == 0) {
                            try {
                                bufferedImage = restoreRGBImage(convolutionLayer2.input().tensorAlongDimension(nextInt2, new int[]{3, 2, 1}));
                            } catch (Exception e2) {
                                throw new RuntimeException(e2);
                            }
                        }
                        arrayList.add(activate2.tensorAlongDimension(nextInt2, new int[]{3, 2, 1}));
                        i2++;
                    }
                }
            }
            this.ssr.putStaticInfo(new ConvolutionListenerPersistable(this.sessionID, this.workerID, System.currentTimeMillis(), rasterizeConvoLayers(arrayList, bufferedImage)));
            this.minibatchNum++;
        }
    }

    private BufferedImage rasterizeConvoLayers(@NonNull List<INDArray> list, BufferedImage bufferedImage) {
        int height;
        if (list == null) {
            throw new NullPointerException("tensors3D");
        }
        int[] shape = list.get(0).shape();
        int i = shape[0];
        int i2 = shape[2];
        int i3 = shape[1];
        int i4 = 0;
        int i5 = 0;
        int i6 = 1;
        Orientation orientation = Orientation.LANDSCAPE;
        if (list.size() > 3) {
            orientation = Orientation.PORTRAIT;
        }
        ArrayList arrayList = new ArrayList();
        for (int i7 = 0; i7 < list.size(); i7++) {
            INDArray iNDArray = list.get(i7);
            BufferedImage bufferedImage2 = null;
            if (orientation == Orientation.LANDSCAPE) {
                i4 = (i2 + (1 * 2) + 2) * i;
                bufferedImage2 = renderMultipleImagesLandscape(iNDArray, i4, i3, i2);
                i5 += bufferedImage2.getWidth() + 80;
            } else if (orientation == Orientation.PORTRAIT) {
                i5 = (i3 + (1 * 2) + 2) * i;
                bufferedImage2 = renderMultipleImagesPortrait(iNDArray, i5, i3, i2);
                i4 += bufferedImage2.getHeight() + 80;
            }
            arrayList.add(bufferedImage2);
        }
        if (orientation == Orientation.LANDSCAPE) {
            i5 += 80 * 2;
        } else if (orientation == Orientation.PORTRAIT) {
            i4 = i4 + (80 * 2) + bufferedImage.getHeight() + (80 * 2);
        }
        BufferedImage bufferedImage3 = new BufferedImage(i5, i4, 1);
        Graphics2D createGraphics = bufferedImage3.createGraphics();
        createGraphics.setPaint(this.bgColor);
        createGraphics.fillRect(0, 0, bufferedImage3.getWidth(), bufferedImage3.getHeight());
        BufferedImage bufferedImage4 = null;
        BufferedImage bufferedImage5 = null;
        try {
            if (orientation == Orientation.LANDSCAPE) {
                try {
                    ClassPathResource classPathResource = new ClassPathResource("arrow_sing.PNG");
                    ClassPathResource classPathResource2 = new ClassPathResource("arrow_mul.PNG");
                    bufferedImage4 = ImageIO.read(classPathResource.getInputStream());
                    bufferedImage5 = ImageIO.read(classPathResource2.getInputStream());
                } catch (Exception e) {
                }
                createGraphics.drawImage(bufferedImage, (80 / 2) - (bufferedImage.getWidth() / 2), (i4 / 2) - (bufferedImage.getHeight() / 2), (ImageObserver) null);
                createGraphics.setPaint(this.borderColor);
                createGraphics.drawRect((80 / 2) - (bufferedImage.getWidth() / 2), (i4 / 2) - (bufferedImage.getHeight() / 2), bufferedImage.getWidth(), bufferedImage.getHeight());
                height = 1 + bufferedImage.getWidth();
                if (bufferedImage4 != null) {
                    createGraphics.drawImage(bufferedImage4, (height + (80 / 2)) - (bufferedImage4.getWidth() / 2), (i4 / 2) - (bufferedImage4.getHeight() / 2), (ImageObserver) null);
                }
            } else {
                try {
                    ClassPathResource classPathResource3 = new ClassPathResource("arrow_singi.PNG");
                    ClassPathResource classPathResource4 = new ClassPathResource("arrow_muli.PNG");
                    bufferedImage4 = ImageIO.read(classPathResource3.getInputStream());
                    bufferedImage5 = ImageIO.read(classPathResource4.getInputStream());
                } catch (Exception e2) {
                }
                createGraphics.drawImage(bufferedImage, (i5 / 2) - (bufferedImage.getWidth() / 2), (80 / 2) - (bufferedImage.getHeight() / 2), (ImageObserver) null);
                createGraphics.setPaint(this.borderColor);
                createGraphics.drawRect((i5 / 2) - (bufferedImage.getWidth() / 2), (80 / 2) - (bufferedImage.getHeight() / 2), bufferedImage.getWidth(), bufferedImage.getHeight());
                height = 1 + bufferedImage.getHeight();
                if (bufferedImage4 != null) {
                    createGraphics.drawImage(bufferedImage4, (i5 / 2) - (bufferedImage4.getWidth() / 2), (height + (80 / 2)) - (bufferedImage4.getHeight() / 2), (ImageObserver) null);
                }
            }
            i6 = height + 80;
        } catch (Exception e3) {
        }
        for (int i8 = 0; i8 < arrayList.size(); i8++) {
            BufferedImage bufferedImage6 = (BufferedImage) arrayList.get(i8);
            if (orientation == Orientation.LANDSCAPE) {
                createGraphics.drawImage(bufferedImage6, i6, 1, (ImageObserver) null);
                i6 += bufferedImage6.getWidth() + 80;
                if (bufferedImage4 != null && bufferedImage5 != null && i8 < arrayList.size() - 1 && bufferedImage5 != null) {
                    createGraphics.drawImage(bufferedImage5, (i6 - (80 / 2)) - (bufferedImage5.getWidth() / 2), (i4 / 2) - (bufferedImage5.getHeight() / 2), (ImageObserver) null);
                }
            } else if (orientation == Orientation.PORTRAIT) {
                createGraphics.drawImage(bufferedImage6, 1, i6, (ImageObserver) null);
                i6 += bufferedImage6.getHeight() + 80;
                if (bufferedImage4 != null && bufferedImage5 != null && i8 < arrayList.size() - 1 && bufferedImage5 != null) {
                    createGraphics.drawImage(bufferedImage5, (i5 / 2) - (bufferedImage5.getWidth() / 2), (i6 - (80 / 2)) - (bufferedImage5.getHeight() / 2), (ImageObserver) null);
                }
            }
        }
        return bufferedImage3;
    }

    private BufferedImage renderMultipleImagesPortrait(INDArray iNDArray, int i, int i2, int i3) {
        int[] shape = iNDArray.shape();
        int i4 = ((shape[0] / shape[2]) * (shape[1] + 1 + 2)) + 2 + 20 + i2;
        BufferedImage bufferedImage = new BufferedImage(i, i4, 10);
        Graphics2D createGraphics = bufferedImage.createGraphics();
        createGraphics.setPaint(this.bgColor);
        createGraphics.fillRect(0, 0, bufferedImage.getWidth(), bufferedImage.getHeight());
        int i5 = 0;
        int i6 = 0;
        int i7 = 0;
        int i8 = i / 5;
        for (int i9 = 0; i9 < iNDArray.shape()[0]; i9++) {
            INDArray tensorAlongDimension = iNDArray.tensorAlongDimension(i9, new int[]{2, 1});
            int i10 = tensorAlongDimension.shape()[0];
            int i11 = tensorAlongDimension.shape()[1];
            int i12 = i11 + (1 * 2) + 2;
            int i13 = i10 + (1 * 2) + 2;
            BufferedImage renderImageGrayscale = renderImageGrayscale(tensorAlongDimension);
            if (i5 + i13 > i) {
                i6 += i12;
                i5 = 0;
            }
            createGraphics.drawImage(renderImageGrayscale, i5 + 1, i6 + 1, (ImageObserver) null);
            createGraphics.setPaint(this.borderColor);
            createGraphics.drawRect(i5, i6, tensorAlongDimension.shape()[0], tensorAlongDimension.shape()[1]);
            if (i9 % 7 == 0 && i9 != 0 && i7 < 5 && i11 != i3 && i10 != i2) {
                int i14 = (i8 * i7) + i3;
                int i15 = (i8 * i7) + i2;
                createGraphics.drawImage(renderImageGrayscale, i15 - 1, (i4 - i2) - 1, i2, i3, (ImageObserver) null);
                createGraphics.drawRect(i15 - 2, (i4 - i2) - 2, i2, i3);
                createGraphics.drawLine(i5 + i10, i6 + i11, i15 - 2, (i4 - i2) - 2);
                i7++;
            }
            i5 += i13;
        }
        return bufferedImage;
    }

    private BufferedImage renderMultipleImagesLandscape(INDArray iNDArray, int i, int i2, int i3) {
        int[] shape = iNDArray.shape();
        int i4 = ((shape[0] / shape[1]) * (shape[1] + 1 + 2)) + 2 + 20 + i2;
        BufferedImage bufferedImage = new BufferedImage(i4, i, 10);
        Graphics2D createGraphics = bufferedImage.createGraphics();
        createGraphics.setPaint(this.bgColor);
        createGraphics.fillRect(0, 0, bufferedImage.getWidth(), bufferedImage.getHeight());
        int i5 = 0;
        int i6 = 0;
        int i7 = 0;
        int i8 = i / 5;
        for (int i9 = 0; i9 < iNDArray.shape()[0]; i9++) {
            INDArray tensorAlongDimension = iNDArray.tensorAlongDimension(i9, new int[]{2, 1});
            int i10 = tensorAlongDimension.shape()[0];
            int i11 = tensorAlongDimension.shape()[1];
            int i12 = i11 + (1 * 2) + 2;
            int i13 = i10 + (1 * 2) + 2;
            BufferedImage renderImageGrayscale = renderImageGrayscale(tensorAlongDimension);
            if (i6 + i12 > i) {
                i5 += i13;
                i6 = 0;
            }
            createGraphics.drawImage(renderImageGrayscale, i5 + 1, i6 + 1, (ImageObserver) null);
            createGraphics.setPaint(this.borderColor);
            createGraphics.drawRect(i5, i6, tensorAlongDimension.shape()[0], tensorAlongDimension.shape()[1]);
            if (i9 % 5 == 0 && i9 != 0 && i7 < 5 && i11 != i3 && i10 != i2) {
                int i14 = (i8 * i7) + i3;
                createGraphics.drawImage(renderImageGrayscale, (i4 - i2) - 1, i14 - 1, i2, i3, (ImageObserver) null);
                createGraphics.drawRect((i4 - i2) - 2, i14 - 2, i2, i3);
                createGraphics.drawLine(i5 + i10, i6 + i11, (i4 - i2) - 2, (i14 - 2) + i3);
                i7++;
            }
            i6 += i12;
        }
        return bufferedImage;
    }

    private BufferedImage restoreRGBImage(INDArray iNDArray) {
        INDArray tensorAlongDimension;
        INDArray iNDArray2;
        INDArray iNDArray3;
        if (iNDArray.shape()[0] == 3) {
            iNDArray3 = iNDArray.tensorAlongDimension(2, new int[]{2, 1});
            iNDArray2 = iNDArray.tensorAlongDimension(1, new int[]{2, 1});
            tensorAlongDimension = iNDArray.tensorAlongDimension(0, new int[]{2, 1});
        } else {
            tensorAlongDimension = iNDArray.tensorAlongDimension(0, new int[]{2, 1});
            iNDArray2 = tensorAlongDimension;
            iNDArray3 = tensorAlongDimension;
        }
        BufferedImage bufferedImage = new BufferedImage(iNDArray3.columns(), iNDArray3.rows(), 1);
        for (int i = 0; i < iNDArray3.columns(); i++) {
            for (int i2 = 0; i2 < iNDArray3.rows(); i2++) {
                bufferedImage.setRGB(i, i2, new Color((int) (255.0d * iNDArray3.getRow(i2).getDouble(i)), (int) (255.0d * iNDArray2.getRow(i2).getDouble(i)), (int) (255.0d * tensorAlongDimension.getRow(i2).getDouble(i))).getRGB());
            }
        }
        return bufferedImage;
    }

    private BufferedImage renderImageGrayscale(INDArray iNDArray) {
        BufferedImage bufferedImage = new BufferedImage(iNDArray.columns(), iNDArray.rows(), 10);
        for (int i = 0; i < iNDArray.columns(); i++) {
            for (int i2 = 0; i2 < iNDArray.rows(); i2++) {
                bufferedImage.getRaster().setSample(i, i2, 0, (int) (255.0d * iNDArray.getRow(i2).getDouble(i)));
            }
        }
        return bufferedImage;
    }

    private void writeImageGrayscale(INDArray iNDArray, File file) {
        try {
            ImageIO.write(renderImageGrayscale(iNDArray), "png", file);
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    private void writeImage(INDArray iNDArray, File file) {
        try {
            ImageIO.write(ImageLoader.toImage(iNDArray), "png", file);
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    private void writeRows(INDArray iNDArray, File file) {
        try {
            PrintWriter printWriter = new PrintWriter(file);
            for (int i = 0; i < iNDArray.rows(); i++) {
                printWriter.println("Row [" + i + "]: " + iNDArray.getRow(i));
            }
            printWriter.flush();
            printWriter.close();
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }
}
