package org.deeplearning4j.spark.models.sequencevectors;

import java.util.Arrays;
import java.util.List;
import java.util.Set;
import lombok.NonNull;
import org.apache.spark.Accumulator;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.storage.StorageLevel;
import org.deeplearning4j.common.config.DL4JClassLoading;
import org.deeplearning4j.exception.DL4JInvalidConfigException;
import org.deeplearning4j.models.embeddings.loader.VectorsConfiguration;
import org.deeplearning4j.models.sequencevectors.SequenceVectors;
import org.deeplearning4j.models.sequencevectors.sequence.Sequence;
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
import org.deeplearning4j.models.sequencevectors.sequence.ShallowSequenceElement;
import org.deeplearning4j.models.word2vec.Huffman;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache;
import org.deeplearning4j.spark.models.sequencevectors.export.ExportContainer;
import org.deeplearning4j.spark.models.sequencevectors.export.SparkModelExporter;
import org.deeplearning4j.spark.models.sequencevectors.functions.CountFunction;
import org.deeplearning4j.spark.models.sequencevectors.functions.DistributedFunction;
import org.deeplearning4j.spark.models.sequencevectors.functions.ElementsFrequenciesAccumulator;
import org.deeplearning4j.spark.models.sequencevectors.functions.ExtraCountFunction;
import org.deeplearning4j.spark.models.sequencevectors.functions.ExtraElementsFrequenciesAccumulator;
import org.deeplearning4j.spark.models.sequencevectors.functions.ListSequenceConvertFunction;
import org.deeplearning4j.spark.models.sequencevectors.functions.PartitionTrainingFunction;
import org.deeplearning4j.spark.models.sequencevectors.functions.TrainingFunction;
import org.deeplearning4j.spark.models.sequencevectors.functions.VocabRddFunctionFlat;
import org.deeplearning4j.spark.models.sequencevectors.learning.SparkElementsLearningAlgorithm;
import org.deeplearning4j.spark.models.sequencevectors.learning.SparkSequenceLearningAlgorithm;
import org.deeplearning4j.spark.models.sequencevectors.primitives.ExtraCounter;
import org.nd4j.common.primitives.Counter;
import org.nd4j.parameterserver.distributed.VoidParameterServer;
import org.nd4j.parameterserver.distributed.conf.VoidConfiguration;
import org.nd4j.parameterserver.distributed.enums.FaultToleranceStrategy;
import org.nd4j.parameterserver.distributed.transport.RoutedTransport;
import org.nd4j.parameterserver.distributed.util.NetworkInformation;
import org.nd4j.parameterserver.distributed.util.NetworkOrganizer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/spark/models/sequencevectors/SparkSequenceVectors.class */
public class SparkSequenceVectors<T extends SequenceElement> extends SequenceVectors<T> {
    private static final Logger log = LoggerFactory.getLogger(SparkSequenceVectors.class);
    protected Accumulator<Counter<Long>> elementsFreqAccum;
    protected Accumulator<ExtraCounter<Long>> elementsFreqAccumExtra;
    protected StorageLevel storageLevel;
    protected Broadcast<VocabCache<T>> vocabCacheBroadcast;
    protected Broadcast<VocabCache<ShallowSequenceElement>> shallowVocabCacheBroadcast;
    protected Broadcast<VectorsConfiguration> configurationBroadcast;
    protected transient boolean isEnvironmentReady;
    protected transient VocabCache<ShallowSequenceElement> shallowVocabCache;
    protected boolean isAutoDiscoveryMode;
    protected SparkModelExporter<T> exporter;
    protected SparkElementsLearningAlgorithm ela;
    protected SparkSequenceLearningAlgorithm sla;
    protected VoidConfiguration paramServerConfiguration;

    /* loaded from: input_file:org/deeplearning4j/spark/models/sequencevectors/SparkSequenceVectors$Builder.class */
    public static class Builder<T extends SequenceElement> {
        protected VectorsConfiguration configuration;
        protected SparkModelExporter<T> modelExporter;
        protected VoidConfiguration peersConfiguration;
        protected int workers;
        protected StorageLevel storageLevel;

        @Deprecated
        public Builder() {
            this(new VoidConfiguration(), new VectorsConfiguration());
        }

        public Builder(@NonNull VoidConfiguration voidConfiguration) {
            this(voidConfiguration, new VectorsConfiguration());
            if (voidConfiguration == null) {
                throw new NullPointerException("psConfiguration is marked non-null but is null");
            }
        }

