package com.omega.engine.nn.network;

import com.omega.common.data.Tensor;
import com.omega.engine.nn.layer.Layer;
import com.omega.engine.updater.UpdaterFactory;
import java.util.ArrayList;
import java.util.List;

/* loaded from: input_file:com/omega/engine/nn/network/OutputsNetwork.class */
public abstract class OutputsNetwork extends Network {
    public List<Layer> outputLayers = new ArrayList();
    public int outputNum = 0;

    public abstract Tensor[] predicts(Tensor tensor);

    public abstract Tensor[] loss(Tensor tensor);

    public abstract Tensor[] lossDiff(Tensor tensor);

    public abstract void back(Tensor[] tensorArr);

    public Tensor[] getOutputs() {
        Tensor[] tensorArr = new Tensor[this.outputNum];
        for (int i = 0; i < this.outputLayers.size(); i++) {
            tensorArr[i] = this.outputLayers.get(i).getOutput();
        }
        return tensorArr;
    }

    @Override // com.omega.engine.nn.network.Network
    public void addLayer(Layer layer) {
        layer.setNetwork(this);
        layer.setIndex(this.layerList.size());
        layer.setUpdater(UpdaterFactory.create(this.updater, this.updaterParams));
        if (layer.index <= 1) {
            layer.PROPAGATE_DOWN = false;
        }
        if (layer.isOutput) {
            this.outputNum++;
            this.outputLayers.add(layer);
        }
        this.layerList.add(layer);
    }

    public void setLossDiff(Tensor[] tensorArr) {
        for (int i = 0; i < this.outputLayers.size(); i++) {
            this.outputLayers.get(i).setDelta(tensorArr[i]);
        }
    }
}
