package com.omega.engine.model;

import com.omega.common.utils.JsonUtils;
import com.omega.engine.nn.layer.AVGPoolingLayer;
import com.omega.engine.nn.layer.CBLLayer;
import com.omega.engine.nn.layer.ConvolutionLayer;
import com.omega.engine.nn.layer.FullyLayer;
import com.omega.engine.nn.layer.InputLayer;
import com.omega.engine.nn.layer.Layer;
import com.omega.engine.nn.layer.ParamsInit;
import com.omega.engine.nn.layer.PoolingLayer;
import com.omega.engine.nn.layer.RouteLayer;
import com.omega.engine.nn.layer.UPSampleLayer;
import com.omega.engine.nn.layer.YoloLayer;
import com.omega.engine.nn.layer.active.LeakyReluLayer;
import com.omega.engine.nn.layer.active.ReluLayer;
import com.omega.engine.nn.layer.active.SiLULayer;
import com.omega.engine.nn.layer.active.SigmodLayer;
import com.omega.engine.nn.layer.active.TanhLayer;
import com.omega.engine.nn.layer.normalization.BNLayer;
import com.omega.engine.nn.network.Network;
import com.omega.engine.pooling.PoolingType;
import com.omega.example.yolo.utils.YoloImageUtils;
import java.io.File;
import java.io.FileInputStream;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;

