package com.omega.example.yolo.loss;

import com.omega.common.data.Tensor;
import com.omega.common.utils.MatrixOperation;
import com.omega.common.utils.MatrixUtils;
import com.omega.engine.loss.LossFunction;
import com.omega.engine.loss.LossType;
import com.omega.example.yolo.utils.YoloUtils;

/* loaded from: input_file:com/omega/example/yolo/loss/YoloLoss7.class */
public class YoloLoss7 extends LossFunction {
    private int class_number;
    private int bbox_num;
    private int total;
    private Tensor loss;
    private Tensor diff;
    private int[] mask;
    private float[] anchors;
    private int orgW;
    private int orgH;
    private int maxBox;
    private float ignoreThresh;
    private float truthThresh;
    public final LossType lossType = LossType.yolo;
    private int outputs = 0;
    private int truths = 0;
    private float iou_normalizer = 0.05f;
    private float max_delta = 2.0f;
    private float cls_normalizer = 0.5f;
    private float obj_normalizer = 1.0f;
    private int objectness_smooth = 0;
    private float iou_thresh = 0.2f;
    private float focal_loss = 1.0f;

    public YoloLoss7(int i, int i2, int[] iArr, float[] fArr, int i3, int i4, int i5, int i6, float f, float f2) {
        this.class_number = 1;
        this.bbox_num = 3;
        this.total = 6;
        this.maxBox = 90;
        this.ignoreThresh = 0.5f;
        this.truthThresh = 1.0f;
        this.class_number = i;
        this.bbox_num = i2;
        this.mask = iArr;
        this.anchors = fArr;
        this.orgH = i3;
        this.orgW = i4;
        this.maxBox = i5;
        this.total = i6;
        this.ignoreThresh = f;
        this.truthThresh = f2;
    }

    public void init(Tensor tensor) {
        if (this.loss != null && tensor.number == this.diff.number) {
            MatrixUtils.zero(this.diff.data);
            return;
        }
        this.loss = new Tensor(1, 1, 1, 1);
        this.diff = new Tensor(tensor.number, tensor.channel, tensor.height, tensor.width, true);
        this.outputs = tensor.height * tensor.width * this.bbox_num * (this.class_number + 4 + 1);
        this.truths = this.maxBox * 5;
    }

