package org.deeplearning4j.spark.models.word2vec;

import lombok.NonNull;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.storage.StorageLevel;
import org.deeplearning4j.exception.DL4JInvalidConfigException;
import org.deeplearning4j.models.embeddings.loader.VectorsConfiguration;
import org.deeplearning4j.models.sequencevectors.sequence.ShallowSequenceElement;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.deeplearning4j.spark.models.sequencevectors.SparkSequenceVectors;
import org.deeplearning4j.spark.models.sequencevectors.export.SparkModelExporter;
import org.deeplearning4j.spark.models.sequencevectors.functions.TokenizerFunction;
import org.deeplearning4j.spark.models.sequencevectors.learning.SparkElementsLearningAlgorithm;
import org.deeplearning4j.spark.models.sequencevectors.learning.SparkSequenceLearningAlgorithm;
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
import org.nd4j.parameterserver.distributed.conf.VoidConfiguration;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/spark/models/word2vec/SparkWord2Vec.class */
public class SparkWord2Vec extends SparkSequenceVectors<VocabWord> {
    private static final Logger log = LoggerFactory.getLogger(SparkWord2Vec.class);

    /* loaded from: input_file:org/deeplearning4j/spark/models/word2vec/SparkWord2Vec$Builder.class */
    public static class Builder extends SparkSequenceVectors.Builder<VocabWord> {
        @Deprecated
        public Builder() {
        }

        public Builder(@NonNull VoidConfiguration voidConfiguration) {
            super(voidConfiguration);
            if (voidConfiguration == null) {
                throw new NullPointerException("psConfiguration is marked non-null but is null");
            }
        }

        public Builder(@NonNull VoidConfiguration voidConfiguration, @NonNull VectorsConfiguration vectorsConfiguration) {
            super(voidConfiguration, vectorsConfiguration);
            if (voidConfiguration == null) {
                throw new NullPointerException("psConfiguration is marked non-null but is null");
            }
            if (vectorsConfiguration == null) {
                throw new NullPointerException("configuration is marked non-null but is null");
            }
        }

        public Builder setTokenizerFactory(@NonNull TokenizerFactory tokenizerFactory) {
            if (tokenizerFactory == null) {
                throw new NullPointerException("tokenizerFactory is marked non-null but is null");
            }
            this.configuration.setTokenizerFactory(tokenizerFactory.getClass().getCanonicalName());
            if (tokenizerFactory.getTokenPreProcessor() != null) {
                this.configuration.setTokenPreProcessor(tokenizerFactory.getTokenPreProcessor().getClass().getCanonicalName());
            }
            return this;
        }

        public Builder setLearningAlgorithm(@NonNull SparkElementsLearningAlgorithm sparkElementsLearningAlgorithm) {
            if (sparkElementsLearningAlgorithm == null) {
                throw new NullPointerException("ela is marked non-null but is null");
            }
            this.configuration.setElementsLearningAlgorithm(sparkElementsLearningAlgorithm.getClass().getCanonicalName());
            return this;
        }

        @Override // org.deeplearning4j.spark.models.sequencevectors.SparkSequenceVectors.Builder
        /* renamed from: setModelExporter, reason: merged with bridge method [inline-methods] */
        public SparkSequenceVectors.Builder<VocabWord> setModelExporter2(@NonNull SparkModelExporter<VocabWord> sparkModelExporter) {
            if (sparkModelExporter == null) {
                throw new NullPointerException("exporter is marked non-null but is null");
            }
            this.modelExporter = sparkModelExporter;
            return this;
        }

        @Override // org.deeplearning4j.spark.models.sequencevectors.SparkSequenceVectors.Builder
        /* renamed from: workers, reason: merged with bridge method [inline-methods] */
        public SparkSequenceVectors.Builder<VocabWord> workers2(int i) {
            super.workers2(i);
            return this;
        }

        @Override // org.deeplearning4j.spark.models.sequencevectors.SparkSequenceVectors.Builder
        /* renamed from: epochs, reason: merged with bridge method [inline-methods] */
        public SparkSequenceVectors.Builder<VocabWord> epochs2(int i) {
            super.epochs2(i);
            return this;
        }

        @Override // org.deeplearning4j.spark.models.sequencevectors.SparkSequenceVectors.Builder
        /* renamed from: setStorageLevel, reason: merged with bridge method [inline-methods] */
        public SparkSequenceVectors.Builder<VocabWord> setStorageLevel2(StorageLevel storageLevel) {
            super.setStorageLevel2(storageLevel);
            return this;
        }

        @Override // org.deeplearning4j.spark.models.sequencevectors.SparkSequenceVectors.Builder
        /* renamed from: minWordFrequency, reason: merged with bridge method [inline-methods] */
        public SparkSequenceVectors.Builder<VocabWord> minWordFrequency2(int i) {
            super.minWordFrequency2(i);
            return this;
        }

        @Override // org.deeplearning4j.spark.models.sequencevectors.SparkSequenceVectors.Builder
        /* renamed from: setLearningRate, reason: merged with bridge method [inline-methods] */
        public SparkSequenceVectors.Builder<VocabWord> setLearningRate2(double d) {
            super.setLearningRate2(d);
            return this;
        }

        @Override // org.deeplearning4j.spark.models.sequencevectors.SparkSequenceVectors.Builder
        /* renamed from: setParameterServerConfiguration, reason: merged with bridge method [inline-methods] */
        public SparkSequenceVectors.Builder<VocabWord> setParameterServerConfiguration2(@NonNull VoidConfiguration voidConfiguration) {
            if (voidConfiguration == null) {
                throw new NullPointerException("configuration is marked non-null but is null");
            }
            super.setParameterServerConfiguration2(voidConfiguration);
            return this;
        }

