package org.deeplearning4j.ui.module.train;

import com.fasterxml.jackson.annotation.JsonIgnore;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.graph.GraphVertex;
import org.deeplearning4j.nn.conf.graph.LayerVertex;
import org.deeplearning4j.nn.conf.layers.BaseOutputLayer;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder;

/* loaded from: input_file:org/deeplearning4j/ui/module/train/TrainModuleUtils.class */
public class TrainModuleUtils {

    /* loaded from: input_file:org/deeplearning4j/ui/module/train/TrainModuleUtils$GraphInfo.class */
    public static class GraphInfo {
        private List<String> vertexNames;
        private List<String> vertexTypes;
        private List<List<Integer>> vertexInputs;
        private List<Map<String, String>> vertexInfo;

        @JsonIgnore
        private List<String> originalVertexName;

        public GraphInfo(List<String> list, List<String> list2, List<List<Integer>> list3, List<Map<String, String>> list4, List<String> list5) {
            this.vertexNames = list;
            this.vertexTypes = list2;
            this.vertexInputs = list3;
            this.vertexInfo = list4;
            this.originalVertexName = list5;
        }

        public List<String> getVertexNames() {
            return this.vertexNames;
        }

        public List<String> getVertexTypes() {
            return this.vertexTypes;
        }

        public List<List<Integer>> getVertexInputs() {
            return this.vertexInputs;
        }

        public List<Map<String, String>> getVertexInfo() {
            return this.vertexInfo;
        }

        public List<String> getOriginalVertexName() {
            return this.originalVertexName;
        }

        public void setVertexNames(List<String> list) {
            this.vertexNames = list;
        }

        public void setVertexTypes(List<String> list) {
            this.vertexTypes = list;
        }

        public void setVertexInputs(List<List<Integer>> list) {
            this.vertexInputs = list;
        }

        public void setVertexInfo(List<Map<String, String>> list) {
            this.vertexInfo = list;
        }