    @Override // com.omega.engine.loss.LossFunction
    public Tensor loss(Tensor tensor, Tensor tensor2) {
        init(tensor);
        if (tensor.isHasGPU()) {
            tensor.syncHost();
        }
        float f = 0.0f;
        float f2 = 0.0f;
        float f3 = 0.0f;
        float f4 = 0.0f;
        float f5 = 0.0f;
        float f6 = 0.0f;
        int i = 0;
        int i2 = 0;
        int i3 = 0;
        int i4 = tensor.width * tensor.height;
        for (int i5 = 0; i5 < tensor.number; i5++) {
            for (int i6 = 0; i6 < tensor.height; i6++) {
                for (int i7 = 0; i7 < tensor.width; i7++) {
                    for (int i8 = 0; i8 < this.bbox_num; i8++) {
                        int i9 = (i8 * tensor.width * tensor.height) + (i6 * tensor.width) + i7;
                        int entryIndex = entryIndex(i5, tensor.width, tensor.height, i9, 0);
                        int entryIndex2 = entryIndex(i5, tensor.width, tensor.height, i9, 4);
                        int entryIndex3 = entryIndex(i5, tensor.width, tensor.height, i9, 5);
                        float[] yoloBox = getYoloBox(tensor, this.anchors, this.mask[i8], entryIndex, i7, i6, tensor.width, tensor.height, this.orgW, this.orgH, i4);
                        float f7 = 0.0f;
                        float f8 = 0.0f;
                        for (int i10 = 0; i10 < this.maxBox; i10++) {
                            float[] floatToBox = floatToBox(tensor2, i5, i10, 1);
                            if (floatToBox[0] == 0.0f) {
                                break;
                            }
                            int i11 = (int) tensor2.data[(i10 * 5) + (i5 * this.truths) + 4];
                            if (i11 >= this.class_number || i11 < 0) {
                                System.err.println("error class.");
                            }
                            float f9 = tensor.data[entryIndex2];
                            if (Float.isNaN(f9) || Float.isInfinite(f9)) {
                                tensor.data[entryIndex2] = 0.0f;
                            }
                            int compareYoloClass = compareYoloClass(tensor, this.class_number, entryIndex3, tensor.width * tensor.height, f9, i11, 0.25f);
                            float box_iou = YoloUtils.box_iou(yoloBox, floatToBox);
                            if (box_iou > f7 && compareYoloClass == 1) {
                                f7 = box_iou;
                            }
                            if (box_iou > f8) {
                                f8 = box_iou;
                            }
                        }
                        f6 += tensor.data[entryIndex2];
                        this.diff.data[entryIndex2] = this.obj_normalizer * tensor.data[entryIndex2];
                        if (f7 > this.ignoreThresh) {
                            if (this.objectness_smooth == 1) {
                                float f10 = this.obj_normalizer * (tensor.data[entryIndex2] - f7);
                                if (f10 > this.diff.data[entryIndex2]) {
                                    this.diff.data[entryIndex2] = f10;
                                }
                            } else {
                                this.diff.data[entryIndex2] = 0.0f;
                            }
                        }
                        if (f8 > this.truthThresh) {
                            System.out.println(f8);
                        }
                    }
                }
            }
            for (int i12 = 0; i12 < this.maxBox; i12++) {
                float[] floatToBox2 = floatToBox(tensor2, i5, i12, 1);
                if (floatToBox2[0] == 0.0f) {
                    break;
                }
                if (floatToBox2[0] < 0.0f || floatToBox2[1] < 0.0f || floatToBox2[0] > 1.0f || floatToBox2[1] > 1.0f || floatToBox2[2] < 0.0f || floatToBox2[3] < 0.0f) {
                    System.err.println("wrong label:[" + floatToBox2[0] + ":" + floatToBox2[1] + ":" + floatToBox2[2] + ":" + floatToBox2[3] + "].");
                }
                float f11 = 0.0f;
                int i13 = 0;
                int i14 = (int) (floatToBox2[0] * tensor.width);
                int i15 = (int) (floatToBox2[1] * tensor.height);
                float[] fArr = {0.0f, 0.0f, floatToBox2[2], floatToBox2[3]};
                for (int i16 = 0; i16 < this.total; i16++) {
                    float box_iou2 = YoloUtils.box_iou(new float[]{0.0f, 0.0f, this.anchors[2 * i16] / this.orgW, this.anchors[(2 * i16) + 1] / this.orgH}, fArr);
                    if (box_iou2 > f11) {
                        f11 = box_iou2;
                        i13 = i16;
                    }
                }
                int intIndex = intIndex(this.mask, i13, this.bbox_num);
                if (intIndex >= 0) {
                    int i17 = (intIndex * tensor.width * tensor.height) + (i15 * tensor.width) + i14;
                    float deltaYoloBox = deltaYoloBox(floatToBox2, tensor, this.anchors, i13, entryIndex(i5, tensor.width, tensor.height, i17, 0), i14, i15, tensor.width, tensor.height, i4);
                    int entryIndex4 = entryIndex(i5, tensor.width, tensor.height, i17, 4);
                    if (tensor.data[entryIndex4] >= 0.7f) {
                        i3++;
                    }
                    f5 += tensor.data[entryIndex4];
                    if (this.objectness_smooth != 1) {
                        this.diff.data[entryIndex4] = this.obj_normalizer * (tensor.data[entryIndex4] - 1.0f);
                    } else if (this.diff.data[entryIndex4] == 0.0f) {
                        this.diff.data[entryIndex4] = this.obj_normalizer * (tensor.data[entryIndex4] - 1.0f);
                    }
                    f4 = deltaYoloClass(tensor, entryIndex(i5, tensor.width, tensor.height, i17, 5), (int) tensor2.data[(i12 * 5) + (i5 * this.truths) + 4], this.class_number, i4, f4);
                    i++;
                    i2++;
                    if (deltaYoloBox > 0.5d) {
                        f2 += 1.0f;
                    }
                    if (deltaYoloBox > 0.75d) {
                        f3 += 1.0f;
                    }
                    f += deltaYoloBox;
                }
                for (int i18 = 0; i18 < this.total; i18++) {
                    int intIndex2 = intIndex(this.mask, i18, this.bbox_num);
                    if (intIndex2 >= 0 && i18 != i13 && this.iou_thresh < 1.0f && YoloUtils.box_iou(new float[]{0.0f, 0.0f, this.anchors[2 * i18] / this.orgW, this.anchors[(2 * i18) + 1] / this.orgH}, fArr) > this.iou_thresh) {
                        int i19 = (intIndex2 * tensor.width * tensor.height) + (i15 * tensor.width) + i14;
                        int i20 = (int) tensor2.data[(i12 * 5) + (i5 * this.truths) + 4];
                        float deltaYoloBox2 = deltaYoloBox(floatToBox2, tensor, this.anchors, i18, entryIndex(i5, tensor.width, tensor.height, i19, 0), i14, i15, tensor.width, tensor.height, i4);
                        int entryIndex5 = entryIndex(i5, tensor.width, tensor.height, i19, 4);
                        if (tensor.data[entryIndex5] >= 0.7f) {
                            i3++;
                        }
                        f5 += tensor.data[entryIndex5];
                        if (this.objectness_smooth != 1) {
                            this.diff.data[entryIndex5] = this.obj_normalizer * (tensor.data[entryIndex5] - 1.0f);
                        } else if (this.diff.data[entryIndex5] == 0.0f) {
                            this.diff.data[entryIndex5] = this.obj_normalizer * (tensor.data[entryIndex5] - 1.0f);
                        }
                        f4 = deltaYoloClass(tensor, entryIndex(i5, tensor.width, tensor.height, i19, 5), i20, this.class_number, i4, f4);
                        i++;
                        i2++;
                        if (deltaYoloBox2 > 0.5d) {
                            f2 += 1.0f;
                        }
                        if (deltaYoloBox2 > 0.75d) {
                            f3 += 1.0f;
                        }
                        f += deltaYoloBox2;
                    }
                }
            }
            if (this.iou_thresh < 1.0f) {
                for (int i21 = 0; i21 < tensor.height; i21++) {
                    for (int i22 = 0; i22 < tensor.width; i22++) {
                        for (int i23 = 0; i23 < this.bbox_num; i23++) {
                            int i24 = (i23 * tensor.width * tensor.height) + (i21 * tensor.width) + i22;
                            int entryIndex6 = entryIndex(i5, tensor.width, tensor.height, i24, 0);
                            int entryIndex7 = entryIndex(i5, tensor.width, tensor.height, i24, 4);
                            int entryIndex8 = entryIndex(i5, tensor.width, tensor.height, i24, 5);
                            if (this.diff.data[entryIndex7] != 0.0f) {
                                averagesYoloDeltas(entryIndex8, entryIndex6, i4, this.class_number, this.diff);
                            }
                        }
                    }
                }
            }
        }
        System.out.println("loss:" + (Math.pow(mag_array(this.diff.data), 2.0d) / tensor.number));
        System.out.println("Avg IOU: " + (f / i) + ", Class: " + (f4 / i2) + ", Obj: " + (f5 / i) + ", No Obj: " + (f6 / (((tensor.width * tensor.height) * this.bbox_num) * tensor.number)) + ", .5R: " + (f2 / i) + ", .75R: " + (f3 / i) + ",  count: " + i + ", testCount:" + i3);
        return this.loss;
    }

