package org.deeplearning4j.rl4j.network;

import java.util.Collection;
import java.util.Iterator;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.layers.recurrent.RnnOutputLayer;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.rl4j.agent.learning.update.Features;
import org.deeplearning4j.rl4j.agent.learning.update.FeaturesLabels;
import org.deeplearning4j.rl4j.agent.learning.update.Gradients;
import org.deeplearning4j.rl4j.network.BaseNetwork;
import org.deeplearning4j.rl4j.observation.Observation;
import org.nd4j.linalg.api.ndarray.INDArray;

/* loaded from: input_file:org/deeplearning4j/rl4j/network/ComputationGraphHandler.class */
public class ComputationGraphHandler implements INetworkHandler {
    private final ComputationGraph model;
    private final boolean recurrent;
    private final ComputationGraphConfiguration configuration;
    private final String[] labelNames;
    private final String gradientName;
    private final int inputFeatureIdx;
    private final ChannelToNetworkInputMapper channelToNetworkInputMapper;

    public ComputationGraphHandler(ComputationGraph computationGraph, String[] strArr, String str, ChannelToNetworkInputMapper channelToNetworkInputMapper) {
        this.model = computationGraph;
        this.recurrent = computationGraph.getOutputLayer(0) instanceof RnnOutputLayer;
        this.configuration = computationGraph.getConfiguration();
        this.labelNames = strArr;
        this.gradientName = str;
        this.inputFeatureIdx = 0;
        this.channelToNetworkInputMapper = channelToNetworkInputMapper;
    }

    public ComputationGraphHandler(ComputationGraph computationGraph, String[] strArr, String str, int i) {
        this.model = computationGraph;
        this.recurrent = computationGraph.getOutputLayer(0) instanceof RnnOutputLayer;
        this.configuration = computationGraph.getConfiguration();
        this.labelNames = strArr;
        this.gradientName = str;
        this.inputFeatureIdx = i;
        this.channelToNetworkInputMapper = null;
    }

    @Override // org.deeplearning4j.rl4j.network.INetworkHandler
    public void notifyGradientCalculation() {
        Collection listeners = this.model.getListeners();
        if (listeners != null) {
            Iterator it = listeners.iterator();
            while (it.hasNext()) {
                ((TrainingListener) it.next()).onGradientCalculation(this.model);
            }
        }
    }

    @Override // org.deeplearning4j.rl4j.network.INetworkHandler
    public void notifyIterationDone() {
        BaseNetwork.ModelCounters modelCounters = getModelCounters();
        Collection listeners = this.model.getListeners();
        if (listeners != null) {
            Iterator it = listeners.iterator();
            while (it.hasNext()) {
                ((TrainingListener) it.next()).iterationDone(this.model, modelCounters.getIterationCount(), modelCounters.getEpochCount());
            }
        }
    }

    @Override // org.deeplearning4j.rl4j.network.INetworkHandler
    public void performFit(FeaturesLabels featuresLabels) {
        this.model.fit(buildInputs(featuresLabels.getFeatures()), buildLabels(featuresLabels));
    }

    @Override // org.deeplearning4j.rl4j.network.INetworkHandler
    public void performGradientsComputation(FeaturesLabels featuresLabels) {
        this.model.setInputs(buildInputs(featuresLabels.getFeatures()));
        this.model.setLabels(buildLabels(featuresLabels));
        this.model.computeGradientAndScore();
    }

    @Override // org.deeplearning4j.rl4j.network.INetworkHandler
    public void fillGradientsResponse(Gradients gradients) {
        gradients.putGradient(this.gradientName, this.model.gradient());
    }

    private INDArray[] buildLabels(FeaturesLabels featuresLabels) {
        int length = this.labelNames.length;
        INDArray[] iNDArrayArr = new INDArray[length];
        for (int i = 0; i < length; i++) {
            iNDArrayArr[i] = featuresLabels.getLabels(this.labelNames[i]);
        }
        return iNDArrayArr;
    }

    private BaseNetwork.ModelCounters getModelCounters() {
        return new BaseNetwork.ModelCounters(this.configuration.getIterationCount(), this.configuration.getEpochCount());
    }

    @Override // org.deeplearning4j.rl4j.network.INetworkHandler
    public void applyGradient(Gradients gradients, long j) {
        BaseNetwork.ModelCounters modelCounters = getModelCounters();
        int iterationCount = modelCounters.getIterationCount();
        Gradient gradient = gradients.getGradient(this.gradientName);
        this.model.getUpdater().update(gradient, iterationCount, modelCounters.getEpochCount(), (int) j, LayerWorkspaceMgr.noWorkspaces());
        this.model.params().subi(gradient.gradient());
        this.configuration.setIterationCount(iterationCount + 1);
    }

    @Override // org.deeplearning4j.rl4j.network.INetworkHandler
    public INDArray[] recurrentStepOutput(Observation observation) {
        return this.model.rnnTimeStep(buildInputs(observation));
    }

    @Override // org.deeplearning4j.rl4j.network.INetworkHandler
    public INDArray[] stepOutput(Observation observation) {
        return this.model.output(buildInputs(observation));
    }

    @Override // org.deeplearning4j.rl4j.network.INetworkHandler
    public INDArray[] batchOutput(Features features) {
        return this.model.output(buildInputs(features));
    }

    @Override // org.deeplearning4j.rl4j.network.INetworkHandler
    public void resetState() {
        this.model.rnnClearPreviousState();
    }

    @Override // org.deeplearning4j.rl4j.network.INetworkHandler
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public INetworkHandler m28clone() {
        return this.channelToNetworkInputMapper != null ? new ComputationGraphHandler(this.model.clone(), this.labelNames, this.gradientName, this.channelToNetworkInputMapper) : new ComputationGraphHandler(this.model.clone(), this.labelNames, this.gradientName, this.inputFeatureIdx);
    }

    @Override // org.deeplearning4j.rl4j.network.INetworkHandler
    public void copyFrom(INetworkHandler iNetworkHandler) {
        this.model.setParams(((ComputationGraphHandler) iNetworkHandler).model.params());
    }

    protected INDArray[] buildInputs(Observation observation) {
        return this.channelToNetworkInputMapper == null ? new INDArray[]{observation.getChannelData(this.inputFeatureIdx)} : this.channelToNetworkInputMapper.getNetworkInputs(observation);
    }

    protected INDArray[] buildInputs(Features features) {
        return this.channelToNetworkInputMapper == null ? new INDArray[]{features.get(this.inputFeatureIdx)} : this.channelToNetworkInputMapper.getNetworkInputs(features);
    }

    @Override // org.deeplearning4j.rl4j.network.INetworkHandler
    public boolean isRecurrent() {
        return this.recurrent;
    }
}
