package org.deeplearning4j.spark.models.sequencevectors.functions;

import lombok.NonNull;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.broadcast.Broadcast;
import org.deeplearning4j.models.embeddings.loader.VectorsConfiguration;
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
import org.deeplearning4j.models.sequencevectors.sequence.ShallowSequenceElement;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.deeplearning4j.spark.models.sequencevectors.export.ExportContainer;
import org.nd4j.parameterserver.distributed.VoidParameterServer;
import org.nd4j.parameterserver.distributed.conf.VoidConfiguration;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/spark/models/sequencevectors/functions/DistributedFunction.class */
public class DistributedFunction<T extends SequenceElement> implements Function<T, ExportContainer<T>> {
    private static final Logger log = LoggerFactory.getLogger(DistributedFunction.class);
    protected Broadcast<VoidConfiguration> configurationBroadcast;
    protected Broadcast<VectorsConfiguration> vectorsConfigurationBroadcast;
    protected Broadcast<VocabCache<ShallowSequenceElement>> shallowVocabBroadcast;
    protected transient VocabCache<ShallowSequenceElement> shallowVocabCache;

    public DistributedFunction(@NonNull Broadcast<VoidConfiguration> broadcast, @NonNull Broadcast<VectorsConfiguration> broadcast2, @NonNull Broadcast<VocabCache<ShallowSequenceElement>> broadcast3) {
        if (broadcast == null) {
            throw new NullPointerException("configurationBroadcast is marked non-null but is null");
        }
        if (broadcast2 == null) {
            throw new NullPointerException("vectorsConfigurationBroadcast is marked non-null but is null");
        }
        if (broadcast3 == null) {
            throw new NullPointerException("shallowVocabBroadcast is marked non-null but is null");
        }
        this.configurationBroadcast = broadcast;
        this.vectorsConfigurationBroadcast = broadcast2;
        this.shallowVocabBroadcast = broadcast3;
    }

    public ExportContainer<T> call(T t) throws Exception {
        if (this.shallowVocabCache == null) {
            this.shallowVocabCache = (VocabCache) this.shallowVocabBroadcast.getValue();
        }
        ExportContainer<T> exportContainer = new ExportContainer<>();
        ShallowSequenceElement shallowSequenceElement = this.shallowVocabCache.tokenFor(t.getStorageId().longValue());
        t.setIndex(shallowSequenceElement.getIndex());
        exportContainer.setElement(t);
        exportContainer.setArray(VoidParameterServer.getInstance().getVector(shallowSequenceElement.getIndex()));
        return exportContainer;
    }
}