    public float mag_array(float[] fArr) {
        float f = 0.0f;
        for (int i = 0; i < fArr.length; i++) {
            f += fArr[i] * fArr[i];
        }
        return (float) Math.sqrt(f);
    }

    private float deltaYoloBox(float[] fArr, Tensor tensor, float[] fArr2, int i, int i2, int i3, int i4, int i5, int i6, int i7) {
        float[] yoloBox = getYoloBox(tensor, fArr2, i, i2, i3, i4, i5, i6, this.orgW, this.orgH, i7);
        float box_ciou = YoloUtils.box_ciou(yoloBox, fArr);
        if (yoloBox[2] == 0.0f) {
            yoloBox[2] = 1.0f;
        }
        if (yoloBox[3] == 0.0f) {
            yoloBox[3] = 1.0f;
        }
        float[] multiplication = MatrixOperation.multiplication(YoloUtils.dx_box_ciou(yoloBox, fArr), this.iou_normalizer);
        fix_nan_inf(multiplication);
        clip(multiplication, this.max_delta);
        float[] fArr3 = this.diff.data;
        int i8 = i2 + (0 * i7);
        fArr3[i8] = fArr3[i8] - multiplication[0];
        float[] fArr4 = this.diff.data;
        int i9 = i2 + (1 * i7);
        fArr4[i9] = fArr4[i9] - multiplication[1];
        float[] fArr5 = this.diff.data;
        int i10 = i2 + (2 * i7);
        fArr5[i10] = fArr5[i10] - multiplication[2];
        float[] fArr6 = this.diff.data;
        int i11 = i2 + (3 * i7);
        fArr6[i11] = fArr6[i11] - multiplication[3];
        return box_ciou;
    }

