package org.deeplearning4j.spark.models.sequencevectors.functions;

import java.util.Iterator;
import java.util.concurrent.atomic.AtomicLong;
import lombok.NonNull;
import org.apache.spark.api.java.function.VoidFunction;
import org.apache.spark.broadcast.Broadcast;
import org.deeplearning4j.common.config.DL4JClassLoading;
import org.deeplearning4j.models.embeddings.loader.VectorsConfiguration;
import org.deeplearning4j.models.sequencevectors.sequence.Sequence;
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
import org.deeplearning4j.models.sequencevectors.sequence.ShallowSequenceElement;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.deeplearning4j.spark.models.sequencevectors.learning.SparkElementsLearningAlgorithm;
import org.deeplearning4j.spark.models.sequencevectors.learning.SparkSequenceLearningAlgorithm;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.parameterserver.distributed.VoidParameterServer;
import org.nd4j.parameterserver.distributed.conf.VoidConfiguration;
import org.nd4j.parameterserver.distributed.messages.TrainingMessage;
import org.nd4j.parameterserver.distributed.training.TrainingDriver;
import org.nd4j.parameterserver.distributed.transport.RoutedTransport;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/spark/models/sequencevectors/functions/TrainingFunction.class */
public class TrainingFunction<T extends SequenceElement> implements VoidFunction<Sequence<T>> {
    private static final Logger log = LoggerFactory.getLogger(TrainingFunction.class);
    protected Broadcast<VocabCache<ShallowSequenceElement>> vocabCacheBroadcast;
    protected Broadcast<VectorsConfiguration> configurationBroadcast;
    protected Broadcast<VoidConfiguration> paramServerConfigurationBroadcast;
    protected transient VoidParameterServer paramServer;
    protected transient VectorsConfiguration vectorsConfiguration;
    protected transient SparkElementsLearningAlgorithm elementsLearningAlgorithm;
    protected transient SparkSequenceLearningAlgorithm sequenceLearningAlgorithm;
    protected transient VocabCache<ShallowSequenceElement> shallowVocabCache;
    protected transient TrainingDriver<? extends TrainingMessage> driver;

    public TrainingFunction(@NonNull Broadcast<VocabCache<ShallowSequenceElement>> broadcast, @NonNull Broadcast<VectorsConfiguration> broadcast2, @NonNull Broadcast<VoidConfiguration> broadcast3) {
        if (broadcast == null) {
            throw new NullPointerException("vocabCacheBroadcast is marked non-null but is null");
        }
        if (broadcast2 == null) {
            throw new NullPointerException("vectorsConfigurationBroadcast is marked non-null but is null");
        }
        if (broadcast3 == null) {
            throw new NullPointerException("paramServerConfigurationBroadcast is marked non-null but is null");
        }
        this.vocabCacheBroadcast = broadcast;
        this.configurationBroadcast = broadcast2;
        this.paramServerConfigurationBroadcast = broadcast3;
    }

    public void call(Sequence<T> sequence) throws Exception {
        if (this.vectorsConfiguration == null) {
            this.vectorsConfiguration = (VectorsConfiguration) this.configurationBroadcast.getValue();
        }
        String elementsLearningAlgorithm = this.vectorsConfiguration.getElementsLearningAlgorithm();
        if (this.paramServer == null) {
            this.paramServer = VoidParameterServer.getInstance();
            if (this.elementsLearningAlgorithm == null) {
                this.elementsLearningAlgorithm = (SparkElementsLearningAlgorithm) DL4JClassLoading.createNewInstance(elementsLearningAlgorithm, new Object[0]);
            }
            this.driver = this.elementsLearningAlgorithm.getTrainingDriver();
            this.paramServer.init((VoidConfiguration) this.paramServerConfigurationBroadcast.getValue(), new RoutedTransport(), this.driver);
        }
        if (this.vectorsConfiguration == null) {
            this.vectorsConfiguration = (VectorsConfiguration) this.configurationBroadcast.getValue();
        }
        if (this.shallowVocabCache == null) {
            this.shallowVocabCache = (VocabCache) this.vocabCacheBroadcast.getValue();
        }
        if (this.elementsLearningAlgorithm == null && elementsLearningAlgorithm != null) {
            this.elementsLearningAlgorithm = (SparkElementsLearningAlgorithm) DL4JClassLoading.createNewInstance(elementsLearningAlgorithm, new Object[0]);
            this.elementsLearningAlgorithm.configure(this.shallowVocabCache, null, this.vectorsConfiguration);
        }
        String sequenceLearningAlgorithm = this.vectorsConfiguration.getSequenceLearningAlgorithm();
        if (this.sequenceLearningAlgorithm == null && sequenceLearningAlgorithm != null) {
            this.sequenceLearningAlgorithm = (SparkSequenceLearningAlgorithm) DL4JClassLoading.createNewInstance(sequenceLearningAlgorithm, new Object[0]);
            this.sequenceLearningAlgorithm.configure(this.shallowVocabCache, null, this.vectorsConfiguration);
        }
        if (this.elementsLearningAlgorithm == null && this.sequenceLearningAlgorithm == null) {
            throw new ND4JIllegalStateException("No LearningAlgorithms specified!");
        }
        Sequence<ShallowSequenceElement> sequence2 = new Sequence<>();
        Iterator it = sequence.getElements().iterator();
        while (it.hasNext()) {
            ShallowSequenceElement shallowSequenceElement = this.shallowVocabCache.tokenFor(((SequenceElement) it.next()).getStorageId().longValue());
            if (shallowSequenceElement != null) {
                sequence2.addElement(shallowSequenceElement);
            }
        }
        if (this.sequenceLearningAlgorithm != null && this.vectorsConfiguration.isTrainSequenceVectors()) {
            Iterator it2 = sequence.getSequenceLabels().iterator();
            while (it2.hasNext()) {
                ShallowSequenceElement shallowSequenceElement2 = this.shallowVocabCache.tokenFor(((SequenceElement) it2.next()).getStorageId().longValue());
                if (shallowSequenceElement2 != null) {
                    sequence2.addSequenceLabel(shallowSequenceElement2);
                }
            }
        }
        if (sequence.size() > 0) {
            this.paramServer.execDistributed(this.elementsLearningAlgorithm.frameSequence(sequence2, new AtomicLong(119L), 0.025d));
        } else {
            log.warn("Skipping empty sequence...");
        }
    }
}
