package org.deeplearning4j.spark.models.sequencevectors.functions;

import java.util.Iterator;
import lombok.NonNull;
import org.apache.spark.Accumulator;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.broadcast.Broadcast;
import org.deeplearning4j.common.config.DL4JClassLoading;
import org.deeplearning4j.models.embeddings.loader.VectorsConfiguration;
import org.deeplearning4j.models.sequencevectors.sequence.Sequence;
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
import org.deeplearning4j.spark.models.sequencevectors.learning.SparkElementsLearningAlgorithm;
import org.nd4j.common.primitives.Counter;
import org.nd4j.common.primitives.Pair;
import org.nd4j.parameterserver.distributed.VoidParameterServer;
import org.nd4j.parameterserver.distributed.conf.VoidConfiguration;
import org.nd4j.parameterserver.distributed.messages.TrainingMessage;
import org.nd4j.parameterserver.distributed.training.TrainingDriver;
import org.nd4j.parameterserver.distributed.transport.RoutedTransport;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/spark/models/sequencevectors/functions/CountFunction.class */
public class CountFunction<T extends SequenceElement> implements Function<Sequence<T>, Pair<Sequence<T>, Long>> {
    private static final Logger log = LoggerFactory.getLogger(CountFunction.class);
    protected Accumulator<Counter<Long>> accumulator;
    protected boolean fetchLabels;
    protected Broadcast<VoidConfiguration> voidConfigurationBroadcast;
    protected Broadcast<VectorsConfiguration> vectorsConfigurationBroadcast;
    protected transient SparkElementsLearningAlgorithm ela;
    protected transient TrainingDriver<? extends TrainingMessage> driver;

    public CountFunction(@NonNull Broadcast<VectorsConfiguration> broadcast, @NonNull Broadcast<VoidConfiguration> broadcast2, @NonNull Accumulator<Counter<Long>> accumulator, boolean z) {
        if (broadcast == null) {
            throw new NullPointerException("vectorsConfigurationBroadcast is marked non-null but is null");
        }
        if (broadcast2 == null) {
            throw new NullPointerException("voidConfigurationBroadcast is marked non-null but is null");
        }
        if (accumulator == null) {
            throw new NullPointerException("accumulator is marked non-null but is null");
        }
        this.accumulator = accumulator;
        this.fetchLabels = z;
        this.voidConfigurationBroadcast = broadcast2;
        this.vectorsConfigurationBroadcast = broadcast;
    }

    public Pair<Sequence<T>, Long> call(Sequence<T> sequence) throws Exception {
        Counter counter = new Counter();
        long j = 0;
        if (this.ela == null) {
            this.ela = (SparkElementsLearningAlgorithm) DL4JClassLoading.createNewInstance(((VectorsConfiguration) this.vectorsConfigurationBroadcast.getValue()).getElementsLearningAlgorithm());
        }
        this.driver = this.ela.getTrainingDriver();
        VoidParameterServer.getInstance().init((VoidConfiguration) this.voidConfigurationBroadcast.getValue(), new RoutedTransport(), this.driver);
        for (SequenceElement sequenceElement : sequence.getElements()) {
            if (sequenceElement != null) {
                counter.incrementCount(sequenceElement.getStorageId(), 1.0d);
                j++;
            }
        }
        if (sequence.getSequenceLabels() != null) {
            Iterator it = sequence.getSequenceLabels().iterator();
            while (it.hasNext()) {
                counter.incrementCount(((SequenceElement) it.next()).getStorageId(), 1.0d);
            }
        }
        this.accumulator.add(counter);
        return Pair.makePair(sequence, Long.valueOf(j));
    }
}