    private void averagesYoloDeltas(int i, int i2, int i3, int i4, Tensor tensor) {
        int i5 = 0;
        for (int i6 = 0; i6 < i4; i6++) {
            if (tensor.data[i + (i3 * i6)] > 0.0f) {
                i5++;
            }
        }
        if (i5 > 0) {
            float[] fArr = tensor.data;
            int i7 = i2 + (0 * i3);
            fArr[i7] = fArr[i7] / i5;
            float[] fArr2 = tensor.data;
            int i8 = i2 + (1 * i3);
            fArr2[i8] = fArr2[i8] / i5;
            float[] fArr3 = tensor.data;
            int i9 = i2 + (2 * i3);
            fArr3[i9] = fArr3[i9] / i5;
            float[] fArr4 = tensor.data;
            int i10 = i2 + (3 * i3);
            fArr4[i10] = fArr4[i10] / i5;
        }
    }

    private float deltaYoloClass(Tensor tensor, int i, int i2, int i3, int i4, float f) {
        if (this.diff.data[i + (i4 * i2)] == 1.0f) {
            float f2 = tensor.data[i + (i4 * i2)] - 1.0f;
            if (!Float.isNaN(f2) && !Float.isInfinite(f2)) {
                this.diff.data[i + (i4 * i2)] = f2;
            }
            return f + tensor.data[i + (i4 * i2)];
        }
        if (this.focal_loss == 1.0f) {
            float f3 = tensor.data[i + (i4 * i2)] + 1.0E-15f;
            float log = (float) ((-(1.0f - f3)) * ((((2.0f * f3) * Math.log(f3)) + f3) - 1.0d));
            int i5 = 0;
            while (i5 < i3) {
                this.diff.data[i + (i4 * i5)] = tensor.data[i + (i4 * i5)] - (i5 == i2 ? 1 : 0);
                float[] fArr = this.diff.data;
                int i6 = i + (i4 * i5);
                fArr[i6] = fArr[i6] * 0.25f * log;
                if (i5 == i2) {
                    f += tensor.data[i + (i4 * i5)];
                }
                i5++;
            }
        } else {
            int i7 = 0;
            while (i7 < i3) {
                float f4 = tensor.data[i + (i4 * i7)] - (i7 == i2 ? 1 : 0);
                if (!Float.isNaN(f4) && !Float.isInfinite(f4)) {
                    this.diff.data[i + (i4 * i7)] = f4;
                }
                if (i7 == i2) {
                    f += tensor.data[i + (i4 * i7)];
                }
                i7++;
            }
        }
        return f;
    }

    private int compareYoloClass(Tensor tensor, int i, int i2, int i3, float f, int i4, float f2) {
        for (int i5 = 0; i5 < i; i5++) {
            if (tensor.data[i2 + (i3 * i5)] > f2) {
                return 1;
            }
        }
        return 0;
    }

    private int intIndex(int[] iArr, int i, int i2) {
        for (int i3 = 0; i3 < i2; i3++) {
            if (iArr[i3] == i) {
                return i3;
            }
        }
        return -1;
    }

