package org.campagnelab.dl.framework.models;

import java.util.Arrays;
import java.util.Iterator;
import org.campagnelab.dl.framework.mappers.FeatureMapper;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/campagnelab/dl/framework/models/ModelOutputHelper.class */
public class ModelOutputHelper<RecordType> {
    private INDArray[] resultGraph;

    public void predictForNext(Model model, Iterator it) {
        if (model instanceof MultiLayerNetwork) {
            predictForNext((MultiLayerNetwork) model, (Iterator<DataSet>) it);
        } else if (model instanceof ComputationGraph) {
            predictForNext((ComputationGraph) model, (Iterator<MultiDataSet>) it);
        }
    }

    public void predictForNextRecord(Model model, RecordType recordtype, FeatureMapper... featureMapperArr) {
        if (model instanceof MultiLayerNetwork) {
            INDArray zeros = Nd4j.zeros(1, featureMapperArr[0].numberOfFeatures());
            featureMapperArr[0].prepareToNormalize(recordtype, 0);
            featureMapperArr[0].mapFeatures(recordtype, zeros, 0);
            Arrays.fill(this.resultGraph, (Object) null);
            this.resultGraph[0] = ((MultiLayerNetwork) model).output(zeros, false);
            return;
        }
        if (!(model instanceof ComputationGraph)) {
            throw new IllegalArgumentException("model is not of supported type: " + model.getClass().getCanonicalName());
        }
        ComputationGraph computationGraph = (ComputationGraph) model;
        INDArray[] iNDArrayArr = new INDArray[featureMapperArr.length];
        for (int i = 0; i < featureMapperArr.length; i++) {
            iNDArrayArr[i] = Nd4j.zeros(1, featureMapperArr[i].numberOfFeatures());
            featureMapperArr[i].prepareToNormalize(recordtype, 0);
            featureMapperArr[i].mapFeatures(recordtype, iNDArrayArr[i], 0);
        }
        this.resultGraph = computationGraph.output(false, iNDArrayArr);
    }

    public void predictForNext(ComputationGraph computationGraph, Iterator<MultiDataSet> it) {
        this.resultGraph = computationGraph.output(false, it.next().getFeatures());
    }

    public void predictForNext(MultiLayerNetwork multiLayerNetwork, Iterator<DataSet> it) {
        Arrays.fill(this.resultGraph, (Object) null);
        this.resultGraph[0] = multiLayerNetwork.output(it.next().getFeatures(), false);
    }

    public INDArray getOutput(int i) {
        return this.resultGraph[i];
    }
}
