package com.omega.engine.nn.network;

import com.omega.common.data.Tensor;
import com.omega.engine.loss.LossFactory;
import com.omega.engine.loss.LossFunction;
import com.omega.engine.loss.LossType;
import com.omega.engine.nn.layer.Layer;
import com.omega.engine.nn.layer.LayerType;
import com.omega.engine.nn.layer.SoftmaxWithCrossEntropyLayer;
import com.omega.engine.updater.UpdaterType;

/* loaded from: input_file:com/omega/engine/nn/network/RNN.class */
public class RNN extends Network {
    public int time;

    public RNN(LossFunction lossFunction) {
        this.time = 1;
        this.lossFunction = lossFunction;
    }

    public RNN(LossFunction lossFunction, UpdaterType updaterType) {
        this.time = 1;
        this.lossFunction = lossFunction;
        this.updater = updaterType;
    }

    public RNN(LossType lossType, UpdaterType updaterType) {
        this.time = 1;
        this.lossFunction = LossFactory.create(lossType);
        this.updater = updaterType;
    }

    public RNN(LossType lossType, UpdaterType updaterType, int i) {
        this.time = 1;
        this.lossFunction = LossFactory.create(lossType);
        this.updater = updaterType;
        this.time = i;
    }

    @Override // com.omega.engine.nn.network.Network
    public void init() throws Exception {
        if (this.layerList.size() <= 0) {
            throw new Exception("layer size must greater than 2.");
        }
        this.layerCount = this.layerList.size();
        this.channel = this.layerList.get(0).channel;
        this.height = this.layerList.get(0).height;
        this.width = this.layerList.get(0).width;
        this.oChannel = getLastLayer().oChannel;
        this.oHeight = getLastLayer().oHeight;
        this.oWidth = getLastLayer().oWidth;
        if (this.layerList.get(0).getLayerType() != LayerType.input) {
            throw new Exception("first layer must be input layer.");
        }
        if ((this.layerList.get(this.layerList.size() - 1).getLayerType() == LayerType.softmax || this.layerList.get(this.layerList.size() - 1).getLayerType() == LayerType.softmax_cross_entropy) && this.lossFunction.getLossType() != LossType.cross_entropy) {
            throw new Exception("The softmax function support only cross entropy loss function now.");
        }
        System.out.println("the network is ready.");
    }

    @Override // com.omega.engine.nn.network.Network
    public NetworkType getNetworkType() {
        return NetworkType.RNN;
    }

    @Override // com.omega.engine.nn.network.Network
    public Tensor predict(Tensor tensor) {
        this.RUN_MODEL = RunModel.TEST;
        forward(tensor);
        return getOutput();
    }

    @Override // com.omega.engine.nn.network.Network
    public Tensor forward(Tensor tensor) {
        setInputData(tensor);
        for (int i = 0; i < this.layerCount; i++) {
            this.layerList.get(i).forward();
        }
        return getOutput();
    }

    @Override // com.omega.engine.nn.network.Network
    public void back(Tensor tensor) {
        setLossDiff(tensor);
        for (int i = this.layerCount - 1; i >= 0; i--) {
            Layer layer = this.layerList.get(i);
            layer.learnRate = this.learnRate;
            layer.back();
        }
    }

    @Override // com.omega.engine.nn.network.Network
    public Tensor loss(Tensor tensor, Tensor tensor2) {
        switch (getLastLayer().getLayerType()) {
            case softmax_cross_entropy:
                ((SoftmaxWithCrossEntropyLayer) getLastLayer()).setCurrentLabel(tensor2);
                break;
        }
        return this.lossFunction.loss(tensor, tensor2);
    }

    @Override // com.omega.engine.nn.network.Network
    public Tensor lossDiff(Tensor tensor, Tensor tensor2) {
        return this.lossFunction.diff(tensor, tensor2);
    }

    @Override // com.omega.engine.nn.network.Network
    public void clearGrad() {
    }

    @Override // com.omega.engine.nn.network.Network
    public Tensor loss(Tensor tensor, Tensor tensor2, Tensor tensor3) {
        return this.lossFunction.loss(tensor, tensor2, tensor3);
    }

    @Override // com.omega.engine.nn.network.Network
    public Tensor lossDiff(Tensor tensor, Tensor tensor2, Tensor tensor3) {
        return this.lossFunction.diff(tensor, tensor2, tensor3);
    }
}