        @Override // org.deeplearning4j.spark.models.sequencevectors.SparkSequenceVectors.Builder
        /* renamed from: iterations, reason: merged with bridge method [inline-methods] */
        public SparkSequenceVectors.Builder<VocabWord> iterations2(int i) {
            super.iterations2(i);
            return this;
        }

        @Override // org.deeplearning4j.spark.models.sequencevectors.SparkSequenceVectors.Builder
        /* renamed from: subsampling, reason: merged with bridge method [inline-methods] */
        public SparkSequenceVectors.Builder<VocabWord> subsampling2(double d) {
            super.subsampling2(d);
            return this;
        }

        @Override // org.deeplearning4j.spark.models.sequencevectors.SparkSequenceVectors.Builder
        /* renamed from: negativeSampling, reason: merged with bridge method [inline-methods] */
        public SparkSequenceVectors.Builder<VocabWord> negativeSampling2(long j) {
            super.negativeSampling2(j);
            return this;
        }

        @Override // org.deeplearning4j.spark.models.sequencevectors.SparkSequenceVectors.Builder
        /* renamed from: setElementsLearningAlgorithm, reason: merged with bridge method [inline-methods] */
        public SparkSequenceVectors.Builder<VocabWord> setElementsLearningAlgorithm2(@NonNull SparkElementsLearningAlgorithm sparkElementsLearningAlgorithm) {
            if (sparkElementsLearningAlgorithm == null) {
                throw new NullPointerException("ela is marked non-null but is null");
            }
            super.setElementsLearningAlgorithm2(sparkElementsLearningAlgorithm);
            return this;
        }

        @Override // org.deeplearning4j.spark.models.sequencevectors.SparkSequenceVectors.Builder
        /* renamed from: setSequenceLearningAlgorithm, reason: merged with bridge method [inline-methods] */
        public SparkSequenceVectors.Builder<VocabWord> setSequenceLearningAlgorithm2(@NonNull SparkSequenceLearningAlgorithm sparkSequenceLearningAlgorithm) {
            if (sparkSequenceLearningAlgorithm == null) {
                throw new NullPointerException("sla is marked non-null but is null");
            }
            throw new UnsupportedOperationException("This method isn't supported by Word2Vec");
        }

        @Override // org.deeplearning4j.spark.models.sequencevectors.SparkSequenceVectors.Builder
        /* renamed from: useHierarchicSoftmax, reason: merged with bridge method [inline-methods] */
        public SparkSequenceVectors.Builder<VocabWord> useHierarchicSoftmax2(boolean z) {
            super.useHierarchicSoftmax2(z);
            return this;
        }

        @Override // org.deeplearning4j.spark.models.sequencevectors.SparkSequenceVectors.Builder
        /* renamed from: layerSize, reason: merged with bridge method [inline-methods] */
        public SparkSequenceVectors.Builder<VocabWord> layerSize2(int i) {
            super.layerSize2(i);
            return this;
        }

        @Override // org.deeplearning4j.spark.models.sequencevectors.SparkSequenceVectors.Builder
        /* renamed from: build, reason: merged with bridge method [inline-methods] */
        public SparkSequenceVectors<VocabWord> build2() {
            SparkWord2Vec sparkWord2Vec = new SparkWord2Vec(this.peersConfiguration, this.configuration);
            sparkWord2Vec.exporter = this.modelExporter;
            sparkWord2Vec.storageLevel = this.storageLevel;
            sparkWord2Vec.workers = this.workers;
            return sparkWord2Vec;
        }
    }

    protected SparkWord2Vec() {
        this.configuration = new VectorsConfiguration();
        this.configuration.setTokenizerFactory(DefaultTokenizerFactory.class.getCanonicalName());
    }

    public SparkWord2Vec(@NonNull VoidConfiguration voidConfiguration, @NonNull VectorsConfiguration vectorsConfiguration) {
        if (voidConfiguration == null) {
            throw new NullPointerException("psConfiguration is marked non-null but is null");
        }
        if (vectorsConfiguration == null) {
            throw new NullPointerException("configuration is marked non-null but is null");
        }
        this.configuration = vectorsConfiguration;
        this.paramServerConfiguration = voidConfiguration;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.deeplearning4j.spark.models.sequencevectors.SparkSequenceVectors
    public VocabCache<ShallowSequenceElement> getShallowVocabCache() {
        return super.getShallowVocabCache();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.deeplearning4j.spark.models.sequencevectors.SparkSequenceVectors
    public void validateConfiguration() {
        super.validateConfiguration();
        if (this.configuration.getTokenizerFactory() == null) {
            throw new DL4JInvalidConfigException("TokenizerFactory is undefined. Can't train Word2Vec without it.");
        }
    }

    @Override // org.deeplearning4j.spark.models.sequencevectors.SparkSequenceVectors
    @Deprecated
    public void fit() {
        throw new UnsupportedOperationException("To use fit() method, please consider using standalone implementation");
    }

    public void fitSentences(JavaRDD<String> javaRDD) {
        validateConfiguration();
        broadcastEnvironment(new JavaSparkContext(javaRDD.context()));
        super.fitSequences(javaRDD.map(new TokenizerFunction(this.configurationBroadcast)));
    }
}
