package com.omega.engine.model;

import com.omega.common.data.Tensor;
import com.omega.engine.nn.layer.CBLLayer;
import com.omega.engine.nn.layer.ConvolutionLayer;
import com.omega.engine.nn.layer.Layer;
import com.omega.engine.nn.layer.normalization.BNLayer;
import com.omega.engine.nn.network.Network;
import java.io.IOException;
import java.io.RandomAccessFile;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.List;

/* loaded from: input_file:com/omega/engine/model/DarknetLoader.class */
public class DarknetLoader {
    public static void loadWeight(Network network, String str, int i, boolean z) {
        System.out.println("start load weight.");
        try {
            RandomAccessFile randomAccessFile = new RandomAccessFile(str, "r");
            Throwable th = null;
            try {
                try {
                    int readInt = readInt(randomAccessFile);
                    int readInt2 = readInt(randomAccessFile);
                    int readInt3 = readInt(randomAccessFile);
                    System.out.println("major:" + readInt);
                    System.out.println("minor:" + readInt2);
                    System.out.println("revision:" + readInt3);
                    if ((readInt * 10) + readInt2 >= 2) {
                        System.out.println("seen:" + readBigInt(randomAccessFile));
                    } else {
                        System.out.println("seen:" + readInt(randomAccessFile));
                    }
                    int i2 = 0;
                    for (int i3 = 0; i3 < network.layerList.size(); i3++) {
                        Layer layer = network.layerList.get(i3);
                        if (i2 < i) {
                            switch (layer.getLayerType()) {
                                case conv:
                                    loadConvWeights(randomAccessFile, i3, layer, network.layerList, z);
                                    i2++;
                                    break;
                                case cbl:
                                    loadCBLvWeights(randomAccessFile, i3, layer, network.layerList, z);
                                    i2++;
                                    break;
                                case full:
                                    loadFullyWeights(randomAccessFile, i3, layer, network.layerList, z);
                                    i2++;
                                    break;
                                case pooling:
                                    i2++;
                                    break;
                                case route:
                                    i2++;
                                    break;
                            }
                        }
                    }
                    System.out.println("load weight finish.");
                    if (randomAccessFile != null) {
                        if (0 != 0) {
                            try {
                                randomAccessFile.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            randomAccessFile.close();
                        }
                    }
                } catch (Throwable th3) {
                    th = th3;
                    throw th3;
                }
            } finally {
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public static void loadWeight(Network network, String str) {
        System.out.println("start load weight.");
        try {
            RandomAccessFile randomAccessFile = new RandomAccessFile(str, "r");
            Throwable th = null;
            try {
                try {
                    int readInt = readInt(randomAccessFile);
                    int readInt2 = readInt(randomAccessFile);
                    int readInt3 = readInt(randomAccessFile);
                    System.out.println("major:" + readInt);
                    System.out.println("minor:" + readInt2);
                    System.out.println("revision:" + readInt3);
                    if ((readInt * 10) + readInt2 >= 2) {
                        System.out.println("seen:" + readBigInt(randomAccessFile));
                    } else {
                        System.out.println("seen:" + readInt(randomAccessFile));
                    }
                    for (int i = 0; i < network.layerList.size(); i++) {
                        Layer layer = network.layerList.get(i);
                        switch (layer.getLayerType()) {
                            case conv:
                                loadConvWeights(randomAccessFile, i, layer, network.layerList, false);
                                break;
                            case full:
                                loadFullyWeights(randomAccessFile, i, layer, network.layerList, false);
                                break;
                        }
                    }
                    System.out.println("load weight finish.");
                    if (randomAccessFile != null) {
                        if (0 != 0) {
                            try {
                                randomAccessFile.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            randomAccessFile.close();
                        }
                    }
                } catch (Throwable th3) {
                    th = th3;
                    throw th3;
                }
            } finally {
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public static void loadConvWeights(RandomAccessFile randomAccessFile, int i, Layer layer, List<Layer> list, boolean z) throws IOException {
        readFloat(randomAccessFile, layer.bias);
        if (!layer.hasBias && i < list.size() - 1 && (list.get(i + 1) instanceof BNLayer)) {
            BNLayer bNLayer = (BNLayer) list.get(i + 1);
            bNLayer.init();
            readFloat(randomAccessFile, bNLayer.gamma);
            bNLayer.beta = layer.bias.copyGPU();
            bNLayer.beta.syncHost();
            readFloat(randomAccessFile, bNLayer.runingMean);
            readFloat(randomAccessFile, bNLayer.runingVar);
            bNLayer.freeze = z;
        }
        readFloat(randomAccessFile, layer.weight);
        layer.freeze = z;
    }

    public static void loadCBLvWeights(RandomAccessFile randomAccessFile, int i, Layer layer, List<Layer> list, boolean z) throws IOException {
        CBLLayer cBLLayer = (CBLLayer) layer;
        ConvolutionLayer convLayer = cBLLayer.getConvLayer();
        readFloat(randomAccessFile, convLayer.bias);
        BNLayer bnLayer = cBLLayer.getBnLayer();
        bnLayer.init();
        readFloat(randomAccessFile, bnLayer.gamma);
        bnLayer.beta = convLayer.bias.copyGPU();
        readFloat(randomAccessFile, bnLayer.runingMean);
        readFloat(randomAccessFile, bnLayer.runingVar);
        bnLayer.freeze = z;
        readFloat(randomAccessFile, convLayer.weight);
        layer.freeze = z;
    }

    public static void loadFullyWeights(RandomAccessFile randomAccessFile, int i, Layer layer, List<Layer> list, boolean z) throws IOException {
        readFloat(randomAccessFile, layer.bias);
        readFloat(randomAccessFile, layer.weight);
        if (!layer.hasBias && i < list.size() - 1) {
            BNLayer bNLayer = (BNLayer) list.get(i + 1);
            bNLayer.init();
            readFloat(randomAccessFile, bNLayer.gamma);
            readFloat(randomAccessFile, bNLayer.runingMean);
            readFloat(randomAccessFile, bNLayer.runingVar);
            bNLayer.freeze = z;
        }
        layer.freeze = z;
    }

    public static long readBigInt(RandomAccessFile randomAccessFile) throws IOException {
        byte[] bArr = new byte[8];
        randomAccessFile.readFully(bArr);
        ByteBuffer wrap = ByteBuffer.wrap(bArr);
        wrap.order(ByteOrder.LITTLE_ENDIAN);
        return wrap.getLong();
    }

    public static int readInt(RandomAccessFile randomAccessFile) throws IOException {
        byte[] bArr = new byte[4];
        randomAccessFile.readFully(bArr);
        ByteBuffer wrap = ByteBuffer.wrap(bArr);
        wrap.order(ByteOrder.LITTLE_ENDIAN);
        return wrap.getInt();
    }

    public static float readFloat(RandomAccessFile randomAccessFile) throws IOException {
        byte[] bArr = new byte[4];
        randomAccessFile.readFully(bArr);
        ByteBuffer wrap = ByteBuffer.wrap(bArr);
        wrap.order(ByteOrder.LITTLE_ENDIAN);
        return wrap.getFloat();
    }

    public static void skipFloat(RandomAccessFile randomAccessFile, int i) throws IOException {
        byte[] bArr = new byte[i * 4];
        randomAccessFile.readFully(bArr);
        ByteBuffer wrap = ByteBuffer.wrap(bArr);
        wrap.order(ByteOrder.LITTLE_ENDIAN);
        wrap.getFloat();
    }

    public static void readFloat(RandomAccessFile randomAccessFile, Tensor tensor) throws IOException {
        for (int i = 0; i < tensor.data.length; i++) {
            tensor.data[i] = readFloat(randomAccessFile);
        }
        if (tensor.isHasGPU()) {
            tensor.hostToDevice();
        }
    }

    public static void readFloat(RandomAccessFile randomAccessFile, float[] fArr) throws IOException {
        for (int i = 0; i < fArr.length; i++) {
            fArr[i] = readFloat(randomAccessFile);
        }
    }
}
