package com.omega.engine.nn.network;

import com.omega.common.data.Tensor;
import com.omega.engine.gpu.CUDAMemoryManager;
import com.omega.engine.loss.LossFactory;
import com.omega.engine.loss.LossFunction;
import com.omega.engine.loss.LossType;
import com.omega.engine.nn.layer.Layer;
import com.omega.engine.nn.layer.LayerType;
import com.omega.engine.updater.UpdaterType;
import jcuda.runtime.JCuda;

/* loaded from: input_file:com/omega/engine/nn/network/Yolo.class */
public class Yolo extends OutputsNetwork {
    private LossFunction[] losses;
    private LossType lossType;
    private Tensor[] loss;
    private Tensor[] lossDiff;
    private int class_num = 1;

    public Yolo(LossFunction lossFunction) {
        this.lossFunction = lossFunction;
    }

    public Yolo(LossType lossType, UpdaterType updaterType) {
        this.lossType = lossType;
        this.updater = updaterType;
    }

    public void initLayer() {
        for (int i = 0; i < this.layerCount; i++) {
            this.layerList.get(i).init();
        }
    }

    @Override // com.omega.engine.nn.network.Network
    public void init() throws Exception {
        if (this.layerList.size() <= 0) {
            throw new Exception("layer size must greater than 2.");
        }
        this.layerCount = this.layerList.size();
        if (this.layerList.get(0).getLayerType() != LayerType.input) {
            throw new Exception("first layer must be input layer.");
        }
        if ((this.layerList.get(this.layerList.size() - 1).getLayerType() == LayerType.softmax || this.layerList.get(this.layerList.size() - 1).getLayerType() == LayerType.softmax_cross_entropy) && this.lossFunction.getLossType() != LossType.cross_entropy) {
            throw new Exception("The softmax function support only cross entropy loss function now.");
        }
        Layer layer = this.layerList.get(0);
        this.channel = layer.channel;
        this.height = layer.height;
        this.width = layer.width;
        if (this.outputNum > 1) {
            if (this.class_num == 1 || this.class_num == 0) {
                this.losses = LossFactory.create(this.lossType, this.outputLayers, this);
            } else {
                this.losses = LossFactory.create(this.lossType, this.outputLayers, this.class_num, this);
            }
            if (this.loss == null) {
                this.loss = new Tensor[this.outputNum];
            }
            if (this.lossDiff == null) {
                this.lossDiff = new Tensor[this.outputNum];
            }
        } else if (this.class_num == 1 || this.class_num == 0) {
            this.lossFunction = LossFactory.create(this.lossType);
        } else {
            this.lossFunction = LossFactory.create(this.lossType, this.class_num);
        }
        System.out.println("the network is ready.");
    }

    @Override // com.omega.engine.nn.network.Network
    public Tensor predict(Tensor tensor) {
        this.RUN_MODEL = RunModel.TEST;
        forward(tensor);
        return getOutput();
    }

    @Override // com.omega.engine.nn.network.Network
    public Tensor forward(Tensor tensor) {
        setInputData(tensor);
        for (int i = 0; i < this.layerCount; i++) {
            this.layerList.get(i).forward();
        }
        return getOutput();
    }

    @Override // com.omega.engine.nn.network.Network
    public void back(Tensor tensor) {
        setLossDiff(tensor);
        for (int i = this.layerCount - 1; i >= 0; i--) {
            Layer layer = this.layerList.get(i);
            layer.learnRate = this.learnRate;
            layer.back();
        }
    }

    @Override // com.omega.engine.nn.network.OutputsNetwork
    public void back(Tensor[] tensorArr) {
        setLossDiff(tensorArr);
        for (int i = this.layerCount - 1; i >= 0; i--) {
            Layer layer = this.layerList.get(i);
            layer.learnRate = this.learnRate;
            layer.back();
        }
    }

    @Override // com.omega.engine.nn.network.Network
    public Tensor loss(Tensor tensor, Tensor tensor2) {
        return this.lossFunction.loss(tensor, tensor2);
    }

    @Override // com.omega.engine.nn.network.OutputsNetwork
    public Tensor[] loss(Tensor tensor) {
        for (int i = 0; i < this.losses.length; i++) {
            this.loss[i] = this.losses[i].loss(getOutputs()[i], tensor);
        }
        return this.loss;
    }

    public Tensor[] loss(Tensor[] tensorArr, Tensor tensor) {
        for (int i = 0; i < this.losses.length; i++) {
            this.loss[i] = this.losses[i].loss(tensorArr[i], tensor);
        }
        return this.loss;
    }

    @Override // com.omega.engine.nn.network.Network
    public Tensor lossDiff(Tensor tensor, Tensor tensor2) {
        clearGrad();
        return this.lossFunction.diff(tensor, tensor2);
    }

    @Override // com.omega.engine.nn.network.OutputsNetwork
    public Tensor[] lossDiff(Tensor tensor) {
        clearGrad();
        for (int i = 0; i < this.losses.length; i++) {
            this.lossDiff[i] = this.losses[i].diff(getOutputs()[i], tensor);
        }
        return this.lossDiff;
    }

    @Override // com.omega.engine.nn.network.Network
    public NetworkType getNetworkType() {
        return NetworkType.YOLO;
    }

    @Override // com.omega.engine.nn.network.OutputsNetwork
    public Tensor[] predicts(Tensor tensor) {
        setInputData(tensor);
        for (int i = 0; i < this.layerCount; i++) {
            this.layerList.get(i).forward();
        }
        return getOutputs();
    }

    public int getClass_num() {
        return this.class_num;
    }

    public void setClass_num(int i) {
        this.class_num = i;
    }

    @Override // com.omega.engine.nn.network.Network
    public void clearGrad() {
        JCuda.cudaMemset(CUDAMemoryManager.workspace.getPointer(), 0, CUDAMemoryManager.workspace.getSize() * 4);
        for (int i = 0; i < this.layerCount; i++) {
            Layer layer = this.layerList.get(i);
            if (layer.cache_delta != null) {
                layer.cache_delta.clearGPU();
            }
        }
        JCuda.cudaDeviceSynchronize();
    }

    @Override // com.omega.engine.nn.network.Network
    public Tensor loss(Tensor tensor, Tensor tensor2, Tensor tensor3) {
        return this.lossFunction.loss(tensor, tensor2, tensor3);
    }

    @Override // com.omega.engine.nn.network.Network
    public Tensor lossDiff(Tensor tensor, Tensor tensor2, Tensor tensor3) {
        return this.lossFunction.diff(tensor, tensor2, tensor3);
    }
}