/* loaded from: input_file:com/omega/engine/model/ModelLoader.class */
public class ModelLoader {
    public static void loadConfigToModel(Network network, String str) {
        try {
            List<Map<String, Object>> loadData = loadData(str);
            if (loadData == null) {
                throw new RuntimeException("load the config file error.");
            }
            addLayer(loadData, network);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    /* JADX WARN: Failed to find 'out' block for switch in B:5:0x0036. Please report as an issue. */
    public static void addLayer(List<Map<String, Object>> list, Network network) {
        for (int i = 0; i < list.size(); i++) {
            Map<String, Object> map = list.get(i);
            String obj = map.get("layerType").toString();
            int[] iArr = null;
            boolean z = -1;
            switch (obj.hashCode()) {
                case -981581471:
                    if (obj.equals("meanpool")) {
                        z = 4;
                        break;
                    }
                    break;
                case -629646866:
                    if (obj.equals("avgpool")) {
                        z = 5;
                        break;
                    }
                    break;
                case 98285:
                    if (obj.equals("cbl")) {
                        z = 2;
                        break;
                    }
                    break;
                case 3714841:
                    if (obj.equals("yolo")) {
                        z = 9;
                        break;
                    }
                    break;
                case 97791946:
                    if (obj.equals("fully")) {
                        z = false;
                        break;
                    }
                    break;
                case 100358090:
                    if (obj.equals("input")) {
                        z = 6;
                        break;
                    }
                    break;
                case 108704329:
                    if (obj.equals("route")) {
                        z = 7;
                        break;
                    }
                    break;
                case 844950400:
                    if (obj.equals("maxpool")) {
                        z = 3;
                        break;
                    }
                    break;
                case 1017330115:
                    if (obj.equals("convolutional")) {
                        z = true;
                        break;
                    }
                    break;
                case 1751835653:
                    if (obj.equals("upsample")) {
                        z = 8;
                        break;
                    }
                    break;
            }
            switch (z) {
                case false:
                    iArr = addFullyLayers(map, network);
                    break;
                case true:
                    iArr = addConvLayers(map, network);
                    break;
                case true:
                    iArr = addCBLs(map, network);
                    break;
                case true:
                    iArr = addMaxPoolingLayer(map, network);
                    break;
                case true:
                    iArr = addMeanPoolingLayer(map, network);
                    break;
                case true:
                    iArr = addAvgPoolingLayer(map, network);
                    break;
                case true:
                    iArr = addInputLayer(map, network);
                    break;
                case YoloImageUtils.GRID_SIZE /* 7 */:
                    iArr = addRouteLayer(map, network, list, i);
                    break;
                case true:
                    iArr = addUpsampleLayer(map, network);
                    break;
                case true:
                    iArr = addYoloLayer(map, network);
                    break;
            }
            System.out.println(obj + "(" + i + "):" + JsonUtils.toJson(iArr));
        }
    }

    public static int getInt(String str) {
        return new Double(str).intValue();
    }

    public static float getFloat(String str) {
        return new Float(str).floatValue();
    }

    public static int[] addInputLayer(Map<String, Object> map, Network network) {
        InputLayer inputLayer = new InputLayer(getInt(map.get("channel").toString()), getInt(map.get("height").toString()), getInt(map.get("width").toString()));
        network.addLayer(inputLayer);
        map.put("lastIndex", Integer.valueOf(inputLayer.index));
        return inputLayer.outputShape();
    }

    public static int[] addMaxPoolingLayer(Map<String, Object> map, Network network) {
        Layer lastLayer = network.getLastLayer();
        if (lastLayer == null) {
            throw new RuntimeException("the pooling layer cant be the fisrt layer.");
        }
        int i = getInt(map.get("size").toString());
        PoolingLayer poolingLayer = new PoolingLayer(lastLayer.oChannel, lastLayer.oWidth, lastLayer.oHeight, i, i, getInt(map.get("stride").toString()), map.get("padding") != null ? getInt(map.get("padding").toString()) : i - 1, PoolingType.MAX_POOLING);
        network.addLayer(poolingLayer);
        map.put("lastIndex", Integer.valueOf(poolingLayer.index));
        return poolingLayer.outputShape();
    }

    public static int[] addMeanPoolingLayer(Map<String, Object> map, Network network) {
        Layer lastLayer = network.getLastLayer();
        if (lastLayer == null) {
            throw new RuntimeException("the pooling layer cant be the fisrt layer.");
        }
        int i = getInt(map.get("size").toString());
        PoolingLayer poolingLayer = new PoolingLayer(lastLayer.oChannel, lastLayer.oWidth, lastLayer.oHeight, i, i, getInt(map.get("stride").toString()), PoolingType.MEAN_POOLING);
        network.addLayer(poolingLayer);
        map.put("lastIndex", Integer.valueOf(poolingLayer.index));
        return poolingLayer.outputShape();
    }

    public static int[] addAvgPoolingLayer(Map<String, Object> map, Network network) {
        Layer lastLayer = network.getLastLayer();
        if (lastLayer == null) {
            throw new RuntimeException("the pooling layer cant be the fisrt layer.");
        }
        AVGPoolingLayer aVGPoolingLayer = new AVGPoolingLayer(lastLayer.oChannel, lastLayer.oWidth, lastLayer.oHeight);
        network.addLayer(aVGPoolingLayer);
        map.put("lastIndex", Integer.valueOf(aVGPoolingLayer.index));
        return aVGPoolingLayer.outputShape();
    }

    public static int[] addConvLayers(Map<String, Object> map, Network network) {
        Layer lastLayer = network.getLastLayer();
        if (lastLayer == null) {
            throw new RuntimeException("the convolution layer cant be the fisrt layer.");
        }
        int i = getInt(map.get("kernel").toString());
        int i2 = getInt(map.get("size").toString());
        int i3 = getInt(map.get("stride").toString());
        int i4 = getInt(map.get("pad").toString());
        int i5 = 0;
        if (map.get("freeze") != null) {
            i5 = getInt(map.get("freeze").toString());
        }
        int i6 = 0;
        boolean z = true;
        if (map.get("batch_normalize") != null) {
            i6 = getInt(map.get("batch_normalize").toString());
            if (i6 > 0) {
                z = false;
            }
        }
        String str = null;
        if (map.get("activation") != null) {
            str = map.get("activation").toString();
        }
        ConvolutionLayer convolutionLayer = new ConvolutionLayer(lastLayer.oChannel, i, lastLayer.oWidth, lastLayer.oHeight, i2, i2, i4, i3, z);
        if (i5 == 1) {
            convolutionLayer.freeze = true;
        }
        network.addLayer(convolutionLayer);
        map.put("lastIndex", Integer.valueOf(convolutionLayer.index));
        if (i6 == 1) {
            BNLayer bNLayer = new BNLayer(convolutionLayer);
            bNLayer.preLayer = convolutionLayer;
            if (i5 == 1) {
                bNLayer.freeze = true;
            }
            network.addLayer(bNLayer);
            map.put("lastIndex", Integer.valueOf(bNLayer.index));
        }
        Layer makeActivation = makeActivation(str, convolutionLayer);
        if (makeActivation != null) {
            network.addLayer(makeActivation);
            map.put("lastIndex", Integer.valueOf(makeActivation.index));
        }
        return convolutionLayer.outputShape();
    }

    public static int[] addCBLs(Map<String, Object> map, Network network) {
        Layer lastLayer = network.getLastLayer();
        if (lastLayer == null) {
            throw new RuntimeException("the convolution layer cant be the fisrt layer.");
        }
        int i = getInt(map.get("kernel").toString());
        int i2 = getInt(map.get("size").toString());
        int i3 = getInt(map.get("stride").toString());
        int i4 = getInt(map.get("pad").toString());
        String str = null;
        if (map.get("activation") != null) {
            str = map.get("activation").toString();
            if (str.equals("leaky")) {
                str = "leaky_relu";
            }
        }
        CBLLayer cBLLayer = new CBLLayer(lastLayer.oChannel, i, lastLayer.oHeight, lastLayer.oWidth, i2, i2, i3, i4, str, network);
        network.addLayer(cBLLayer);
        map.put("lastIndex", Integer.valueOf(cBLLayer.index));
        return cBLLayer.outputShape();
    }

    public static int[] addFullyLayers(Map<String, Object> map, Network network) {
        Layer lastLayer = network.getLastLayer();
        if (lastLayer == null) {
            throw new RuntimeException("the fully layer cant be the fisrt layer.");
        }
        int i = lastLayer.oChannel * lastLayer.oHeight * lastLayer.oWidth;
        int i2 = getInt(map.get("output").toString());
        int i3 = 0;
        if (map.get("freeze") != null) {
            i3 = getInt(map.get("freeze").toString());
        }
        int i4 = 0;
        boolean z = true;
        if (map.get("batch_normalize") != null) {
            i4 = getInt(map.get("batch_normalize").toString());
            if (i4 > 0) {
                z = false;
            }
        }
        String str = null;
        if (map.get("activation") != null) {
            str = map.get("activation").toString();
        }
        FullyLayer fullyLayer = new FullyLayer(i, i2, z);
        if (i3 == 1) {
            fullyLayer.freeze = true;
        }
        network.addLayer(fullyLayer);
        map.put("lastIndex", Integer.valueOf(fullyLayer.index));
        if (i4 == 1) {
            BNLayer bNLayer = new BNLayer();
            bNLayer.preLayer = fullyLayer;
            if (i3 == 1) {
                bNLayer.freeze = true;
            }
            network.addLayer(bNLayer);
            map.put("lastIndex", Integer.valueOf(bNLayer.index));
        }
        Layer makeActivation = makeActivation(str, fullyLayer);
        if (makeActivation != null) {
            network.addLayer(makeActivation);
            map.put("lastIndex", Integer.valueOf(makeActivation.index));
        }
        return fullyLayer.outputShape();
    }

    public static int[] addRouteLayer(Map<String, Object> map, Network network, List<Map<String, Object>> list, int i) {
        int i2 = map.get("group") != null ? getInt(map.get("group").toString()) : 1;
        int i3 = map.get("group_id") != null ? getInt(map.get("group_id").toString()) : 0;
        List list2 = (List) map.get("layers");
        int[] iArr = new int[list2.size()];
        for (int i4 = 0; i4 < list2.size(); i4++) {
            iArr[i4] = ((Double) list2.get(i4)).intValue();
        }
        Layer[] layerArr = new Layer[iArr.length];
        for (int i5 = 0; i5 < iArr.length; i5++) {
            int i6 = iArr[i5];
            layerArr[i5] = network.layerList.get(i6 < 0 ? ((Integer) list.get(i + i6).get("lastIndex")).intValue() : ((Integer) list.get(i6).get("lastIndex")).intValue());
        }
        RouteLayer routeLayer = new RouteLayer(layerArr, i2, i3);
        network.addLayer(routeLayer);
        map.put("lastIndex", Integer.valueOf(routeLayer.index));
        return routeLayer.outputShape();
    }

    public static int[] addUpsampleLayer(Map<String, Object> map, Network network) {
        Layer lastLayer = network.getLastLayer();
        if (lastLayer == null) {
            throw new RuntimeException("the upsample layer cant be the fisrt layer.");
        }
        UPSampleLayer uPSampleLayer = new UPSampleLayer(lastLayer.oChannel, lastLayer.oHeight, lastLayer.oWidth, getInt(map.get("stride").toString()));
        network.addLayer(uPSampleLayer);
        map.put("lastIndex", Integer.valueOf(uPSampleLayer.index));
        return uPSampleLayer.outputShape();
    }

    public static int[] addYoloLayer(Map<String, Object> map, Network network) {
        int[] iArr;
        int i = getInt(map.get("classes").toString());
        int i2 = getInt(map.get("num").toString());
        int i3 = getInt(map.get("maxBox").toString());
        float f = getFloat(map.get("ignore_thresh").toString());
        float f2 = getFloat(map.get("truth_thresh").toString());
        float f3 = map.get("scale_x_y") != null ? getInt(map.get("scale_x_y").toString()) : 1.0f;
        int i4 = map.get("active") != null ? getInt(map.get("active").toString()) : 1;
        List list = (List) map.get("anchors");
        float[] fArr = new float[list.size()];
        for (int i5 = 0; i5 < list.size(); i5++) {
            fArr[i5] = ((Double) list.get(i5)).floatValue();
        }
        List list2 = (List) map.get("mask");
        if (list2 != null) {
            iArr = new int[list2.size()];
            for (int i6 = 0; i6 < list2.size(); i6++) {
                iArr[i6] = ((Double) list2.get(i6)).intValue();
            }
        } else {
            iArr = new int[i2];
        }
        YoloLayer yoloLayer = new YoloLayer(i, iArr.length, iArr, fArr, i3, i2, f, f2, i4, f3);
        network.addLayer(yoloLayer);
        map.put("lastIndex", Integer.valueOf(yoloLayer.index));
        return yoloLayer.outputShape();
    }

    public static Layer makeActivation(String str, Layer layer) {
        Layer layer2 = null;
        boolean z = -1;
        switch (str.hashCode()) {
            case 3387192:
                if (str.equals("none")) {
                    z = 5;
                    break;
                }
                break;
            case 3496700:
                if (str.equals("relu")) {
                    z = false;
                    break;
                }
                break;
            case 3530335:
                if (str.equals("silu")) {
                    z = 4;
                    break;
                }
                break;
            case 3552487:
                if (str.equals("tanh")) {
                    z = 3;
                    break;
                }
                break;
            case 102845814:
                if (str.equals("leaky")) {
                    z = 2;
                    break;
                }
                break;
            case 2088248974:
                if (str.equals("sigmoid")) {
                    z = true;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                layer2 = new ReluLayer(layer);
                layer.paramsInit = ParamsInit.relu;
                break;
            case true:
                layer2 = new SigmodLayer(layer);
                layer.paramsInit = ParamsInit.sigmoid;
                break;
            case true:
                layer2 = new LeakyReluLayer(layer);
                layer.paramsInit = ParamsInit.leaky_relu;
                break;
            case true:
                layer2 = new TanhLayer(layer);
                layer.paramsInit = ParamsInit.tanh;
                break;
            case true:
                layer2 = new SiLULayer(layer);
                layer.paramsInit = ParamsInit.silu;
                break;
            case true:
                break;
            default:
                throw new RuntimeException("not support this active function.");
        }
        return layer2;
    }

    public static List<Map<String, Object>> loadData(String str) {
        try {
            File file = new File(str);
            if (!file.exists()) {
                throw new RuntimeException("the config file is not exists.");
            }
            try {
                FileInputStream fileInputStream = new FileInputStream(file);
                Throwable th = null;
                try {
                    InputStreamReader inputStreamReader = new InputStreamReader(fileInputStream, "utf-8");
                    Throwable th2 = null;
                    try {
                        StringBuffer stringBuffer = new StringBuffer();
                        while (true) {
                            int read = inputStreamReader.read();
                            if (read == -1) {
                                break;
                            }
                            stringBuffer.append((char) read);
                        }
                        List<Map<String, Object>> list = (List) JsonUtils.gson.fromJson(stringBuffer.toString(), new ArrayList().getClass());
                        if (inputStreamReader != null) {
                            if (0 != 0) {
                                try {
                                    inputStreamReader.close();
                                } catch (Throwable th3) {
                                    th2.addSuppressed(th3);
                                }
                            } else {
                                inputStreamReader.close();
                            }
                        }
                        return list;
                    } catch (Throwable th4) {
                        if (inputStreamReader != null) {
                            if (0 != 0) {
                                try {
                                    inputStreamReader.close();
                                } catch (Throwable th5) {
                                    th2.addSuppressed(th5);
                                }
                            } else {
                                inputStreamReader.close();
                            }
                        }
                        throw th4;
                    }
                } finally {
                    if (fileInputStream != null) {
                        if (0 != 0) {
                            try {
                                fileInputStream.close();
                            } catch (Throwable th6) {
                                th.addSuppressed(th6);
                            }
                        } else {
                            fileInputStream.close();
                        }
                    }
                }
            } catch (Exception e) {
                e.printStackTrace();
                return null;
            }
        } catch (Exception e2) {
            e2.printStackTrace();
            return null;
        }
    }
}
