package org.campagnelab.dl.framework.training;

import it.unimi.dsi.logging.ProgressLogger;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.parallelism.ParallelWrapper;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;

/* loaded from: input_file:org/campagnelab/dl/framework/training/ParallelTrainerOnGPU.class */
public class ParallelTrainerOnGPU implements Trainer {
    ParallelWrapper wrapper;
    int numExamplesPerIterator;
    int miniBatchSize;

    public ParallelTrainerOnGPU(ComputationGraph computationGraph, int i, int i2) {
        this.wrapper = new ParallelWrapper.Builder(computationGraph).prefetchBuffer(64).workers(4).averagingFrequency(1).reportScoreAfterAveraging(false).useLegacyAveraging(false).build();
        this.numExamplesPerIterator = i2;
        this.miniBatchSize = i;
    }

    @Override // org.campagnelab.dl.framework.training.Trainer
    public int train(ComputationGraph computationGraph, MultiDataSetIterator multiDataSetIterator, ProgressLogger progressLogger) {
        this.wrapper.fit(multiDataSetIterator);
        progressLogger.update(this.numExamplesPerIterator);
        return this.numExamplesPerIterator;
    }
}
