package org.deeplearning4j.models.embeddings.learning.impl.sequence;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicLong;
import lombok.NonNull;
import org.deeplearning4j.models.embeddings.WeightLookupTable;
import org.deeplearning4j.models.embeddings.learning.SequenceLearningAlgorithm;
import org.deeplearning4j.models.embeddings.learning.impl.elements.SkipGram;
import org.deeplearning4j.models.embeddings.loader.VectorsConfiguration;
import org.deeplearning4j.models.sequencevectors.interfaces.SequenceIterator;
import org.deeplearning4j.models.sequencevectors.sequence.Sequence;
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/models/embeddings/learning/impl/sequence/DBOW.class */
public class DBOW<T extends SequenceElement> implements SequenceLearningAlgorithm<T> {
    protected VocabCache<T> vocabCache;
    protected WeightLookupTable<T> lookupTable;
    protected VectorsConfiguration configuration;
    protected int window;
    protected boolean useAdaGrad;
    protected double negative;
    protected SkipGram<T> skipGram = new SkipGram<>();
    private static final Logger log = LoggerFactory.getLogger(DBOW.class);

    @Override // org.deeplearning4j.models.embeddings.learning.SequenceLearningAlgorithm
    public String getCodeName() {
        return "DBOW";
    }

    @Override // org.deeplearning4j.models.embeddings.learning.SequenceLearningAlgorithm
    public void configure(@NonNull VocabCache<T> vocabCache, @NonNull WeightLookupTable<T> weightLookupTable, @NonNull VectorsConfiguration vectorsConfiguration) {
        if (vocabCache == null) {
            throw new NullPointerException("vocabCache");
        }
        if (weightLookupTable == null) {
            throw new NullPointerException("lookupTable");
        }
        if (vectorsConfiguration == null) {
            throw new NullPointerException("configuration");
        }
        this.vocabCache = vocabCache;
        this.lookupTable = weightLookupTable;
        this.window = vectorsConfiguration.getWindow();
        this.useAdaGrad = vectorsConfiguration.isUseAdaGrad();
        this.negative = vectorsConfiguration.getNegative();
        this.skipGram.configure(vocabCache, weightLookupTable, vectorsConfiguration);
    }

    @Override // org.deeplearning4j.models.embeddings.learning.SequenceLearningAlgorithm
    public void pretrain(SequenceIterator<T> sequenceIterator) {
    }

    @Override // org.deeplearning4j.models.embeddings.learning.SequenceLearningAlgorithm
    public double learnSequence(@NonNull Sequence<T> sequence, @NonNull AtomicLong atomicLong, double d) {
        if (sequence == null) {
            throw new NullPointerException("sequence");
        }
        if (atomicLong == null) {
            throw new NullPointerException("nextRandom");
        }
        dbow(0, sequence, ((int) atomicLong.get()) % this.window, atomicLong, d);
        return 0.0d;
    }

    @Override // org.deeplearning4j.models.embeddings.learning.SequenceLearningAlgorithm
    public boolean isEarlyTerminationHit() {
        return false;
    }

    /* JADX WARN: Multi-variable type inference failed */
    protected void dbow(int i, Sequence<T> sequence, int i2, AtomicLong atomicLong, double d) {
        List<T> elements = this.skipGram.applySubsampling(sequence, atomicLong).getElements();
        ArrayList<SequenceElement> arrayList = new ArrayList();
        arrayList.addAll(sequence.getSequenceLabels());
        if (sequence.getSequenceLabel() == null) {
            throw new IllegalStateException("Label is NULL");
        }
        if (elements.isEmpty() || arrayList.isEmpty()) {
            return;
        }
        for (SequenceElement sequenceElement : arrayList) {
            for (T t : elements) {
                if (t != null) {
                    this.skipGram.iterateSample(t, sequenceElement, atomicLong, d);
                }
            }
        }
    }
}
