package org.deeplearning4j.spark.models.sequencevectors.functions;

import java.util.ArrayList;
import java.util.Iterator;
import lombok.NonNull;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.broadcast.Broadcast;
import org.deeplearning4j.common.config.DL4JClassLoading;
import org.deeplearning4j.models.embeddings.loader.VectorsConfiguration;
import org.deeplearning4j.models.sequencevectors.sequence.Sequence;
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
import org.deeplearning4j.spark.models.sequencevectors.learning.SparkElementsLearningAlgorithm;
import org.nd4j.parameterserver.distributed.VoidParameterServer;
import org.nd4j.parameterserver.distributed.conf.VoidConfiguration;
import org.nd4j.parameterserver.distributed.messages.TrainingMessage;
import org.nd4j.parameterserver.distributed.training.TrainingDriver;
import org.nd4j.parameterserver.distributed.transport.RoutedTransport;

/* loaded from: input_file:org/deeplearning4j/spark/models/sequencevectors/functions/VocabRddFunctionFlat.class */
public class VocabRddFunctionFlat<T extends SequenceElement> implements FlatMapFunction<Sequence<T>, T> {
    protected Broadcast<VectorsConfiguration> vectorsConfigurationBroadcast;
    protected Broadcast<VoidConfiguration> paramServerConfigurationBroadcast;
    protected transient VectorsConfiguration configuration;
    protected transient SparkElementsLearningAlgorithm ela;
    protected transient TrainingDriver<? extends TrainingMessage> driver;

    public VocabRddFunctionFlat(@NonNull Broadcast<VectorsConfiguration> broadcast, @NonNull Broadcast<VoidConfiguration> broadcast2) {
        if (broadcast == null) {
            throw new NullPointerException("vectorsConfigurationBroadcast is marked non-null but is null");
        }
        if (broadcast2 == null) {
            throw new NullPointerException("paramServerConfigurationBroadcast is marked non-null but is null");
        }
        this.vectorsConfigurationBroadcast = broadcast;
        this.paramServerConfigurationBroadcast = broadcast2;
    }

    public Iterator<T> call(Sequence<T> sequence) throws Exception {
        if (this.configuration == null) {
            this.configuration = (VectorsConfiguration) this.vectorsConfigurationBroadcast.getValue();
        }
        if (this.ela == null) {
            this.ela = (SparkElementsLearningAlgorithm) DL4JClassLoading.createNewInstance(this.configuration.getElementsLearningAlgorithm(), new Object[0]);
        }
        this.driver = this.ela.getTrainingDriver();
        VoidParameterServer.getInstance().init((VoidConfiguration) this.paramServerConfigurationBroadcast.getValue(), new RoutedTransport(), this.driver);
        ArrayList arrayList = new ArrayList();
        arrayList.addAll(sequence.getElements());
        if (this.configuration.isTrainSequenceVectors() && !sequence.getSequenceLabels().isEmpty()) {
            arrayList.addAll(sequence.getSequenceLabels());
        }
        return arrayList.iterator();
    }
}
