package com.omega.engine.optimizer;

import com.omega.common.data.Tensor;
import com.omega.common.task.TaskEngine;
import com.omega.common.utils.JsonUtils;
import com.omega.common.utils.LabelUtils;
import com.omega.common.utils.MatrixOperation;
import com.omega.common.utils.RandomUtils;
import com.omega.engine.check.BaseCheck;
import com.omega.engine.nn.data.BaseData;
import com.omega.engine.nn.layer.YoloLayer;
import com.omega.engine.nn.network.Network;
import com.omega.engine.nn.network.RunModel;
import com.omega.engine.nn.network.Yolo;
import com.omega.engine.optimizer.lr.GDDecay;
import com.omega.engine.optimizer.lr.HalfDecay;
import com.omega.engine.optimizer.lr.LRDecay;
import com.omega.engine.optimizer.lr.LearnRateUpdate;
import com.omega.example.transformer.utils.BPETokenizer;
import com.omega.example.yolo.data.BaseDataLoader;
import com.omega.example.yolo.data.DetectionDataLoader;
import com.omega.example.yolo.model.YoloBox;
import com.omega.example.yolo.model.YoloDetection;
import com.omega.example.yolo.utils.YoloDataLoader;
import com.omega.example.yolo.utils.YoloDecode;
import com.omega.example.yolo.utils.YoloImageUtils;
import com.omega.example.yolo.utils.YoloUtils;
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;

/* loaded from: input_file:com/omega/engine/optimizer/Optimizer.class */
public abstract class Optimizer {
    private String sid;
    public int batchIndex;
    public int trainIndex;
    public int batchSize;
    public int dataSize;
    public Tensor loss;
    public Tensor lossDiff;
    public int trainTime;
    public int minTrainTime;
    public float currentError;
    public float error;
    public Network network;
    private TaskEngine trainEngine;
    public LearnRateUpdate learnRateUpdate;
    private BaseData trainingData;
    private BaseData testData;
    public BaseCheck check;
    private boolean warmUp;
    public int burnIn;
    public int power;
    public float scale;
    public int step;
    public float gama;
    public float lr;
    public boolean isOnline;
    public int lrStartTime;
    public float max_lr;
    public float min_lr;
    public int trainMax;
    public float min_loss;
    public int counter;
    public int[] lr_step;