        public Builder(@NonNull VoidConfiguration voidConfiguration, @NonNull VectorsConfiguration vectorsConfiguration) {
            if (voidConfiguration == null) {
                throw new NullPointerException("psConfiguration is marked non-null but is null");
            }
            if (vectorsConfiguration == null) {
                throw new NullPointerException("w2vConfiguration is marked non-null but is null");
            }
            this.configuration = vectorsConfiguration;
            this.peersConfiguration = voidConfiguration;
        }

        /* renamed from: setStorageLevel */
        public Builder<T> setStorageLevel2(StorageLevel storageLevel) {
            this.storageLevel = storageLevel;
            return this;
        }

        /* renamed from: minWordFrequency */
        public Builder<T> minWordFrequency2(int i) {
            this.configuration.setMinWordFrequency(i);
            return this;
        }

        /* renamed from: workers */
        public Builder<T> workers2(int i) {
            this.workers = i;
            return this;
        }

        /* renamed from: setLearningRate */
        public Builder<T> setLearningRate2(double d) {
            this.configuration.setLearningRate(d);
            return this;
        }

        /* renamed from: setParameterServerConfiguration */
        public Builder<T> setParameterServerConfiguration2(@NonNull VoidConfiguration voidConfiguration) {
            if (voidConfiguration == null) {
                throw new NullPointerException("configuration is marked non-null but is null");
            }
            this.peersConfiguration = voidConfiguration;
            return this;
        }

        /* renamed from: setModelExporter */
        public Builder<T> setModelExporter2(@NonNull SparkModelExporter<T> sparkModelExporter) {
            if (sparkModelExporter == null) {
                throw new NullPointerException("modelExporter is marked non-null but is null");
            }
            this.modelExporter = sparkModelExporter;
            return this;
        }

        /* renamed from: epochs */
        public Builder<T> epochs2(int i) {
            this.configuration.setEpochs(i);
            return this;
        }

        /* renamed from: iterations */
        public Builder<T> iterations2(int i) {
            this.configuration.setIterations(i);
            return this;
        }

        /* renamed from: subsampling */
        public Builder<T> subsampling2(double d) {
            this.configuration.setSampling(d);
            return this;
        }

        /* renamed from: useHierarchicSoftmax */
        public Builder<T> useHierarchicSoftmax2(boolean z) {
            this.configuration.setUseHierarchicSoftmax(z);
            return this;
        }

        /* renamed from: negativeSampling */
        public Builder<T> negativeSampling2(long j) {
            this.configuration.setNegative(j);
            return this;
        }

        /* renamed from: setElementsLearningAlgorithm */
        public Builder<T> setElementsLearningAlgorithm2(@NonNull SparkElementsLearningAlgorithm sparkElementsLearningAlgorithm) {
            if (sparkElementsLearningAlgorithm == null) {
                throw new NullPointerException("ela is marked non-null but is null");
            }
            this.configuration.setElementsLearningAlgorithm(sparkElementsLearningAlgorithm.getClass().getCanonicalName());
            return this;
        }

        /* renamed from: setSequenceLearningAlgorithm */
        public Builder<T> setSequenceLearningAlgorithm2(@NonNull SparkSequenceLearningAlgorithm sparkSequenceLearningAlgorithm) {
            if (sparkSequenceLearningAlgorithm == null) {
                throw new NullPointerException("sla is marked non-null but is null");
            }
            this.configuration.setSequenceLearningAlgorithm(sparkSequenceLearningAlgorithm.getClass().getCanonicalName());
            return this;
        }

        /* renamed from: layerSize */
        public Builder<T> layerSize2(int i) {
            if (i < 1) {
                throw new DL4JInvalidConfigException("LayerSize should be positive value");
            }
            this.configuration.setLayersSize(i);
            return this;
        }

