package com.omega.example.yolo.loss;

import com.omega.common.data.Tensor;
import com.omega.common.utils.MatrixUtils;
import com.omega.engine.loss.LossFunction;
import com.omega.engine.loss.LossType;
import com.omega.example.yolo.utils.YoloUtils;

/* loaded from: input_file:com/omega/example/yolo/loss/YoloLoss3.class */
public class YoloLoss3 extends LossFunction {
    private int class_number;
    private int bbox_num;
    private int total;
    private Tensor loss;
    private Tensor diff;
    private int[] mask;
    private float[] anchors;
    private int orgW;
    private int orgH;
    private int maxBox;
    private float ignoreThresh;
    private float truthThresh;
    public final LossType lossType = LossType.yolo;
    private int outputs = 0;
    private int truths = 0;
    private float eta = 1.0E-6f;

    public YoloLoss3(int i, int i2, int[] iArr, float[] fArr, int i3, int i4, int i5, int i6, float f, float f2) {
        this.class_number = 1;
        this.bbox_num = 3;
        this.total = 6;
        this.maxBox = 90;
        this.ignoreThresh = 0.5f;
        this.truthThresh = 1.0f;
        this.class_number = i;
        this.bbox_num = i2;
        this.mask = iArr;
        this.anchors = fArr;
        this.orgH = i3;
        this.orgW = i4;
        this.maxBox = i5;
        this.total = i6;
        this.ignoreThresh = f;
        this.truthThresh = f2;
    }

    public void init(Tensor tensor) {
        if (this.loss != null && tensor.number == this.diff.number) {
            MatrixUtils.zero(this.diff.data);
            return;
        }
        this.loss = new Tensor(1, 1, 1, 1);
        this.diff = new Tensor(tensor.number, tensor.channel, tensor.height, tensor.width, true);
        this.outputs = tensor.height * tensor.width * this.bbox_num * (this.class_number + 4 + 1);
        this.truths = this.maxBox * 5;
    }