    /* renamed from: com.omega.engine.optimizer.Optimizer$2, reason: invalid class name */
    /* loaded from: input_file:com/omega/engine/optimizer/Optimizer$2.class */
    static /* synthetic */ class AnonymousClass2 {
        static final /* synthetic */ int[] $SwitchMap$com$omega$engine$optimizer$lr$LearnRateUpdate = new int[LearnRateUpdate.values().length];

        static {
            try {
                $SwitchMap$com$omega$engine$optimizer$lr$LearnRateUpdate[LearnRateUpdate.LR_DECAY.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$com$omega$engine$optimizer$lr$LearnRateUpdate[LearnRateUpdate.GD_GECAY.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$com$omega$engine$optimizer$lr$LearnRateUpdate[LearnRateUpdate.NONE.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$com$omega$engine$optimizer$lr$LearnRateUpdate[LearnRateUpdate.CONSTANT.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
            try {
                $SwitchMap$com$omega$engine$optimizer$lr$LearnRateUpdate[LearnRateUpdate.COSINE.ordinal()] = 5;
            } catch (NoSuchFieldError e5) {
            }
            try {
                $SwitchMap$com$omega$engine$optimizer$lr$LearnRateUpdate[LearnRateUpdate.COSINE_ANNEALING.ordinal()] = 6;
            } catch (NoSuchFieldError e6) {
            }
            try {
                $SwitchMap$com$omega$engine$optimizer$lr$LearnRateUpdate[LearnRateUpdate.RANDOM.ordinal()] = 7;
            } catch (NoSuchFieldError e7) {
            }
            try {
                $SwitchMap$com$omega$engine$optimizer$lr$LearnRateUpdate[LearnRateUpdate.POLY.ordinal()] = 8;
            } catch (NoSuchFieldError e8) {
            }
            try {
                $SwitchMap$com$omega$engine$optimizer$lr$LearnRateUpdate[LearnRateUpdate.STEP.ordinal()] = 9;
            } catch (NoSuchFieldError e9) {
            }
            try {
                $SwitchMap$com$omega$engine$optimizer$lr$LearnRateUpdate[LearnRateUpdate.EXP.ordinal()] = 10;
            } catch (NoSuchFieldError e10) {
            }
            try {
                $SwitchMap$com$omega$engine$optimizer$lr$LearnRateUpdate[LearnRateUpdate.SIG.ordinal()] = 11;
            } catch (NoSuchFieldError e11) {
            }
            try {
                $SwitchMap$com$omega$engine$optimizer$lr$LearnRateUpdate[LearnRateUpdate.HALF.ordinal()] = 12;
            } catch (NoSuchFieldError e12) {
            }
            try {
                $SwitchMap$com$omega$engine$optimizer$lr$LearnRateUpdate[LearnRateUpdate.SMART_HALF.ordinal()] = 13;
            } catch (NoSuchFieldError e13) {
            }
        }
    }

    public abstract void train(BaseData baseData);

    public abstract void train(BaseData baseData, BaseData baseData2);

    public abstract void train(BaseData baseData, BaseData baseData2, BaseData baseData3);

    public Optimizer(Network network, int i, int i2, float f, boolean z) throws Exception {
        this.batchIndex = 1;
        this.trainIndex = 1;
        this.batchSize = 1;
        this.dataSize = 0;
        this.trainTime = 10;
        this.minTrainTime = 10000;
        this.currentError = 1.0f;
        this.error = 0.01f;
        this.learnRateUpdate = LearnRateUpdate.NONE;
        this.warmUp = false;
        this.burnIn = 300;
        this.power = 4;
        this.scale = 0.1f;
        this.step = 500;
        this.gama = 0.9999f;
        this.lr = 0.1f;
        this.isOnline = false;
        this.lrStartTime = 5;
        this.max_lr = 1.0E-4f;
        this.min_lr = 0.0f;
        this.trainMax = 100;
        this.min_loss = Float.POSITIVE_INFINITY;
        this.counter = 0;
        this.network = network;
        this.trainTime = i2;
        this.error = f;
        this.lr = network.learnRate;
        this.max_lr = network.learnRate;
        this.warmUp = z;
        this.network.init();
    }

    public Optimizer(Network network, int i, int i2, int i3, float f, boolean z) throws Exception {
        this.batchIndex = 1;
        this.trainIndex = 1;
        this.batchSize = 1;
        this.dataSize = 0;
        this.trainTime = 10;
        this.minTrainTime = 10000;
        this.currentError = 1.0f;
        this.error = 0.01f;
        this.learnRateUpdate = LearnRateUpdate.NONE;
        this.warmUp = false;
        this.burnIn = 300;
        this.power = 4;
        this.scale = 0.1f;
        this.step = 500;
        this.gama = 0.9999f;
        this.lr = 0.1f;
        this.isOnline = false;
        this.lrStartTime = 5;
        this.max_lr = 1.0E-4f;
        this.min_lr = 0.0f;
        this.trainMax = 100;
        this.min_loss = Float.POSITIVE_INFINITY;
        this.counter = 0;
        this.network = network;
        this.trainTime = i2;
        this.minTrainTime = i3;
        this.error = f;
        this.warmUp = z;
        this.lr = network.learnRate;
        this.max_lr = network.learnRate;
        this.network.init();
    }

    public Optimizer(Network network, int i, int i2, int i3, float f, boolean z, LearnRateUpdate learnRateUpdate) throws Exception {
        this.batchIndex = 1;
        this.trainIndex = 1;
        this.batchSize = 1;
        this.dataSize = 0;
        this.trainTime = 10;
        this.minTrainTime = 10000;
        this.currentError = 1.0f;
        this.error = 0.01f;
        this.learnRateUpdate = LearnRateUpdate.NONE;
        this.warmUp = false;
        this.burnIn = 300;
        this.power = 4;
        this.scale = 0.1f;
        this.step = 500;
        this.gama = 0.9999f;
        this.lr = 0.1f;
        this.isOnline = false;
        this.lrStartTime = 5;
        this.max_lr = 1.0E-4f;
        this.min_lr = 0.0f;
        this.trainMax = 100;
        this.min_loss = Float.POSITIVE_INFINITY;
        this.counter = 0;
        this.network = network;
        this.trainTime = i2;
        this.minTrainTime = i3;
        this.error = f;
        this.warmUp = z;
        this.learnRateUpdate = learnRateUpdate;
        this.lr = network.learnRate;
        this.max_lr = network.learnRate;
        this.network.init();
    }

    public void setTrainEngine(TaskEngine taskEngine) {
        this.trainEngine = taskEngine;
    }

    public TaskEngine getTrainEngine() {
        return this.trainEngine;
    }

    public BaseData getTrainingData() {
        return this.trainingData;
    }

    public void setTrainingData(BaseData baseData) {
        this.trainingData = baseData;
    }

    public BaseData getTestData() {
        return this.testData;
    }

    public void setTestData(BaseData baseData) {
        this.testData = baseData;
    }

    public void updateLR(int[] iArr) {
        if (this.warmUp && this.batchIndex < this.burnIn) {
            this.network.learnRate = (float) (this.lr * Math.pow(((this.batchIndex * 1.0f) / this.burnIn) * 1.0f, this.power));
            return;
        }
        switch (AnonymousClass2.$SwitchMap$com$omega$engine$optimizer$lr$LearnRateUpdate[this.learnRateUpdate.ordinal()]) {
            case 1:
                this.network.learnRate = LRDecay.decayedLR(this.max_lr, this.network.learnRate, this.trainIndex, 5);
                return;
            case 2:
                this.network.learnRate = GDDecay.decayedLR(this.max_lr, this.trainIndex);
                return;
            case 3:
            case 4:
            default:
                return;
            case 5:
                if (this.trainIndex >= this.lrStartTime) {
                    this.network.learnRate = ((float) (0.5d * this.max_lr * (Math.cos((this.trainIndex * 3.141592653589793d) / this.trainTime) + 1.0d))) * this.network.learnRate;
                    return;
                } else {
                    this.network.learnRate = (this.trainIndex * this.network.learnRate) / this.lrStartTime;
                    return;
                }
            case 6:
                this.network.learnRate = (float) (this.min_lr + (0.5d * (this.max_lr - this.min_lr) * (Math.cos((this.trainIndex * 3.141592653589793d) / this.trainMax) + 1.0d)));
                return;
            case YoloImageUtils.GRID_SIZE /* 7 */:
                this.network.learnRate = ((float) Math.pow(RandomUtils.getInstance().nextFloat(), this.power)) * this.lr;
                return;
            case 8:
                float f = (((this.batchIndex * 1.0f) / this.trainTime) / this.dataSize) * this.batchSize;
                this.network.learnRate = (float) (this.lr * Math.pow(1.0f - f, this.power));
                return;
            case 9:
                this.network.learnRate = (float) (this.lr * Math.pow(this.scale, this.batchIndex / this.step));
                return;
            case 10:
                this.network.learnRate = (float) (this.lr * Math.pow(this.gama, this.batchIndex));
                return;
            case 11:
                this.network.learnRate = (float) (this.lr / (1.0d + Math.pow(2.718281828459045d, this.gama * (this.batchIndex - this.step))));
                return;
            case 12:
                if (this.counter % 10 == 0) {
                    this.network.learnRate = HalfDecay.decayedLR(this.network.learnRate);
                    return;
                }
                return;
            case 13:
                if (this.learnRateUpdate == LearnRateUpdate.SMART_HALF) {
                    if (iArr == null) {
                        if (this.trainIndex % 200 == 0) {
                            this.network.learnRate *= 0.5f;
                            return;
                        }
                        return;
                    }
                    for (int i : iArr) {
                        if (i == this.trainIndex) {
                            this.network.learnRate *= 0.1f;
                        }
                    }
                    return;
                }
                return;
        }
    }

    public float updateLR(int[] iArr, float f, float f2) {
        if (!this.warmUp || this.batchIndex >= this.burnIn) {
            switch (AnonymousClass2.$SwitchMap$com$omega$engine$optimizer$lr$LearnRateUpdate[this.learnRateUpdate.ordinal()]) {
                case 1:
                    f = LRDecay.decayedLR(this.max_lr, f, this.trainIndex, 5);
                    break;
                case 2:
                    f = GDDecay.decayedLR(this.max_lr, this.trainIndex);
                    break;
                case 5:
                    if (this.trainIndex >= this.lrStartTime) {
                        f = ((float) (0.5d * this.max_lr * (Math.cos((this.trainIndex * 3.141592653589793d) / this.trainTime) + 1.0d))) * f;
                        break;
                    } else {
                        f = (this.trainIndex * f) / this.lrStartTime;
                        break;
                    }
                case 6:
                    f = (float) (this.min_lr + (0.5d * (this.max_lr - this.min_lr) * (Math.cos((this.trainIndex * 3.141592653589793d) / this.trainMax) + 1.0d)));
                    break;
                case YoloImageUtils.GRID_SIZE /* 7 */:
                    f = ((float) Math.pow(RandomUtils.getInstance().nextFloat(), this.power)) * this.lr;
                    break;
                case 8:
                    f = (float) (f2 * Math.pow(1.0f - ((((this.batchIndex * 1.0f) / this.trainTime) / this.dataSize) * this.batchSize), this.power));
                    break;
                case 9:
                    f = (float) (f2 * Math.pow(this.scale, this.batchIndex / this.step));
                    break;
                case 10:
                    f = (float) (f2 * Math.pow(this.gama, this.batchIndex));
                    break;
                case 11:
                    f = (float) (f2 / (1.0d + Math.pow(2.718281828459045d, this.gama * (this.batchIndex - this.step))));
                    break;
                case 12:
                    if (this.counter % 10 == 0) {
                        f = HalfDecay.decayedLR(f);
                        break;
                    }
                    break;
                case 13:
                    if (this.learnRateUpdate == LearnRateUpdate.SMART_HALF) {
                        if (iArr != null) {
                            for (int i : iArr) {
                                if (i == this.trainIndex) {
                                    f *= 0.1f;
                                }
                            }
                            break;
                        } else if (this.trainIndex % 200 == 0) {
                            f *= 0.5f;
                            break;
                        }
                    }
                    break;
            }
        } else {
            f = (float) (this.lr * Math.pow(((this.batchIndex * 1.0f) / this.burnIn) * 1.0f, this.power));
        }
        return f;
    }

    public void updateLR(float f) {
        if (this.warmUp && this.batchIndex < this.burnIn) {
            this.network.learnRate = (float) (this.lr * Math.pow(((this.batchIndex * 1.0f) / this.burnIn) * 1.0f, this.power));
            return;
        }
        switch (AnonymousClass2.$SwitchMap$com$omega$engine$optimizer$lr$LearnRateUpdate[this.learnRateUpdate.ordinal()]) {
            case 1:
                this.network.learnRate = LRDecay.decayedLR(this.max_lr, this.network.learnRate, this.trainIndex, 5);
                return;
            case 2:
                this.network.learnRate = GDDecay.decayedLR(this.max_lr, this.trainIndex);
                return;
            case 3:
            case 4:
            case 6:
            default:
                return;
            case 5:
                if (this.trainIndex >= this.lrStartTime) {
                    this.network.learnRate = ((float) ((0.5d * this.max_lr * Math.cos((this.trainIndex / this.trainTime) * 3.141592653589793d)) + 1.0d)) * this.network.learnRate;
                    return;
                } else {
                    this.network.learnRate = (this.trainIndex * this.network.learnRate) / this.lrStartTime;
                    return;
                }
            case YoloImageUtils.GRID_SIZE /* 7 */:
                this.network.learnRate = ((float) Math.pow(RandomUtils.getInstance().nextFloat(), this.power)) * this.lr;
                return;
            case 8:
                float f2 = (((this.batchIndex * 1.0f) / this.trainTime) / this.dataSize) * this.batchSize;
                this.network.learnRate = (float) (this.lr * Math.pow(1.0f - f2, this.power));
                return;
            case 9:
                this.network.learnRate = (float) (this.lr * Math.pow(this.scale, this.batchIndex / this.step));
                return;
            case 10:
                this.network.learnRate = (float) (this.lr * Math.pow(this.gama, this.batchIndex));
                return;
            case 11:
                this.network.learnRate = (float) (this.lr / (1.0d + Math.pow(2.718281828459045d, this.gama * (this.batchIndex - this.step))));
                return;
            case 12:
                if (this.counter % 10 == 0) {
                    this.network.learnRate = HalfDecay.decayedLR(this.network.learnRate);
                    return;
                }
                return;
            case 13:
                if (f <= this.min_loss) {
                    System.out.println("Validation loss decreased (" + this.min_loss + " --> " + f + ")");
                    this.min_loss = f;
                    this.counter = 0;
                } else {
                    this.counter++;
                    System.out.println("Validation loss decreased (" + this.min_loss + " < " + f + ") update counter:" + this.counter);
                }
                if (this.counter >= 9) {
                    this.network.learnRate = HalfDecay.decayedLR(this.network.learnRate);
                    this.counter = 0;
                    return;
                }
                return;
        }
    }

    public float test(BaseData baseData) {
        float f = 0.0f;
        Tensor tensor = new Tensor(1, baseData.channel, baseData.height, baseData.width, true);
        this.network.RUN_MODEL = RunModel.TEST;
        for (int i = 0; i < baseData.number; i++) {
            baseData.getOnceData(i, tensor);
            Tensor predict = this.network.predict(tensor);
            predict.syncHost();
            if (baseData.labels[i].equals(LabelUtils.vectorTolabel(predict.data, baseData.labelSet))) {
                f += 1.0f;
            }
        }
        float f2 = f / baseData.number;
        System.out.println("准确率:" + (f2 * 100.0f) + "%");
        return f2;
    }

    public float test(BaseData baseData, int i) {
        float f = 0.0f;
        long nanoTime = System.nanoTime();
        this.network.RUN_MODEL = RunModel.TEST;
        Tensor tensor = new Tensor(i, baseData.channel, baseData.height, baseData.width, true);
        Tensor tensor2 = new Tensor(i, baseData.label.channel, baseData.label.height, baseData.label.width);
        int intValue = new BigDecimal(baseData.number).divide(new BigDecimal(i), 0, 1).intValue();
        for (int i2 = 0; i2 < intValue; i2++) {
            baseData.getBatchData(i2, i, tensor, tensor2);
            tensor.hostToDevice();
            this.network.predict(tensor).syncHost();
            f += accuracyTrueCount(r0, tensor2, baseData.labelSet);
        }
        float f2 = (f / intValue) / i;
        System.out.println("training[" + this.trainIndex + "] vail accuracy:{" + (f2 * 100.0f) + "%} [costTime:" + ((System.nanoTime() - nanoTime) / 1000000.0d) + "ms.]");
        return f2;
    }

    public float testAndLoss(BaseData baseData, int i) {
        float f = 0.0f;
        float f2 = 0.0f;
        long nanoTime = System.nanoTime();
        this.network.RUN_MODEL = RunModel.TEST;
        Tensor tensor = new Tensor(i, baseData.channel, baseData.height, baseData.width, true);
        Tensor tensor2 = new Tensor(i, baseData.label.channel, baseData.label.height, baseData.label.width);
        int intValue = new BigDecimal(baseData.number).divide(new BigDecimal(i), 0, 1).intValue();
        for (int i2 = 0; i2 < intValue; i2++) {
            baseData.getBatchData(i2, i, tensor, tensor2);
            tensor.hostToDevice();
            tensor2.hostToDevice();
            Tensor predict = this.network.predict(tensor);
            f2 += MatrixOperation.sum(this.network.loss(predict, tensor2).syncHost()) / i;
            predict.syncHost();
            f += accuracyTrueCount(predict, tensor2, baseData.labelSet);
        }
        float f3 = (f / intValue) / i;
        float f4 = f2 / intValue;
        System.out.println("training[" + this.trainIndex + "] vail accuracy:{" + (f3 * 100.0f) + "%} vail loss:{" + f4 + "}  [costTime:" + ((System.nanoTime() - nanoTime) / 1000000.0d) + "ms.]");
        return f4;
    }

    public float testAndLoss(BaseData baseData, Tensor tensor, Tensor tensor2, int i) {
        float f;
        float sum;
        float f2 = 0.0f;
        float f3 = 0.0f;
        long nanoTime = System.nanoTime();
        this.network.RUN_MODEL = RunModel.TEST;
        int intValue = new BigDecimal(baseData.number).divide(new BigDecimal(i), 0, 1).intValue();
        for (int i2 = 0; i2 < intValue; i2++) {
            baseData.getBatchData(i2, i, tensor, tensor2);
            tensor.hostToDevice();
            tensor2.hostToDevice();
            Tensor predict = this.network.predict(tensor);
            Tensor loss = this.network.loss(predict, tensor2);
            if (loss.isHasGPU()) {
                f = f3;
                sum = MatrixOperation.sum(loss.syncHost());
            } else {
                f = f3;
                sum = MatrixOperation.sum(loss.data);
            }
            f3 = f + sum;
            predict.syncHost();
            f2 += accuracyTrueCount(predict, tensor2, baseData.labelSet);
        }
        float f4 = (f2 / intValue) / i;
        float f5 = (f3 / intValue) / i;
        System.out.println("training[" + this.trainIndex + "] vail accuracy:{" + (f4 * 100.0f) + "%} vail loss:{" + f5 + "}  [costTime:" + ((System.nanoTime() - nanoTime) / 1000000.0d) + "ms.]");
        return f5;
    }

    public float testAndLoss(BaseDataLoader baseDataLoader, Tensor tensor, Tensor tensor2, int i, BaseCheck baseCheck) {
        float f;
        float sum;
        float f2 = 0.0f;
        long nanoTime = System.nanoTime();
        this.network.RUN_MODEL = RunModel.TEST;
        int intValue = new BigDecimal(baseDataLoader.number).divide(new BigDecimal(i), 0, 0).intValue();
        float f3 = 0.0f;
        for (int i2 = 0; i2 < intValue; i2++) {
            baseDataLoader.loadData(i2, i, tensor, tensor2);
            tensor.hostToDevice();
            tensor2.hostToDevice();
            Tensor predict = this.network.predict(tensor);
            Tensor loss = this.network.loss(predict, tensor2);
            if (loss.isHasGPU()) {
                f = f2;
                sum = MatrixOperation.sum(loss.syncHost());
            } else {
                f = f2;
                sum = MatrixOperation.sum(loss.data);
            }
            f2 = f + sum;
            predict.syncHost();
            f3 += baseCheck.check(predict, tensor2, baseDataLoader.labelSet, true);
        }
        float f4 = f2 / intValue;
        System.out.println("test[" + this.trainIndex + "] vail loss:{" + f4 + "} (accuracy:" + ((f3 / baseDataLoader.number) * 100.0f) + "%) [costTime:" + ((System.nanoTime() - nanoTime) / 1000000.0d) + "ms.]");
        return f4;
    }

    public float testObjectRecognition(BaseData baseData, Tensor tensor, Tensor tensor2, int i) {
        float f;
        float sum;
        float f2 = 0.0f;
        long nanoTime = System.nanoTime();
        this.network.RUN_MODEL = RunModel.TEST;
        int intValue = new BigDecimal(baseData.number).divide(new BigDecimal(i), 0, 0).intValue();
        for (int i2 = 0; i2 < intValue; i2++) {
            baseData.getBatchData(i2, i, tensor, tensor2);
            tensor.hostToDevice();
            tensor2.hostToDevice();
            Tensor loss = this.network.loss(this.network.predict(tensor), tensor2);
            if (loss.isHasGPU()) {
                f = f2;
                sum = MatrixOperation.sum(loss.syncHost());
            } else {
                f = f2;
                sum = MatrixOperation.sum(loss.data);
            }
            f2 = f + sum;
        }
        float f3 = (f2 / intValue) / i;
        System.out.println("test[" + this.trainIndex + "] vail loss:{" + f3 + "} [costTime:" + ((System.nanoTime() - nanoTime) / 1000000.0d) + "ms.]");
        return f3;
    }

    public float testObjectRecognition(BaseDataLoader baseDataLoader, Tensor tensor, Tensor tensor2, int i) {
        float f;
        float sum;
        float f2 = 0.0f;
        long nanoTime = System.nanoTime();
        this.network.RUN_MODEL = RunModel.TEST;
        int intValue = new BigDecimal(baseDataLoader.number).divide(new BigDecimal(i), 0, 0).intValue();
        for (int i2 = 0; i2 < intValue; i2++) {
            baseDataLoader.loadData(i2, i, tensor, tensor2);
            tensor.hostToDevice();
            tensor2.hostToDevice();
            Tensor loss = this.network.loss(this.network.predict(tensor), tensor2);
            if (loss.isHasGPU()) {
                f = f2;
                sum = MatrixOperation.sum(loss.syncHost());
            } else {
                f = f2;
                sum = MatrixOperation.sum(loss.data);
            }
            f2 = f + sum;
        }
        float f3 = (f2 / intValue) / i;
        System.out.println("test[" + this.trainIndex + "] vail loss:{" + f3 + "} [costTime:" + ((System.nanoTime() - nanoTime) / 1000000.0d) + "ms.]");
        return f3;
    }

    public float testObjectRecognition(DetectionDataLoader detectionDataLoader, Tensor tensor, Tensor tensor2, int i) {
        float f;
        float sum;
        float f2 = 0.0f;
        long nanoTime = System.nanoTime();
        this.network.RUN_MODEL = RunModel.TEST;
        int intValue = new BigDecimal(detectionDataLoader.number).divide(new BigDecimal(i), 0, 0).intValue();
        for (int i2 = 0; i2 < intValue; i2++) {
            detectionDataLoader.loadData(i2, i, tensor, tensor2);
            Tensor loss = this.network.loss(this.network.predict(tensor), tensor2);
            if (loss.isHasGPU()) {
                f = f2;
                sum = MatrixOperation.sum(loss.syncHost());
            } else {
                f = f2;
                sum = MatrixOperation.sum(loss.data);
            }
            f2 = f + sum;
        }
        float f3 = (f2 / intValue) / i;
        System.out.println("test[" + this.trainIndex + "] vail loss:{" + f3 + "} [costTime:" + ((System.nanoTime() - nanoTime) / 1000000.0d) + "ms.]");
        return f3;
    }

    public float testObjectRecognitionOutputs(BaseData baseData, Tensor tensor, Tensor tensor2, int i) {
        long nanoTime = System.nanoTime();
        this.network.RUN_MODEL = RunModel.TEST;
        int intValue = new BigDecimal(baseData.number).divide(new BigDecimal(i), 0, 0).intValue();
        Yolo yolo = (Yolo) this.network;
        for (int i2 = 0; i2 < intValue; i2++) {
            baseData.getBatchData(i2, i, tensor, tensor2);
            tensor.hostToDevice();
            tensor2.hostToDevice();
            yolo.loss(yolo.predicts(tensor), tensor2);
        }
        System.out.println("test[" + this.trainIndex + "] [costTime:" + ((System.nanoTime() - nanoTime) / 1000000.0d) + "ms.]");
        return 0.0f;
    }

    public float testObjectRecognitionOutputs(BaseDataLoader baseDataLoader, Tensor tensor, Tensor tensor2, int i) {
        long nanoTime = System.nanoTime();
        this.network.RUN_MODEL = RunModel.TEST;
        int intValue = new BigDecimal(baseDataLoader.number).divide(new BigDecimal(i), 0, 0).intValue();
        Yolo yolo = (Yolo) this.network;
        for (int i2 = 0; i2 < intValue; i2++) {
            baseDataLoader.loadData(i2, i, tensor, tensor2);
            tensor.hostToDevice();
            tensor2.hostToDevice();
            if (yolo.outputNum > 1) {
                yolo.loss(yolo.predicts(tensor), tensor2);
            } else {
                yolo.loss(yolo.predict(tensor), tensor2);
            }
        }
        System.out.println("test[" + this.trainIndex + "] [costTime:" + ((System.nanoTime() - nanoTime) / 1000000.0d) + "ms.]");
        return 0.0f;
    }

    public float testObjectRecognitionOutputs(DetectionDataLoader detectionDataLoader, Tensor tensor, Tensor tensor2, int i) {
        long nanoTime = System.nanoTime();
        this.network.RUN_MODEL = RunModel.TEST;
        int intValue = new BigDecimal(detectionDataLoader.number).divide(new BigDecimal(i), 0, 0).intValue();
        Yolo yolo = (Yolo) this.network;
        for (int i2 = 0; i2 < intValue; i2++) {
            detectionDataLoader.loadData(i2, i, tensor, tensor2);
            if (yolo.outputNum > 1) {
                yolo.loss(yolo.predicts(tensor), tensor2);
            } else {
                yolo.loss(yolo.predict(tensor), tensor2);
            }
        }
        System.out.println("test[" + this.trainIndex + "] [costTime:" + ((System.nanoTime() - nanoTime) / 1000000.0d) + "ms.]");
        return 0.0f;
    }

    public List<Tensor> predictObjectRecognitionOutputs(DetectionDataLoader detectionDataLoader, Tensor tensor, int i) {
        ArrayList arrayList = new ArrayList();
        long nanoTime = System.nanoTime();
        this.network.RUN_MODEL = RunModel.TEST;
        int intValue = new BigDecimal(detectionDataLoader.number).divide(new BigDecimal(i), 0, 0).intValue();
        Yolo yolo = (Yolo) this.network;
        for (int i2 = 0; i2 < intValue; i2++) {
            detectionDataLoader.loadData(i2, i, tensor);
            if (yolo.outputNum > 1) {
                for (Tensor tensor2 : yolo.predicts(tensor)) {
                    arrayList.add(tensor2);
                }
            } else {
                arrayList.add(yolo.predict(tensor));
            }
        }
        System.out.println("test[" + this.trainIndex + "] [costTime:" + ((System.nanoTime() - nanoTime) / 1000000.0d) + "ms.]");
        return arrayList;
    }

    public float testObjectRecognitionOutputs(BaseData baseData, int i) {
        long nanoTime = System.nanoTime();
        this.network.RUN_MODEL = RunModel.TEST;
        int intValue = new BigDecimal(baseData.number).divide(new BigDecimal(i), 0, 0).intValue();
        Tensor tensor = new Tensor(i, baseData.channel, baseData.height, baseData.width, true);
        Yolo yolo = (Yolo) this.network;
        for (int i2 = 0; i2 < intValue; i2++) {
            baseData.getBatchData(i2, i, tensor);
            tensor.hostToDevice();
            Tensor[] predicts = yolo.predicts(tensor);
            for (int i3 = 0; i3 < yolo.outputLayers.size(); i3++) {
                YoloLayer yoloLayer = (YoloLayer) yolo.outputLayers.get(i3);
                for (int i4 = 0; i4 < predicts[i3].number; i4++) {
                    for (int i5 = 0; i5 < predicts[i3].height * predicts[i3].width; i5++) {
                        int i6 = i5 / predicts[i3].width;
                        int i7 = i5 % predicts[i3].width;
                        for (int i8 = 0; i8 < yoloLayer.bbox_num; i8++) {
                            float f = predicts[i3].data[entryIndex(i4, predicts[i3].width, predicts[i3].height, (i8 * predicts[i3].width * predicts[i3].height) + (i6 * predicts[i3].width) + i7, 4, yoloLayer.outputs, yoloLayer.class_number)];
                            if (f > 0.1f) {
                                System.out.println(f);
                            }
                        }
                    }
                }
            }
        }
        System.out.println("test[" + this.trainIndex + "] [costTime:" + ((System.nanoTime() - nanoTime) / 1000000.0d) + "ms.]");
        return 0.0f;
    }

    public static int entryIndex(int i, int i2, int i3, int i4, int i5, int i6, int i7) {
        return (i * i6) + ((i4 / (i2 * i3)) * i2 * i3 * (4 + i7 + 1)) + (i5 * i2 * i3) + (i4 % (i2 * i3));
    }

    public float[][][] showObjectRecognition(BaseData baseData, Tensor tensor, int i) {
        this.network.RUN_MODEL = RunModel.TEST;
        float[][][] fArr = new float[baseData.number][YoloDecode.grid_size * YoloDecode.grid_size * YoloDecode.bbox_num][YoloDecode.class_number + 1 + 4];
        int intValue = new BigDecimal(baseData.number).divide(new BigDecimal(i), 0, 0).intValue();
        for (int i2 = 0; i2 < intValue; i2++) {
            baseData.getBatchData(i2, i, tensor);
            tensor.hostToDevice();
            Tensor predict = this.network.predict(tensor);
            predict.syncHost();
            float[][][] detection = YoloDecode.getDetection(predict, baseData.width, baseData.height);
            if ((i2 + 1) * i > baseData.number) {
                System.arraycopy(detection, 0, fArr, i2 * i, ((i2 + 1) * i) - baseData.number);
            } else {
                System.arraycopy(detection, 0, fArr, i2 * i, i);
            }
        }
        return fArr;
    }

    public List<YoloBox> showObjectRecognitionYoloV3(BaseData baseData, int i) {
        this.network.RUN_MODEL = RunModel.TEST;
        ArrayList arrayList = new ArrayList();
        int intValue = new BigDecimal(baseData.number).divide(new BigDecimal(i), 0, 0).intValue();
        Tensor tensor = new Tensor(i, baseData.channel, baseData.height, baseData.width, true);
        Yolo yolo = (Yolo) this.network;
        for (int i2 = 0; i2 < intValue; i2++) {
            baseData.getBatchData(i2, i, tensor);
            tensor.hostToDevice();
            Tensor[] predicts = yolo.predicts(tensor);
            YoloBox[] yoloBoxArr = new YoloBox[tensor.number];
            for (int i3 = 0; i3 < yolo.outputLayers.size(); i3++) {
                YoloLayer yoloLayer = (YoloLayer) yolo.outputLayers.get(i3);
                YoloDetection[][] yoloDetections = YoloUtils.getYoloDetections(predicts[i3], yoloLayer.anchors, yoloLayer.mask, yoloLayer.bbox_num, yoloLayer.outputs, yoloLayer.class_number, baseData.width, baseData.height, 0.5f);
                for (int i4 = 0; i4 < yoloDetections.length; i4++) {
                    if (yoloBoxArr[i4] != null) {
                        yoloBoxArr[i4].getDets().addAll(new ArrayList(Arrays.asList(yoloDetections[i4])));
                    } else {
                        yoloBoxArr[i4] = new YoloBox(yoloDetections[i4]);
                    }
                }
            }
            arrayList.addAll(new ArrayList(Arrays.asList(yoloBoxArr)));
        }
        return arrayList;
    }

    public List<YoloBox> showObjectRecognitionYoloV3(DetectionDataLoader detectionDataLoader, int i) {
        this.network.RUN_MODEL = RunModel.TEST;
        ArrayList arrayList = new ArrayList();
        int intValue = new BigDecimal(detectionDataLoader.number).divide(new BigDecimal(i), 0, 0).intValue();
        Tensor tensor = new Tensor(i, this.network.channel, this.network.height, this.network.width, true);
        Yolo yolo = (Yolo) this.network;
        for (int i2 = 0; i2 < intValue; i2++) {
            detectionDataLoader.loadData(i2, i, tensor);
            Tensor[] predicts = yolo.predicts(tensor);
            YoloBox[] yoloBoxArr = new YoloBox[tensor.number];
            for (int i3 = 0; i3 < yolo.outputLayers.size(); i3++) {
                YoloLayer yoloLayer = (YoloLayer) yolo.outputLayers.get(i3);
                YoloDetection[][] yoloDetections = YoloUtils.getYoloDetections(predicts[i3], yoloLayer.anchors, yoloLayer.mask, yoloLayer.bbox_num, yoloLayer.outputs, yoloLayer.class_number, this.network.height, this.network.width, 0.5f);
                for (int i4 = 0; i4 < yoloDetections.length; i4++) {
                    nmsSort(yoloDetections[i4], yoloDetections[i4].length, yoloLayer.class_number, 0.7f);
                    if (yoloBoxArr[i4] != null) {
                        yoloBoxArr[i4].getDets().addAll(new ArrayList(Arrays.asList(yoloDetections[i4])));
                    } else {
                        yoloBoxArr[i4] = new YoloBox(yoloDetections[i4]);
                    }
                }
            }
            arrayList.addAll(new ArrayList(Arrays.asList(yoloBoxArr)));
        }
        return arrayList;
    }

    public List<YoloBox> showObjectRecognitionYoloV7(DetectionDataLoader detectionDataLoader, int i) {
        this.network.RUN_MODEL = RunModel.TEST;
        ArrayList arrayList = new ArrayList();
        int intValue = new BigDecimal(detectionDataLoader.number).divide(new BigDecimal(i), 0, 0).intValue();
        Tensor tensor = new Tensor(i, this.network.channel, this.network.height, this.network.width, true);
        Yolo yolo = (Yolo) this.network;
        for (int i2 = 0; i2 < intValue; i2++) {
            detectionDataLoader.loadData(i2, i, tensor);
            Tensor[] predicts = yolo.predicts(tensor);
            YoloBox[] yoloBoxArr = new YoloBox[tensor.number];
            for (int i3 = 0; i3 < yolo.outputLayers.size(); i3++) {
                YoloLayer yoloLayer = (YoloLayer) yolo.outputLayers.get(i3);
                YoloDetection[][] yoloDetectionsV7 = YoloUtils.getYoloDetectionsV7(predicts[i3], yoloLayer.anchors, yoloLayer.mask, yoloLayer.bbox_num, yoloLayer.outputs, yoloLayer.class_number, this.network.height, this.network.width, 0.5f);
                for (int i4 = 0; i4 < yoloDetectionsV7.length; i4++) {
                    nmsSort(yoloDetectionsV7[i4], yoloDetectionsV7[i4].length, yoloLayer.class_number, 0.7f);
                    if (yoloBoxArr[i4] != null) {
                        yoloBoxArr[i4].getDets().addAll(new ArrayList(Arrays.asList(yoloDetectionsV7[i4])));
                    } else {
                        yoloBoxArr[i4] = new YoloBox(yoloDetectionsV7[i4]);
                    }
                }
            }
            arrayList.addAll(new ArrayList(Arrays.asList(yoloBoxArr)));
        }
        return arrayList;
    }

    public void nmsSort(YoloDetection[] yoloDetectionArr, int i, int i2, float f) {
        int i3 = i - 1;
        int i4 = 0;
        while (i4 <= i3) {
            if (yoloDetectionArr[i4] == null || yoloDetectionArr[i4].getObjectness() == 0.0f) {
                YoloDetection yoloDetection = yoloDetectionArr[i4];
                yoloDetectionArr[i4] = yoloDetectionArr[i3];
                yoloDetectionArr[i3] = yoloDetection;
                i3--;
                i4--;
            }
            i4++;
        }
        int i5 = i3 + 1;
        for (int i6 = 0; i6 < i2; i6++) {
            for (int i7 = 0; i7 < i5; i7++) {
                if (yoloDetectionArr[i7] != null) {
                    yoloDetectionArr[i7].setSortClass(i6);
                }
            }
            Arrays.sort(yoloDetectionArr, new Comparator<YoloDetection>() { // from class: com.omega.engine.optimizer.Optimizer.1
                @Override // java.util.Comparator
                public int compare(YoloDetection yoloDetection2, YoloDetection yoloDetection3) {
                    float objectness = yoloDetection3.getSortClass() >= 0 ? yoloDetection2.getProb()[yoloDetection3.getSortClass()] - yoloDetection3.getProb()[yoloDetection3.getSortClass()] : yoloDetection2.getObjectness() - yoloDetection3.getObjectness();
                    if (objectness < 0.0f) {
                        return 1;
                    }
                    return objectness > 0.0f ? -1 : 0;
                }
            });
            for (int i8 = 0; i8 < i5; i8++) {
                if (yoloDetectionArr[i8].getProb()[i6] != 0.0f) {
                    float[] bbox = yoloDetectionArr[i8].getBbox();
                    for (int i9 = i8 + 1; i9 < i5; i9++) {
                        if (YoloUtils.box_iou(bbox, yoloDetectionArr[i9].getBbox()) > f) {
                            yoloDetectionArr[i9].getProb()[i6] = 0.0f;
                        }
                    }
                }
            }
        }
    }

    public float[][][] showObjectRecognition(BaseData baseData, int i) {
        this.network.RUN_MODEL = RunModel.TEST;
        float[][][] fArr = new float[baseData.number][YoloDecode.grid_size * YoloDecode.grid_size * YoloDecode.bbox_num][YoloDecode.class_number + 1 + 4];
        int intValue = new BigDecimal(baseData.number).divide(new BigDecimal(i), 0, 0).intValue();
        Tensor tensor = new Tensor(i, baseData.channel, baseData.height, baseData.width, true);
        for (int i2 = 0; i2 < intValue; i2++) {
            baseData.getBatchData(i2, i, tensor);
            tensor.hostToDevice();
            Tensor predict = this.network.predict(tensor);
            predict.syncHost();
            float[][][] detection = YoloDecode.getDetection(predict, baseData.width, baseData.height);
            if ((i2 + 1) * i > baseData.number) {
                System.arraycopy(detection, i - (baseData.number % i), fArr, i2 * i, baseData.number % i);
            } else {
                System.arraycopy(detection, 0, fArr, i2 * i, i);
            }
        }
        return fArr;
    }

    public float[][][] showObjectRecognition(YoloDataLoader yoloDataLoader, int i, int i2) {
        System.out.println("start object recognition.");
        long currentTimeMillis = System.currentTimeMillis();
        this.network.RUN_MODEL = RunModel.TEST;
        float[][][] fArr = new float[yoloDataLoader.number][YoloDecode.grid_size * YoloDecode.grid_size * YoloDecode.bbox_num][YoloDecode.class_number + 1 + 4];
        int intValue = new BigDecimal(yoloDataLoader.number).divide(new BigDecimal(i), 0, 0).intValue();
        Tensor tensor = new Tensor(i, this.network.channel, this.network.height, this.network.width, true);
        for (int i3 = 0; i3 < intValue; i3++) {
            yoloDataLoader.loadData(i3, i, tensor);
            tensor.hostToDevice();
            Tensor predict = this.network.predict(tensor);
            predict.syncHost();
            float[][][] detection = YoloDecode.getDetection(predict, yoloDataLoader.getDataSet().width, yoloDataLoader.getDataSet().height, i2);
            if ((i3 + 1) * i > yoloDataLoader.number) {
                System.arraycopy(detection, i - (yoloDataLoader.number % i), fArr, i3 * i, yoloDataLoader.number % i);
            } else {
                System.arraycopy(detection, 0, fArr, i3 * i, i);
            }
        }
        System.out.println("finish object recognition[" + ((System.currentTimeMillis() - currentTimeMillis) / 1000) + "s].");
        return fArr;
    }

    public float accuracy(Tensor tensor, Tensor tensor2, String[] strArr) {
        float f = 0.0f;
        for (int i = 0; i < tensor.number; i++) {
            if (LabelUtils.vectorTolabel(tensor2.getByNumber(i), strArr).equals(LabelUtils.vectorTolabel(tensor.getByNumber(i), strArr))) {
                f += 1.0f;
            }
        }
        return (f / tensor.number) * 100.0f;
    }

    public float accuracy(Tensor tensor, Tensor tensor2) {
        float f = 0.0f;
        for (int i = 0; i < tensor.number; i++) {
            if (MatrixOperation.maxIndex(tensor2.getByNumber(i)) == MatrixOperation.maxIndex(tensor.getByNumber(i))) {
                f += 1.0f;
            }
        }
        return (f / tensor.number) * 100.0f;
    }

    public float accuracy(Tensor tensor, Tensor tensor2, int i, int i2) {
        float f = 0.0f;
        for (int i3 = 0; i3 < i2; i3++) {
            boolean z = true;
            for (int i4 = 0; i4 < i; i4++) {
                int maxIndex = MatrixOperation.maxIndex(tensor.getByNumber((i4 * i2) + i3));
                int maxIndex2 = MatrixOperation.maxIndex(tensor2.getByNumber((i4 * i2) + i3));
                System.out.println(maxIndex + ":" + maxIndex2);
                if (maxIndex2 != maxIndex) {
                    z = false;
                }
            }
            if (z) {
                f += 1.0f;
            }
        }
        return (f / i2) * 100.0f;
    }

    public float accuracyBatchFisrt(Tensor tensor, Tensor tensor2, Tensor tensor3, int i, int i2, String[] strArr, int i3) {
        float f = 0.0f;
        int i4 = 0;
        String str = "";
        String str2 = "";
        String str3 = "";
        for (int i5 = 0; i5 < i2; i5++) {
            boolean z = true;
            int i6 = i;
            String str4 = "";
            String str5 = "";
            String str6 = "";
            for (int i7 = 0; i7 < i; i7++) {
                int maxIndex = MatrixOperation.maxIndex(tensor2.getByNumber((i5 * i) + i7));
                int maxIndex2 = MatrixOperation.maxIndex(tensor3.getByNumber((i5 * i) + i7));
                int i8 = (int) tensor.data[(i5 * i) + i7];
                str5 = str5 + strArr[maxIndex];
                str6 = str6 + strArr[maxIndex2];
                str4 = str4 + strArr[i8];
                if (maxIndex2 != i3 && maxIndex2 != maxIndex) {
                    z = false;
                    i6--;
                }
            }
            if (i4 <= i6) {
                i4 = i6;
                str = str4;
                str2 = str5;
                str3 = str6;
            }
            if (z) {
                f += 1.0f;
            }
        }
        System.out.println("max_score:" + i4);
        System.out.println("itxt:" + str);
        System.out.println("ptxt:" + str2);
        System.out.println("ltxt:" + str3);
        return (f / i2) * 100.0f;
    }

    public float accuracyBatchFisrt(Tensor tensor, Tensor tensor2, Tensor tensor3, int i, int i2, BPETokenizer bPETokenizer, int i3) {
        float f = 0.0f;
        int i4 = 0;
        String str = "";
        String str2 = "";
        String str3 = "";
        for (int i5 = 0; i5 < i2; i5++) {
            boolean z = true;
            int i6 = i;
            ArrayList arrayList = new ArrayList();
            ArrayList arrayList2 = new ArrayList();
            ArrayList arrayList3 = new ArrayList();
            for (int i7 = 0; i7 < i; i7++) {
                int maxIndex = MatrixOperation.maxIndex(tensor2.getByNumber((i5 * i) + i7));
                int maxIndex2 = MatrixOperation.maxIndex(tensor3.getByNumber((i5 * i) + i7));
                arrayList.add(Integer.valueOf((int) tensor.data[(i5 * i) + i7]));
                arrayList2.add(Integer.valueOf(maxIndex));
                arrayList3.add(Integer.valueOf(maxIndex2));
                if (maxIndex2 != i3 && maxIndex2 != maxIndex) {
                    z = false;
                    i6--;
                }
            }
            if (i4 <= i6) {
                i4 = i6;
                str = bPETokenizer.toText(arrayList);
                str2 = bPETokenizer.toText(arrayList2);
                str3 = bPETokenizer.toText(arrayList3);
            }
            if (z) {
                f += 1.0f;
            }
        }
        System.out.println("max_score:" + i4);
        System.out.println("itxt:" + str);
        System.out.println("ptxt:" + str2);
        System.out.println("ltxt:" + str3);
        return (f / i2) * 100.0f;
    }

    public float accuracyBatchFisrt(Tensor tensor, Tensor tensor2, Tensor tensor3, int i, int i2, String[] strArr) {
        float f = 0.0f;
        int i3 = 0;
        String str = "";
        String str2 = "";
        String str3 = "";
        for (int i4 = 0; i4 < i2; i4++) {
            boolean z = true;
            int i5 = i;
            String str4 = "";
            String str5 = "";
            String str6 = "";
            for (int i6 = 0; i6 < i; i6++) {
                int maxIndex = MatrixOperation.maxIndex(tensor2.getByNumber((i4 * i) + i6));
                int maxIndex2 = MatrixOperation.maxIndex(tensor3.getByNumber((i4 * i) + i6));
                int i7 = (int) tensor.data[(i4 * i) + i6];
                str5 = str5 + strArr[maxIndex];
                str6 = str6 + strArr[maxIndex2];
                str4 = str4 + strArr[i7];
                if (maxIndex2 != maxIndex) {
                    z = false;
                    i5--;
                }
            }
            if (i3 <= i5) {
                i3 = i5;
                str = str4;
                str2 = str5;
                str3 = str6;
            }
            if (z) {
                f += 1.0f;
            }
        }
        System.out.println("max_score:" + i3);
        System.out.println("itxt:" + str);
        System.out.println("ptxt:" + str2);
        System.out.println("ltxt:" + str3);
        return (f / i2) * 100.0f;
    }

    public float testLoss(Tensor tensor, Tensor tensor2) {
        float[] fArr = new float[tensor.number];
        float f = 0.0f;
        for (int i = 0; i < tensor.number; i++) {
            float testLoss = testLoss(tensor.getByNumber(i), tensor2.getByNumber(i));
            f += testLoss;
            fArr[i] = testLoss;
        }
        System.out.println("cpu_loss:" + JsonUtils.toJson(fArr));
        return f;
    }

    public int accuracyTrueCount(Tensor tensor, Tensor tensor2, String[] strArr) {
        int i = 0;
        for (int i2 = 0; i2 < tensor.number; i2++) {
            if (LabelUtils.vectorTolabel(tensor2.getByNumber(i2), strArr).equals(LabelUtils.vectorTolabel(tensor.getByNumber(i2), strArr))) {
                i++;
            }
        }
        return i;
    }

    public float testLoss(float[] fArr, float[] fArr2) {
        float f = 0.0f;
        float f2 = 0.0f;
        float max = MatrixOperation.max(fArr);
        for (float f3 : fArr) {
            f = (float) (f + Math.exp(f3 - max));
        }
        for (int i = 0; i < fArr.length; i++) {
            f2 += (float) ((-((fArr[i] - max) - Math.log(f))) * fArr2[i]);
        }
        return f2;
    }

    public static float testLoss2(float[] fArr, float[] fArr2) {
        float f = 0.0f;
        float f2 = 0.0f;
        float max = MatrixOperation.max(fArr);
        for (float f3 : fArr) {
            f = (float) (f + Math.exp(f3 - max));
        }
        for (int i = 0; i < fArr.length; i++) {
            f2 += (float) ((-((fArr[i] - max) - Math.log(f))) * fArr2[i]);
        }
        return f2;
    }

    public static void main(String[] strArr) {
        System.out.println(testLoss2(new float[]{0.6079413f, -1.1546507f, 1.444119f, 1.5811894f, 1.131686f, 1.5374337f, 0.39088273f, -0.19011068f, -0.010914803f, -1.4776193f}, new float[]{0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}));
    }

    public boolean isWarmUp() {
        return this.warmUp;
    }

    public void setWarmUp(boolean z) {
        this.warmUp = z;
    }

    public void online(boolean z) {
        this.isOnline = z;
    }

    public String getSid() {
        return this.sid;
    }

    public void setSid(String str) {
        this.sid = str;
    }
}
