package org.campagnelab.dl.framework.tools;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.function.Consumer;
import java.util.function.Predicate;
import org.campagnelab.dl.framework.domains.DomainDescriptor;
import org.campagnelab.dl.framework.domains.prediction.Prediction;
import org.campagnelab.dl.framework.domains.prediction.PredictionInterpreter;
import org.campagnelab.dl.framework.models.ModelOutputHelper;
import org.deeplearning4j.nn.api.Model;
import org.nd4j.linalg.api.ndarray.INDArray;

/* loaded from: input_file:org/campagnelab/dl/framework/tools/PredictWithModel.class */
public class PredictWithModel<RecordType> {
    DomainDescriptor<RecordType> domainDescriptor;
    ModelOutputHelper outputHelper = new ModelOutputHelper();
    private PredictionInterpreter[] interpretors;

    public PredictWithModel(DomainDescriptor<RecordType> domainDescriptor) {
        this.domainDescriptor = domainDescriptor;
        int i = 0;
        String[] outputNames = domainDescriptor.getComputationalGraph().getOutputNames();
        this.interpretors = new PredictionInterpreter[outputNames.length];
        for (String str : outputNames) {
            int i2 = i;
            i++;
            this.interpretors[i2] = domainDescriptor.getPredictionInterpreter(str);
        }
    }

    public void makePredictions(Iterator<RecordType> it, Model model, Consumer<List<Prediction>> consumer, Predicate<Integer> predicate) {
        makePredictions(it, model, obj -> {
        }, consumer, predicate);
    }

    public void makePredictions(Iterator<RecordType> it, Model model, Consumer<RecordType> consumer, Consumer<List<Prediction>> consumer2, Predicate<Integer> predicate) {
        int i = 0;
        ArrayList arrayList = new ArrayList();
        while (it.hasNext()) {
            RecordType next = it.next();
            consumer.accept(next);
            this.outputHelper.predictForNextRecord(model, next, this.domainDescriptor.featureMappers());
            arrayList.clear();
            for (int i2 = 0; i2 < this.domainDescriptor.getNumModelOutputs(); i2++) {
                INDArray output = this.outputHelper.getOutput(i2);
                if (this.interpretors[i2] != null) {
                    Prediction interpret = this.interpretors[i2].interpret(next, output);
                    interpret.outputIndex = i2;
                    interpret.index = i;
                    arrayList.add(interpret);
                }
            }
            consumer2.accept(arrayList);
            i++;
            if (predicate.test(Integer.valueOf(i))) {
                return;
            }
        }
    }
}
