package com.omega.example.yolo.test;

import com.omega.common.data.Tensor;
import com.omega.common.utils.ImageUtils;
import com.omega.common.utils.JsonUtils;
import com.omega.engine.gpu.CUDAMemoryManager;
import com.omega.engine.gpu.CUDAModules;
import com.omega.engine.loss.LossType;
import com.omega.engine.model.ModelLoader;
import com.omega.engine.nn.data.BaseData;
import com.omega.engine.nn.data.DataSet;
import com.omega.engine.nn.network.Network;
import com.omega.engine.nn.network.Yolo;
import com.omega.engine.optimizer.MBSGDOptimizer;
import com.omega.engine.optimizer.lr.LearnRateUpdate;
import com.omega.engine.updater.UpdaterType;
import com.omega.example.yolo.data.DataType;
import com.omega.example.yolo.data.DetectionDataLoader;
import com.omega.example.yolo.data.YoloDataTransform2;
import com.omega.example.yolo.model.YoloBox;
import com.omega.example.yolo.model.YoloDetection;
import com.omega.example.yolo.utils.AnchorBoxUtils;
import com.omega.example.yolo.utils.LabelFileType;
import com.omega.example.yolo.utils.LabelType;
import com.omega.example.yolo.utils.LabelUtils;
import com.omega.example.yolo.utils.YoloDataLoader;
import com.omega.example.yolo.utils.YoloImageUtils;
import com.omega.example.yolo.utils.YoloLabelUtils;
import java.util.ArrayList;
import java.util.List;

/* loaded from: input_file:com/omega/example/yolo/test/YoloV2Test.class */
public class YoloV2Test {
    public void yolov3() {
    }

    public static void showImg(String str, DataSet dataSet, int i, List<YoloBox> list, int i2, boolean z, int i3, int i4) {
        ImageUtils imageUtils = new ImageUtils();
        int i5 = dataSet.number % i2;
        for (int i6 = 0; i6 < dataSet.number; i6++) {
            float[] fArr = dataSet.getOnceData(i6).data;
            int i7 = i6;
            if (i6 >= dataSet.number - i5) {
                i7 = i6 + (i2 - i5);
            }
            YoloBox yoloBox = list.get(i7);
            ArrayList arrayList = new ArrayList();
            for (int i8 = 0; i8 < yoloBox.getDets().size(); i8++) {
                if (yoloBox.getDets().get(i8) != null && yoloBox.getDets().get(i8).getObjectness() > 0.0f) {
                    arrayList.add(Integer.valueOf(i8));
                }
            }
            int[][] iArr = new int[arrayList.size()][5];
            for (int i9 = 0; i9 < arrayList.size(); i9++) {
                YoloDetection yoloDetection = yoloBox.getDets().get(((Integer) arrayList.get(i9)).intValue());
                iArr[i9][0] = 0;
                iArr[i9][1] = (int) ((yoloDetection.getBbox()[0] - (yoloDetection.getBbox()[2] / 2.0f)) * i3);
                iArr[i9][2] = (int) ((yoloDetection.getBbox()[1] - (yoloDetection.getBbox()[3] / 2.0f)) * i4);
                iArr[i9][3] = (int) ((yoloDetection.getBbox()[0] + (yoloDetection.getBbox()[2] / 2.0f)) * i3);
                iArr[i9][4] = (int) ((yoloDetection.getBbox()[1] + (yoloDetection.getBbox()[3] / 2.0f)) * i4);
            }
            imageUtils.createRGBImage(str + i6 + ".png", "png", ImageUtils.color2rgb2(fArr, i3, i4, z), i3, i4, iArr);
        }
    }

    public void yolov2_tiny() {
        try {
            YoloDataLoader yoloDataLoader = new YoloDataLoader("H:\\voc\\banana-detection\\bananas_train\\images", "H:\\voc\\banana-detection\\bananas_train\\label.csv", 1000, 3, 256, 256, 5, LabelType.csv_v3, true);
            YoloDataLoader yoloDataLoader2 = new YoloDataLoader("H:\\voc\\banana-detection\\bananas_val\\images", "H:\\voc\\banana-detection\\bananas_val\\label.csv", 100, 3, 256, 256, 5, LabelType.csv_v3, true);
            DataSet formatToYoloV3 = YoloLabelUtils.formatToYoloV3(yoloDataLoader.getDataSet(), 256, 256);
            DataSet formatToYoloV32 = YoloLabelUtils.formatToYoloV3(yoloDataLoader2.getDataSet(), 256, 256);
            Yolo yolo = new Yolo(LossType.yolov2, UpdaterType.adamw);
            yolo.CUDNN = true;
            yolo.learnRate = 0.01f;
            ModelLoader.loadConfigToModel(yolo, "H:\\voc\\banana-detection\\yolov2-tiny-banana.cfg");
            new MBSGDOptimizer((Network) yolo, 1000, 0.001f, 64, LearnRateUpdate.SMART_HALF, false).trainObjectRecognitionOutputs((BaseData) formatToYoloV3, (BaseData) formatToYoloV32, false);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void yolov2_tiny_voc() {
        try {
            DetectionDataLoader detectionDataLoader = new DetectionDataLoader("H:\\voc\\train\\imgs", "H:\\voc\\train\\labels\\yolov3.txt", LabelFileType.txt, YoloImageUtils.YOLO_IMG_SIZE, YoloImageUtils.YOLO_IMG_SIZE, 20, 24, DataType.yolov3, new YoloDataTransform2(20, DataType.yolov3, 90));
            DetectionDataLoader detectionDataLoader2 = new DetectionDataLoader("H:\\voc\\test\\imgs", "H:\\voc\\test\\labels\\yolov3.txt", LabelFileType.txt, YoloImageUtils.YOLO_IMG_SIZE, YoloImageUtils.YOLO_IMG_SIZE, 20, 24, DataType.yolov3);
            Yolo yolo = new Yolo(LossType.yolov2, UpdaterType.adamw);
            yolo.CUDNN = true;
            yolo.learnRate = 1.0E-4f;
            ModelLoader.loadConfigToModel(yolo, "H:/voc/train/yolov3-tiny-voc.cfg");
            new MBSGDOptimizer((Network) yolo, 3000, 0.001f, 24, LearnRateUpdate.SMART_HALF, false).trainObjectRecognitionOutputs(detectionDataLoader, detectionDataLoader2);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void getAnchors() {
        try {
            Tensor tensor = new Tensor(100, 1, 1, 4);
            LabelUtils.loadBoxCSV("H:\\voc\\banana-detection\\bananas_val\\label.csv", tensor);
            System.out.println(JsonUtils.toJson(AnchorBoxUtils.getAnchorBox(tensor, 6).data));
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public static void main(String[] strArr) {
        try {
            try {
                CUDAModules.initContext();
                new YoloV2Test().yolov2_tiny();
                CUDAMemoryManager.free();
            } catch (Exception e) {
                e.printStackTrace();
                CUDAMemoryManager.free();
            }
        } catch (Throwable th) {
            CUDAMemoryManager.free();
            throw th;
        }
    }
}