        /* renamed from: build */
        public SparkSequenceVectors<T> build2() {
            if (this.modelExporter == null) {
                throw new IllegalStateException("ModelExporter is undefined!");
            }
            SparkSequenceVectors<T> sparkSequenceVectors = new SparkSequenceVectors<>(this.configuration);
            sparkSequenceVectors.exporter = this.modelExporter;
            sparkSequenceVectors.paramServerConfiguration = this.peersConfiguration;
            sparkSequenceVectors.storageLevel = this.storageLevel;
            sparkSequenceVectors.workers = this.workers;
            return sparkSequenceVectors;
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public SparkSequenceVectors() {
        this(new VectorsConfiguration());
    }

    protected SparkSequenceVectors(@NonNull VectorsConfiguration vectorsConfiguration) {
        this.storageLevel = StorageLevel.MEMORY_ONLY();
        this.isEnvironmentReady = false;
        this.isAutoDiscoveryMode = true;
        if (vectorsConfiguration == null) {
            throw new NullPointerException("configuration is marked non-null but is null");
        }
        this.configuration = vectorsConfiguration;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public VocabCache<ShallowSequenceElement> getShallowVocabCache() {
        return this.shallowVocabCache;
    }

    @Deprecated
    public void fit() {
        throw new UnsupportedOperationException("To use fit() method, please consider using standalone implementation");
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void validateConfiguration() {
        if (!this.configuration.isUseHierarchicSoftmax() && this.configuration.getNegative() == 0.0d) {
            throw new DL4JInvalidConfigException("Both HierarchicSoftmax and NegativeSampling are disabled. Nothing to learn here.");
        }
        if (this.configuration.getElementsLearningAlgorithm() == null && this.configuration.getSequenceLearningAlgorithm() == null) {
            throw new DL4JInvalidConfigException("No LearningAlgorithm was set. Nothing to learn here.");
        }
        if (this.exporter == null) {
            throw new DL4JInvalidConfigException("SparkModelExporter is undefined. No sense for training, if model won't be exported.");
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void broadcastEnvironment(JavaSparkContext javaSparkContext) {
        if (this.isEnvironmentReady) {
            return;
        }
        this.configurationBroadcast = javaSparkContext.broadcast(this.configuration);
        this.isEnvironmentReady = true;
    }

    public void fitLists(JavaRDD<List<T>> javaRDD) {
        fitSequences(javaRDD.map(new ListSequenceConvertFunction()));
    }

    public void fitSequences(JavaRDD<Sequence<T>> javaRDD) {
        Broadcast broadcast;
        long count;
        Counter<Long> counter;
        validateConfiguration();
        if (this.ela == null) {
            this.ela = (SparkElementsLearningAlgorithm) DL4JClassLoading.createNewInstance(this.configuration.getElementsLearningAlgorithm());
        }
        if (this.workers > 1) {
            log.info("Repartitioning corpus to {} parts...", Integer.valueOf(this.workers));
            javaRDD.repartition(this.workers);
        }
        if (this.storageLevel != null) {
            javaRDD.persist(this.storageLevel);
        }
        JavaSparkContext javaSparkContext = new JavaSparkContext(javaRDD.context());
        broadcastEnvironment(javaSparkContext);
        if (this.paramServerConfiguration == null) {
            this.paramServerConfiguration = VoidConfiguration.builder().numberOfShards(2).unicastPort(40123).multicastPort(40124).build();
            this.paramServerConfiguration.setFaultToleranceStrategy(FaultToleranceStrategy.NONE);
        }
        this.isAutoDiscoveryMode = this.paramServerConfiguration.getShardAddresses() == null || this.paramServerConfiguration.getShardAddresses().isEmpty();
        if (this.isAutoDiscoveryMode) {
            log.info("Trying auto discovery mode...");
            this.elementsFreqAccumExtra = javaRDD.context().accumulator(new ExtraCounter(), new ExtraElementsFrequenciesAccumulator());
            count = javaRDD.map(new ExtraCountFunction(this.elementsFreqAccumExtra, this.configuration.isTrainSequenceVectors())).count();
            counter = (Counter) this.elementsFreqAccumExtra.value();
            Set<NetworkInformation> networkInformation = ((ExtraCounter) counter).getNetworkInformation();
            log.info("availableHosts: {}", networkInformation);
            if (networkInformation.size() > 1) {
                NetworkOrganizer networkOrganizer = new NetworkOrganizer(networkInformation, this.paramServerConfiguration.getNetworkMask());
                this.paramServerConfiguration.setShardAddresses(networkOrganizer.getSubset(this.paramServerConfiguration.getNumberOfShards()));
                if (this.paramServerConfiguration.getFaultToleranceStrategy() != FaultToleranceStrategy.NONE) {
                    this.paramServerConfiguration.setBackupAddresses(networkOrganizer.getSubset(this.paramServerConfiguration.getNumberOfShards(), this.paramServerConfiguration.getShardAddresses()));
                }
            } else {
                this.paramServerConfiguration.setShardAddresses(Arrays.asList("127.0.0.1:" + this.paramServerConfiguration.getPortSupplier().getPort()));
                this.paramServerConfiguration.setFaultToleranceStrategy(FaultToleranceStrategy.NONE);
            }
            log.info("Got Shards so far: {}", this.paramServerConfiguration.getShardAddresses());
            this.paramServerConfiguration.setNumberOfShards(this.paramServerConfiguration.getShardAddresses().size());
            this.paramServerConfiguration.setUseHS(this.configuration.isUseHierarchicSoftmax());
            this.paramServerConfiguration.setUseNS(this.configuration.getNegative() > 0.0d);
            broadcast = javaSparkContext.broadcast(this.paramServerConfiguration);
        } else {
            this.paramServerConfiguration.setNumberOfShards(this.paramServerConfiguration.getShardAddresses().size());
            this.paramServerConfiguration.setUseHS(this.configuration.isUseHierarchicSoftmax());
            this.paramServerConfiguration.setUseNS(this.configuration.getNegative() > 0.0d);
            broadcast = javaSparkContext.broadcast(this.paramServerConfiguration);
            this.elementsFreqAccum = javaRDD.context().accumulator(new Counter(), new ElementsFrequenciesAccumulator());
            count = javaRDD.map(new CountFunction(this.configurationBroadcast, broadcast, this.elementsFreqAccum, this.configuration.isTrainSequenceVectors())).count();
            counter = (Counter) this.elementsFreqAccum.value();
        }
        long j = (long) counter.totalCount();
        long size = counter.size();
        log.info("Total number of sequences: {}; Total number of elements entries: {}; Total number of unique elements: {}", new Object[]{Long.valueOf(count), Long.valueOf(j), Long.valueOf(size)});
        this.shallowVocabCache = buildShallowVocabCache(counter);
        this.shallowVocabCacheBroadcast = javaSparkContext.broadcast(this.shallowVocabCache);
        JavaRDD distinct = javaRDD.flatMap(new VocabRddFunctionFlat(this.configurationBroadcast, broadcast)).distinct();
        distinct.count();
        VoidParameterServer.getInstance().init(this.paramServerConfiguration, new RoutedTransport(), this.ela.getTrainingDriver());
        VoidParameterServer.getInstance().initializeSeqVec(this.configuration.getLayersSize(), (int) size, 119L, this.configuration.getLayersSize() / this.paramServerConfiguration.getNumberOfShards(), this.paramServerConfiguration.isUseHS(), this.paramServerConfiguration.isUseNS());
        new TrainingFunction(this.shallowVocabCacheBroadcast, this.configurationBroadcast, broadcast);
        PartitionTrainingFunction partitionTrainingFunction = new PartitionTrainingFunction(this.shallowVocabCacheBroadcast, this.configurationBroadcast, broadcast);
        if (this.configuration != null) {
            for (int i = 0; i < this.configuration.getEpochs(); i++) {
                javaRDD.foreachPartition(partitionTrainingFunction);
            }
        }
        JavaRDD<ExportContainer<T>> map = distinct.map(new DistributedFunction(broadcast, this.configurationBroadcast, this.shallowVocabCacheBroadcast));
        if (this.exporter != null) {
            this.exporter.export(map);
        }
        if (this.storageLevel != null) {
            javaRDD.unpersist();
        }
        log.info("Training finish, starting cleanup...");
        VoidParameterServer.getInstance().shutdown();
    }

    protected VocabCache<ShallowSequenceElement> buildShallowVocabCache(Counter<Long> counter) {
        AbstractCache abstractCache = new AbstractCache();
        for (Long l : counter.keySet()) {
            abstractCache.addToken(new ShallowSequenceElement(counter.getCount(l), l.longValue()));
        }
        Huffman huffman = new Huffman(abstractCache.vocabWords());
        huffman.build();
        huffman.applyIndexes(abstractCache);
        return abstractCache;
    }

    protected Counter<Long> getCounter() {
        return this.isAutoDiscoveryMode ? (Counter) this.elementsFreqAccumExtra.value() : (Counter) this.elementsFreqAccum.value();
    }

    public SparkModelExporter<T> getExporter() {
        return this.exporter;
    }

    public void setExporter(SparkModelExporter<T> sparkModelExporter) {
        this.exporter = sparkModelExporter;
    }
}