    @Override // com.omega.engine.loss.LossFunction
    public Tensor loss(Tensor tensor, Tensor tensor2) {
        init(tensor);
        if (tensor.isHasGPU()) {
            tensor.syncHost();
        }
        float f = 0.0f;
        float f2 = 0.0f;
        float f3 = 0.0f;
        float f4 = 0.0f;
        float f5 = 0.0f;
        float f6 = 0.0f;
        int i = 0;
        int i2 = 0;
        int i3 = 0;
        int i4 = tensor.width * tensor.height;
        for (int i5 = 0; i5 < tensor.number; i5++) {
            for (int i6 = 0; i6 < tensor.height; i6++) {
                for (int i7 = 0; i7 < tensor.width; i7++) {
                    for (int i8 = 0; i8 < this.bbox_num; i8++) {
                        int i9 = (i8 * tensor.width * tensor.height) + (i6 * tensor.width) + i7;
                        float[] yoloBox = getYoloBox(tensor, this.anchors, this.mask[i8], entryIndex(i5, tensor.width, tensor.height, i9, 0), i7, i6, tensor.width, tensor.height, this.orgW, this.orgH, i4);
                        float f7 = 0.0f;
                        for (int i10 = 0; i10 < this.maxBox; i10++) {
                            float[] floatToBox = floatToBox(tensor2, i5, i10, 1);
                            if (floatToBox[0] == 0.0f) {
                                break;
                            }
                            float box_iou = YoloUtils.box_iou(yoloBox, floatToBox);
                            if (box_iou > f7) {
                                f7 = box_iou;
                            }
                        }
                        int entryIndex = entryIndex(i5, tensor.width, tensor.height, i9, 4);
                        f6 += tensor.data[entryIndex];
                        this.diff.data[entryIndex] = tensor.data[entryIndex];
                        if (f7 > this.ignoreThresh) {
                            this.diff.data[entryIndex] = 0.0f;
                        }
                        if (f7 > this.truthThresh) {
                            System.out.println(f7);
                        }
                    }
                }
            }
            for (int i11 = 0; i11 < this.maxBox; i11++) {
                float[] floatToBox2 = floatToBox(tensor2, i5, i11, 1);
                if (floatToBox2[0] == 0.0f) {
                    break;
                }
                float f8 = 0.0f;
                int i12 = 0;
                int i13 = (int) (floatToBox2[0] * tensor.width);
                int i14 = (int) (floatToBox2[1] * tensor.height);
                float[] fArr = {0.0f, 0.0f, floatToBox2[2], floatToBox2[3]};
                for (int i15 = 0; i15 < this.total; i15++) {
                    float box_iou2 = YoloUtils.box_iou(new float[]{0.0f, 0.0f, this.anchors[2 * i15] / this.orgW, this.anchors[(2 * i15) + 1] / this.orgH}, fArr);
                    if (box_iou2 > f8) {
                        f8 = box_iou2;
                        i12 = i15;
                    }
                }
                int intIndex = intIndex(this.mask, i12, this.bbox_num);
                if (intIndex >= 0) {
                    int i16 = (intIndex * tensor.width * tensor.height) + (i14 * tensor.width) + i13;
                    float deltaYoloBox = deltaYoloBox(floatToBox2, tensor, this.anchors, i12, entryIndex(i5, tensor.width, tensor.height, i16, 0), i13, i14, tensor.width, tensor.height, 2.0f - (floatToBox2[2] * floatToBox2[3]), i4);
                    int entryIndex2 = entryIndex(i5, tensor.width, tensor.height, i16, 4);
                    if (tensor.data[entryIndex2] >= 0.8f) {
                        i3++;
                    }
                    f5 += tensor.data[entryIndex2];
                    this.diff.data[entryIndex2] = tensor.data[entryIndex2] - 1.0f;
                    f4 = deltaYoloClass(tensor, entryIndex(i5, tensor.width, tensor.height, i16, 5), (int) tensor2.data[(i11 * 5) + (i5 * this.truths) + 4], this.class_number, i4, f4);
                    i++;
                    i2++;
                    if (deltaYoloBox > 0.5d) {
                        f2 += 1.0f;
                    }
                    if (deltaYoloBox > 0.75d) {
                        f3 += 1.0f;
                    }
                    f += deltaYoloBox;
                }
            }
        }
        System.out.println("loss:" + (Math.pow(mag_array(this.diff.data), 2.0d) / tensor.number));
        System.out.println("Avg IOU: " + (f / i) + ", Class: " + (f4 / i2) + ", Obj: " + (f5 / i) + ", No Obj: " + (f6 / (((tensor.width * tensor.height) * this.bbox_num) * tensor.number)) + ", .5R: " + (f2 / i) + ", .75R: " + (f3 / i) + ",  count: " + i + ", testCount:" + i3);
        return this.loss;
    }

    public float mag_array(float[] fArr) {
        float f = 0.0f;
        for (int i = 0; i < fArr.length; i++) {
            f += fArr[i] * fArr[i];
        }
        return (float) Math.sqrt(f);
    }

    private float deltaYoloClass(Tensor tensor, int i, int i2, int i3, int i4, float f) {
        if (this.diff.data[i] == 1.0f) {
            this.diff.data[i + (i4 * i2)] = tensor.data[i + (i4 * i2)] - 1.0f;
            return f + tensor.data[i + (i4 * i2)];
        }
        int i5 = 0;
        while (i5 < i3) {
            this.diff.data[i + (i4 * i5)] = tensor.data[i + (i4 * i5)] - (i5 == i2 ? 1 : 0);
            if (i5 == i2) {
                f += tensor.data[i + (i4 * i5)];
            }
            i5++;
        }
        return f;
    }

