package org.deeplearning4j.models.embeddings.learning.impl.sequence;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.atomic.AtomicLong;
import lombok.NonNull;
import org.deeplearning4j.models.embeddings.WeightLookupTable;
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
import org.deeplearning4j.models.embeddings.learning.ElementsLearningAlgorithm;
import org.deeplearning4j.models.embeddings.learning.SequenceLearningAlgorithm;
import org.deeplearning4j.models.embeddings.learning.impl.elements.CBOW;
import org.deeplearning4j.models.embeddings.loader.VectorsConfiguration;
import org.deeplearning4j.models.sequencevectors.interfaces.SequenceIterator;
import org.deeplearning4j.models.sequencevectors.sequence.Sequence;
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/models/embeddings/learning/impl/sequence/DM.class */
public class DM<T extends SequenceElement> implements SequenceLearningAlgorithm<T> {
    private VocabCache<T> vocabCache;
    private WeightLookupTable<T> lookupTable;
    private VectorsConfiguration configuration;
    protected int window;
    protected boolean useAdaGrad;
    protected double negative;
    protected double sampling;
    protected double[] expTable;
    protected INDArray syn0;
    protected INDArray syn1;
    protected INDArray syn1Neg;
    protected INDArray table;
    private CBOW<T> cbow = new CBOW<>();
    private static final Logger log = LoggerFactory.getLogger(DM.class);
    protected static double MAX_EXP = 6.0d;

    @Override // org.deeplearning4j.models.embeddings.learning.SequenceLearningAlgorithm
    public ElementsLearningAlgorithm<T> getElementsLearningAlgorithm() {
        return this.cbow;
    }

    @Override // org.deeplearning4j.models.embeddings.learning.SequenceLearningAlgorithm
    public String getCodeName() {
        return "PV-DM";
    }

    @Override // org.deeplearning4j.models.embeddings.learning.SequenceLearningAlgorithm
    public void configure(@NonNull VocabCache<T> vocabCache, @NonNull WeightLookupTable<T> weightLookupTable, @NonNull VectorsConfiguration vectorsConfiguration) {
        if (vocabCache == null) {
            throw new NullPointerException("vocabCache");
        }
        if (weightLookupTable == null) {
            throw new NullPointerException("lookupTable");
        }
        if (vectorsConfiguration == null) {
            throw new NullPointerException("configuration");
        }
        this.vocabCache = vocabCache;
        this.lookupTable = weightLookupTable;
        this.configuration = vectorsConfiguration;
        this.cbow.configure(vocabCache, weightLookupTable, vectorsConfiguration);
        this.window = vectorsConfiguration.getWindow();
        this.useAdaGrad = vectorsConfiguration.isUseAdaGrad();
        this.negative = vectorsConfiguration.getNegative();
        this.sampling = vectorsConfiguration.getSampling();
        this.syn0 = ((InMemoryLookupTable) weightLookupTable).getSyn0();
        this.syn1 = ((InMemoryLookupTable) weightLookupTable).getSyn1();
        this.syn1Neg = ((InMemoryLookupTable) weightLookupTable).getSyn1Neg();
        this.expTable = ((InMemoryLookupTable) weightLookupTable).getExpTable();
        this.table = ((InMemoryLookupTable) weightLookupTable).getTable();
    }

    @Override // org.deeplearning4j.models.embeddings.learning.SequenceLearningAlgorithm
    public void pretrain(SequenceIterator<T> sequenceIterator) {
    }

    @Override // org.deeplearning4j.models.embeddings.learning.SequenceLearningAlgorithm
    public double learnSequence(Sequence<T> sequence, AtomicLong atomicLong, double d) {
        Sequence<T> applySubsampling = this.cbow.applySubsampling(sequence, atomicLong);
        ArrayList arrayList = new ArrayList();
        arrayList.addAll(sequence.getSequenceLabels());
        if (sequence.getSequenceLabel() == null) {
            throw new IllegalStateException("Label is NULL");
        }
        if (applySubsampling.isEmpty() || arrayList.isEmpty()) {
            return 0.0d;
        }
        for (int i = 0; i < applySubsampling.size(); i++) {
            atomicLong.set(Math.abs((atomicLong.get() * 25214903917L) + 11));
            dm(i, applySubsampling, ((int) atomicLong.get()) % this.window, atomicLong, d, arrayList, false, null);
        }
        return 0.0d;
    }

    public void dm(int i, Sequence<T> sequence, int i2, AtomicLong atomicLong, double d, List<T> list, boolean z, INDArray iNDArray) {
        int i3;
        int i4 = ((this.window * 2) + 1) - i2;
        T elementByIndex = sequence.getElementByIndex(i);
        ArrayList arrayList = new ArrayList();
        for (int i5 = i2; i5 < i4; i5++) {
            if (i5 != this.window && (i3 = (i - this.window) + i5) >= 0 && i3 < sequence.size()) {
                arrayList.add(Integer.valueOf(sequence.getElementByIndex(i3).getIndex()));
            }
        }
        if (list != null) {
            Iterator<T> it = list.iterator();
            while (it.hasNext()) {
                arrayList.add(Integer.valueOf(it.next().getIndex()));
            }
        }
        int[] iArr = new int[arrayList.size()];
        for (int i6 = 0; i6 < iArr.length; i6++) {
            iArr[i6] = ((Integer) arrayList.get(i6)).intValue();
        }
        this.cbow.iterateSample(elementByIndex, iArr, atomicLong, d, z, list == null ? 0 : list.size(), this.configuration.isTrainElementsVectors(), iNDArray);
        if (this.cbow.getBatch() == null || this.cbow.getBatch().size() < this.configuration.getBatchSize()) {
            return;
        }
        Nd4j.getExecutioner().exec(this.cbow.getBatch());
        this.cbow.getBatch().clear();
    }

    @Override // org.deeplearning4j.models.embeddings.learning.SequenceLearningAlgorithm
    public boolean isEarlyTerminationHit() {
        return false;
    }

    @Override // org.deeplearning4j.models.embeddings.learning.SequenceLearningAlgorithm
    public INDArray inferSequence(Sequence<T> sequence, long j, double d, double d2, int i) {
        AtomicLong atomicLong = new AtomicLong(j);
        if (sequence.isEmpty()) {
            return null;
        }
        INDArray divi = Nd4j.rand(new int[]{1, this.lookupTable.layerSize()}, Nd4j.getRandomFactory().getNewRandomInstance(this.configuration.getSeed() * sequence.hashCode(), this.lookupTable.layerSize() + 1)).subi(Double.valueOf(0.5d)).divi(Integer.valueOf(this.lookupTable.layerSize()));
        divi.dup();
        for (int i2 = 0; i2 < i; i2++) {
            for (int i3 = 0; i3 < sequence.size(); i3++) {
                atomicLong.set(Math.abs((atomicLong.get() * 25214903917L) + 11));
                dm(i3, sequence, ((int) atomicLong.get()) % this.window, atomicLong, d, null, true, divi);
            }
            d = ((d - d2) / (i - i2)) + d2;
        }
        return divi;
    }

    @Override // org.deeplearning4j.models.embeddings.learning.SequenceLearningAlgorithm
    public void finish() {
        if (this.cbow == null || this.cbow.getBatch() == null || this.cbow.getBatch().size() < this.configuration.getBatchSize()) {
            return;
        }
        Nd4j.getExecutioner().exec(this.cbow.getBatch());
        this.cbow.getBatch().clear();
    }
}