        public void setOriginalVertexName(List<String> list) {
            this.originalVertexName = list;
        }

        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof GraphInfo)) {
                return false;
            }
            GraphInfo graphInfo = (GraphInfo) obj;
            if (!graphInfo.canEqual(this)) {
                return false;
            }
            List<String> vertexNames = getVertexNames();
            List<String> vertexNames2 = graphInfo.getVertexNames();
            if (vertexNames == null) {
                if (vertexNames2 != null) {
                    return false;
                }
            } else if (!vertexNames.equals(vertexNames2)) {
                return false;
            }
            List<String> vertexTypes = getVertexTypes();
            List<String> vertexTypes2 = graphInfo.getVertexTypes();
            if (vertexTypes == null) {
                if (vertexTypes2 != null) {
                    return false;
                }
            } else if (!vertexTypes.equals(vertexTypes2)) {
                return false;
            }
            List<List<Integer>> vertexInputs = getVertexInputs();
            List<List<Integer>> vertexInputs2 = graphInfo.getVertexInputs();
            if (vertexInputs == null) {
                if (vertexInputs2 != null) {
                    return false;
                }
            } else if (!vertexInputs.equals(vertexInputs2)) {
                return false;
            }
            List<Map<String, String>> vertexInfo = getVertexInfo();
            List<Map<String, String>> vertexInfo2 = graphInfo.getVertexInfo();
            if (vertexInfo == null) {
                if (vertexInfo2 != null) {
                    return false;
                }
            } else if (!vertexInfo.equals(vertexInfo2)) {
                return false;
            }
            List<String> originalVertexName = getOriginalVertexName();
            List<String> originalVertexName2 = graphInfo.getOriginalVertexName();
            return originalVertexName == null ? originalVertexName2 == null : originalVertexName.equals(originalVertexName2);
        }

        protected boolean canEqual(Object obj) {
            return obj instanceof GraphInfo;
        }

        public int hashCode() {
            List<String> vertexNames = getVertexNames();
            int hashCode = (1 * 59) + (vertexNames == null ? 43 : vertexNames.hashCode());
            List<String> vertexTypes = getVertexTypes();
            int hashCode2 = (hashCode * 59) + (vertexTypes == null ? 43 : vertexTypes.hashCode());
            List<List<Integer>> vertexInputs = getVertexInputs();
            int hashCode3 = (hashCode2 * 59) + (vertexInputs == null ? 43 : vertexInputs.hashCode());
            List<Map<String, String>> vertexInfo = getVertexInfo();
            int hashCode4 = (hashCode3 * 59) + (vertexInfo == null ? 43 : vertexInfo.hashCode());
            List<String> originalVertexName = getOriginalVertexName();
            return (hashCode4 * 59) + (originalVertexName == null ? 43 : originalVertexName.hashCode());
        }

        public String toString() {
            return "TrainModuleUtils.GraphInfo(vertexNames=" + getVertexNames() + ", vertexTypes=" + getVertexTypes() + ", vertexInputs=" + getVertexInputs() + ", vertexInfo=" + getVertexInfo() + ", originalVertexName=" + getOriginalVertexName() + ")";
        }
    }

    public static GraphInfo buildGraphInfo(MultiLayerConfiguration multiLayerConfiguration) {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        ArrayList arrayList3 = new ArrayList();
        ArrayList arrayList4 = new ArrayList();
        ArrayList arrayList5 = new ArrayList();
        arrayList.add("Input");
        arrayList2.add(null);
        arrayList3.add("Input");
        arrayList4.add(Collections.emptyList());
        arrayList5.add(Collections.emptyMap());
        int i = 1;
        for (NeuralNetConfiguration neuralNetConfiguration : multiLayerConfiguration.getConfs()) {
            Layer layer = neuralNetConfiguration.getLayer();
            String layerName = layer.getLayerName();
            if (layerName == null) {
                layerName = "layer" + i;
            }
            arrayList.add(layerName);
            arrayList2.add(String.valueOf(i - 1));
            arrayList3.add(neuralNetConfiguration.getLayer().getClass().getSimpleName().replaceAll("Layer$", ""));
            arrayList4.add(Collections.singletonList(Integer.valueOf(i - 1)));
            i++;
            arrayList5.add(getLayerInfo(neuralNetConfiguration, layer));
        }
        return new GraphInfo(arrayList, arrayList3, arrayList4, arrayList5, arrayList2);
    }

    public static GraphInfo buildGraphInfo(ComputationGraphConfiguration computationGraphConfiguration) {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        ArrayList arrayList3 = new ArrayList();
        ArrayList arrayList4 = new ArrayList();
        Map vertices = computationGraphConfiguration.getVertices();
        Map vertexInputs = computationGraphConfiguration.getVertexInputs();
        List<String> networkInputs = computationGraphConfiguration.getNetworkInputs();
        ArrayList arrayList5 = new ArrayList();
        HashMap hashMap = new HashMap();
        int i = 0;
        for (String str : networkInputs) {
            int i2 = i;
            i++;
            hashMap.put(str, Integer.valueOf(i2));
            arrayList.add(str);
            arrayList5.add(str);
            arrayList2.add(str);
            arrayList3.add(Collections.emptyList());
            arrayList4.add(Collections.emptyMap());
        }
        Iterator it = vertices.keySet().iterator();
        while (it.hasNext()) {
            int i3 = i;
            i++;
            hashMap.put((String) it.next(), Integer.valueOf(i3));
        }
        for (Map.Entry entry : vertices.entrySet()) {
            LayerVertex layerVertex = (GraphVertex) entry.getValue();
            arrayList.add(entry.getKey());
            List list = (List) vertexInputs.get(entry.getKey());
            ArrayList arrayList6 = new ArrayList();
            Iterator it2 = list.iterator();
            while (it2.hasNext()) {
                arrayList6.add(hashMap.get((String) it2.next()));
            }
            arrayList3.add(arrayList6);
            if (layerVertex instanceof LayerVertex) {
                NeuralNetConfiguration layerConf = layerVertex.getLayerConf();
                Layer layer = layerConf.getLayer();
                arrayList2.add(layer.getClass().getSimpleName().replaceAll("Layer$", ""));
                arrayList4.add(getLayerInfo(layerConf, layer));
            } else {
                arrayList2.add(layerVertex.getClass().getSimpleName());
                arrayList4.add(Collections.emptyMap());
            }
            arrayList5.add(entry.getKey());
        }
        return new GraphInfo(arrayList, arrayList2, arrayList3, arrayList4, arrayList5);
    }

    public static GraphInfo buildGraphInfo(NeuralNetConfiguration neuralNetConfiguration) {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        ArrayList arrayList3 = new ArrayList();
        ArrayList arrayList4 = new ArrayList();
        ArrayList arrayList5 = new ArrayList();
        arrayList.add("Input");
        arrayList2.add(null);
        arrayList3.add("Input");
        arrayList4.add(Collections.emptyList());
        arrayList5.add(Collections.emptyMap());
        if (neuralNetConfiguration.getLayer() instanceof VariationalAutoencoder) {
            VariationalAutoencoder layer = neuralNetConfiguration.getLayer();
            int[] encoderLayerSizes = layer.getEncoderLayerSizes();
            int[] decoderLayerSizes = layer.getDecoderLayerSizes();
            int i = 1;
            int i2 = 0;
            while (i2 < encoderLayerSizes.length) {
                arrayList.add("encoder_" + i2);
                arrayList2.add("e" + i2);
                arrayList3.add("VAE-Encoder");
                arrayList4.add(Collections.singletonList(Integer.valueOf(i - 1)));
                i++;
                LinkedHashMap linkedHashMap = new LinkedHashMap();
                int nIn = i2 == 0 ? layer.getNIn() : encoderLayerSizes[i2 - 1];
                int i3 = encoderLayerSizes[i2];
                linkedHashMap.put("Input Size", String.valueOf(nIn));
                linkedHashMap.put("Layer Size", String.valueOf(i3));
                linkedHashMap.put("Num Parameters", String.valueOf((nIn + 1) * i3));
                linkedHashMap.put("Activation Function", layer.getActivationFn().toString());
                arrayList5.add(linkedHashMap);
                i2++;
            }
            arrayList.add("z");
            arrayList2.add("pZX");
            arrayList3.add("VAE-LatentVariable");
            arrayList4.add(Collections.singletonList(Integer.valueOf(i - 1)));
            int i4 = i + 1;
            LinkedHashMap linkedHashMap2 = new LinkedHashMap();
            int i5 = encoderLayerSizes[encoderLayerSizes.length - 1];
            int nOut = layer.getNOut();
            linkedHashMap2.put("Input Size", String.valueOf(i5));
            linkedHashMap2.put("Layer Size", String.valueOf(nOut));
            linkedHashMap2.put("Num Parameters", String.valueOf((i5 + 1) * nOut * 2));
            linkedHashMap2.put("Activation Function", layer.getPzxActivationFn().toString());
            arrayList5.add(linkedHashMap2);
            int i6 = 0;
            while (i6 < decoderLayerSizes.length) {
                arrayList.add("decoder_" + i6);
                arrayList2.add("d" + i6);
                arrayList3.add("VAE-Decoder");
                arrayList4.add(Collections.singletonList(Integer.valueOf(i4 - 1)));
                i4++;
                LinkedHashMap linkedHashMap3 = new LinkedHashMap();
                int nOut2 = i6 == 0 ? layer.getNOut() : decoderLayerSizes[i6 - 1];
                int i7 = decoderLayerSizes[i6];
                linkedHashMap3.put("Input Size", String.valueOf(nOut2));
                linkedHashMap3.put("Layer Size", String.valueOf(i7));
                linkedHashMap3.put("Num Parameters", String.valueOf((nOut2 + 1) * i7));
                linkedHashMap3.put("Activation Function", layer.getActivationFn().toString());
                arrayList5.add(linkedHashMap3);
                i6++;
            }
            arrayList.add("x");
            arrayList2.add("pXZ");
            arrayList3.add("VAE-Reconstruction");
            arrayList4.add(Collections.singletonList(Integer.valueOf(i4 - 1)));
            int i8 = i4 + 1;
            LinkedHashMap linkedHashMap4 = new LinkedHashMap();
            int i9 = decoderLayerSizes[decoderLayerSizes.length - 1];
            int nIn2 = layer.getNIn();
            linkedHashMap4.put("Input Size", String.valueOf(i9));
            linkedHashMap4.put("Layer Size", String.valueOf(nIn2));
            linkedHashMap4.put("Num Parameters", String.valueOf((i9 + 1) * layer.getOutputDistribution().distributionInputSize(layer.getNIn())));
            linkedHashMap4.put("Distribution", layer.getOutputDistribution().toString());
            arrayList5.add(linkedHashMap4);
        } else {
            Layer layer2 = neuralNetConfiguration.getLayer();
            String layerName = layer2.getLayerName();
            if (layerName == null) {
                layerName = "layer0";
            }
            arrayList.add(layerName);
            arrayList2.add(String.valueOf("0"));
            arrayList3.add(neuralNetConfiguration.getLayer().getClass().getSimpleName().replaceAll("Layer$", ""));
            arrayList4.add(Collections.singletonList(0));
            arrayList5.add(getLayerInfo(neuralNetConfiguration, layer2));
        }
        return new GraphInfo(arrayList, arrayList3, arrayList4, arrayList5, arrayList2);
    }

    private static Map<String, String> getLayerInfo(NeuralNetConfiguration neuralNetConfiguration, Layer layer) {
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        if (layer instanceof FeedForwardLayer) {
            FeedForwardLayer feedForwardLayer = (FeedForwardLayer) layer;
            linkedHashMap.put("Input size", String.valueOf(feedForwardLayer.getNIn()));
            linkedHashMap.put("Output size", String.valueOf(feedForwardLayer.getNOut()));
            linkedHashMap.put("Num Parameters", String.valueOf(feedForwardLayer.initializer().numParams(neuralNetConfiguration)));
            linkedHashMap.put("Activation Function", feedForwardLayer.getActivationFn().toString());
        }
        if (layer instanceof ConvolutionLayer) {
            ConvolutionLayer convolutionLayer = (ConvolutionLayer) layer;
            linkedHashMap.put("Kernel size", Arrays.toString(convolutionLayer.getKernelSize()));
            linkedHashMap.put("Stride", Arrays.toString(convolutionLayer.getStride()));
            linkedHashMap.put("Padding", Arrays.toString(convolutionLayer.getPadding()));
        } else if (layer instanceof SubsamplingLayer) {
            SubsamplingLayer subsamplingLayer = (SubsamplingLayer) layer;
            linkedHashMap.put("Kernel size", Arrays.toString(subsamplingLayer.getKernelSize()));
            linkedHashMap.put("Stride", Arrays.toString(subsamplingLayer.getStride()));
            linkedHashMap.put("Padding", Arrays.toString(subsamplingLayer.getPadding()));
            linkedHashMap.put("Pooling Type", subsamplingLayer.getPoolingType().toString());
        } else if (layer instanceof BaseOutputLayer) {
            BaseOutputLayer baseOutputLayer = (BaseOutputLayer) layer;
            if (baseOutputLayer.getLossFn() != null) {
                linkedHashMap.put("Loss Function", baseOutputLayer.getLossFn().toString());
            }
        }
        return linkedHashMap;
    }
}