    private float deltaYoloBox(float[] fArr, Tensor tensor, float[] fArr2, int i, int i2, int i3, int i4, int i5, int i6, float f, int i7) {
        float box_iou = YoloUtils.box_iou(getYoloBox(tensor, fArr2, i, i2, i3, i4, i5, i6, this.orgW, this.orgH, i7), fArr);
        float f2 = (fArr[0] * i5) - i3;
        float f3 = (fArr[1] * i6) - i4;
        float log = (float) Math.log((fArr[2] * this.orgW) / fArr2[2 * i]);
        float log2 = (float) Math.log((fArr[3] * this.orgH) / fArr2[(2 * i) + 1]);
        this.diff.data[i2 + (0 * i7)] = f * (tensor.data[i2 + (0 * i7)] - f2);
        this.diff.data[i2 + (1 * i7)] = f * (tensor.data[i2 + (1 * i7)] - f3);
        this.diff.data[i2 + (2 * i7)] = f * (tensor.data[i2 + (2 * i7)] - log);
        this.diff.data[i2 + (3 * i7)] = f * (tensor.data[i2 + (3 * i7)] - log2);
        return box_iou;
    }

    private int intIndex(int[] iArr, int i, int i2) {
        for (int i3 = 0; i3 < i2; i3++) {
            if (iArr[i3] == i) {
                return i3;
            }
        }
        return -1;
    }

    private float[] floatToBox(Tensor tensor, int i, int i2, int i3) {
        return new float[]{tensor.data[((i * this.truths) + (i2 * 5) + 0) * i3], tensor.data[((i * this.truths) + (i2 * 5) + 1) * i3], tensor.data[((i * this.truths) + (i2 * 5) + 2) * i3], tensor.data[((i * this.truths) + (i2 * 5) + 3) * i3]};
    }

    public static float[] getYoloBox(Tensor tensor, float[] fArr, int i, int i2, int i3, int i4, int i5, int i6, int i7, int i8, int i9) {
        return new float[]{(i3 + tensor.data[i2 + (0 * i9)]) / i5, (i4 + tensor.data[i2 + (1 * i9)]) / i6, (float) ((Math.exp(tensor.data[i2 + (2 * i9)]) * fArr[2 * i]) / i7), (float) ((Math.exp(tensor.data[i2 + (3 * i9)]) * fArr[(2 * i) + 1]) / i8)};
    }

    private int entryIndex(int i, int i2, int i3, int i4, int i5) {
        return (i * this.outputs) + ((i4 / (i2 * i3)) * i2 * i3 * (4 + this.class_number + 1)) + (i5 * i2 * i3) + (i4 % (i2 * i3));
    }

    @Override // com.omega.engine.loss.LossFunction
    public Tensor diff(Tensor tensor, Tensor tensor2) {
        if (this.diff.isHasGPU()) {
            this.diff.hostToDevice();
        }
        return this.diff;
    }

    @Override // com.omega.engine.loss.LossFunction
    public LossType getLossType() {
        return LossType.yolo;
    }

    @Override // com.omega.engine.loss.LossFunction
    public Tensor[] loss(Tensor[] tensorArr, Tensor tensor) {
        return null;
    }

    @Override // com.omega.engine.loss.LossFunction
    public Tensor[] diff(Tensor[] tensorArr, Tensor tensor) {
        return null;
    }

    public void test(Tensor tensor, int i, int i2, int i3) {
        for (int i4 = 0; i4 < tensor.height * tensor.width; i4++) {
            int i5 = i4 / tensor.width;
            int i6 = i4 % tensor.width;
            for (int i7 = 0; i7 < i; i7++) {
                int entryIndex = entryIndex(i2, tensor.width, tensor.height, (i7 * tensor.width * tensor.height) + (i5 * tensor.width) + i6, 4);
                float f = tensor.data[entryIndex];
                if (entryIndex == i3) {
                    System.out.println("test:" + f + "=" + tensor.data[i3]);
                }
            }
        }
    }

    @Override // com.omega.engine.loss.LossFunction
    public Tensor loss(Tensor tensor, Tensor tensor2, Tensor tensor3) {
        return null;
    }

    @Override // com.omega.engine.loss.LossFunction
    public Tensor diff(Tensor tensor, Tensor tensor2, Tensor tensor3) {
        return null;
    }

    @Override // com.omega.engine.loss.LossFunction
    public Tensor loss(Tensor tensor, Tensor tensor2, int i) {
        return null;
    }

    @Override // com.omega.engine.loss.LossFunction
    public Tensor diff(Tensor tensor, Tensor tensor2, int i) {
        return null;
    }
}