    private float[] floatToBox(Tensor tensor, int i, int i2, int i3) {
        return new float[]{tensor.data[((i * this.truths) + (i2 * 5) + 0) * i3], tensor.data[((i * this.truths) + (i2 * 5) + 1) * i3], tensor.data[((i * this.truths) + (i2 * 5) + 2) * i3], tensor.data[((i * this.truths) + (i2 * 5) + 3) * i3]};
    }

    public static float[] getYoloBox(Tensor tensor, float[] fArr, int i, int i2, int i3, int i4, int i5, int i6, int i7, int i8, int i9) {
        return new float[]{(i3 + tensor.data[i2 + (0 * i9)]) / i5, (i4 + tensor.data[i2 + (1 * i9)]) / i6, (((tensor.data[i2 + (2 * i9)] * tensor.data[i2 + (2 * i9)]) * 4.0f) * fArr[2 * i]) / i7, (((tensor.data[i2 + (3 * i9)] * tensor.data[i2 + (3 * i9)]) * 4.0f) * fArr[(2 * i) + 1]) / i8};
    }

    private int entryIndex(int i, int i2, int i3, int i4, int i5) {
        return (i * this.outputs) + ((i4 / (i2 * i3)) * i2 * i3 * (4 + this.class_number + 1)) + (i5 * i2 * i3) + (i4 % (i2 * i3));
    }

    @Override // com.omega.engine.loss.LossFunction
    public Tensor diff(Tensor tensor, Tensor tensor2) {
        if (this.diff.isHasGPU()) {
            this.diff.hostToDevice();
        }
        return this.diff;
    }

    @Override // com.omega.engine.loss.LossFunction
    public LossType getLossType() {
        return LossType.yolo;
    }

    @Override // com.omega.engine.loss.LossFunction
    public Tensor[] loss(Tensor[] tensorArr, Tensor tensor) {
        return null;
    }

    @Override // com.omega.engine.loss.LossFunction
    public Tensor[] diff(Tensor[] tensorArr, Tensor tensor) {
        return null;
    }

    public void test(Tensor tensor, int i, int i2, int i3) {
        for (int i4 = 0; i4 < tensor.height * tensor.width; i4++) {
            int i5 = i4 / tensor.width;
            int i6 = i4 % tensor.width;
            for (int i7 = 0; i7 < i; i7++) {
                int entryIndex = entryIndex(i2, tensor.width, tensor.height, (i7 * tensor.width * tensor.height) + (i5 * tensor.width) + i6, 4);
                float f = tensor.data[entryIndex];
                if (entryIndex == i3) {
                    System.out.println("test:" + f + "=" + tensor.data[i3]);
                }
            }
        }
    }

    @Override // com.omega.engine.loss.LossFunction
    public Tensor loss(Tensor tensor, Tensor tensor2, Tensor tensor3) {
        return null;
    }

    @Override // com.omega.engine.loss.LossFunction
    public Tensor diff(Tensor tensor, Tensor tensor2, Tensor tensor3) {
        return null;
    }

    public static float fix_nan_inf(float f) {
        if (Float.isNaN(f) || Float.isInfinite(f)) {
            return 0.0f;
        }
        return f;
    }

    public static void fix_nan_inf(float[] fArr) {
        for (int i = 0; i < fArr.length; i++) {
            if (Float.isNaN(fArr[i]) || Float.isInfinite(fArr[i])) {
                fArr[i] = 0.0f;
            }
        }
    }

    public static float clip(float f, float f2) {
        if (f > f2) {
            f = f2;
        } else if (f < (-f2)) {
            f = -f2;
        }
        return f;
    }

    public static void clip(float[] fArr, float f) {
        for (int i = 0; i < fArr.length; i++) {
            if (fArr[i] > f) {
                fArr[i] = f;
            } else if (fArr[i] < (-f)) {
                fArr[i] = -f;
            }
        }
    }

    @Override // com.omega.engine.loss.LossFunction
    public Tensor loss(Tensor tensor, Tensor tensor2, int i) {
        return null;
    }

    @Override // com.omega.engine.loss.LossFunction
    public Tensor diff(Tensor tensor, Tensor tensor2, int i) {
        return null;
    }
}
