package org.deeplearning4j.models.embeddings.learning.impl.elements;

import java.util.List;
import java.util.concurrent.atomic.AtomicLong;
import lombok.NonNull;
import org.apache.commons.lang.math.RandomUtils;
import org.deeplearning4j.models.embeddings.WeightLookupTable;
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
import org.deeplearning4j.models.embeddings.learning.ElementsLearningAlgorithm;
import org.deeplearning4j.models.embeddings.loader.VectorsConfiguration;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectorsImpl;
import org.deeplearning4j.models.sequencevectors.interfaces.SequenceIterator;
import org.deeplearning4j.models.sequencevectors.sequence.Sequence;
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/deeplearning4j/models/embeddings/learning/impl/elements/SkipGram.class */
public class SkipGram<T extends SequenceElement> implements ElementsLearningAlgorithm<T> {
    protected VocabCache<T> vocabCache;
    protected WeightLookupTable<T> lookupTable;
    protected VectorsConfiguration configuration;
    protected static double MAX_EXP = 6.0d;
    protected double[] expTable;
    protected int window;
    protected boolean useAdaGrad;
    protected double negative;
    protected double sampling;
    protected int[] variableWindows;
    protected INDArray syn0;
    protected INDArray syn1;
    protected INDArray syn1Neg;
    protected INDArray table;

    @Override // org.deeplearning4j.models.embeddings.learning.ElementsLearningAlgorithm
    public String getCodeName() {
        return "SkipGram";
    }

    @Override // org.deeplearning4j.models.embeddings.learning.ElementsLearningAlgorithm
    public void configure(@NonNull VocabCache<T> vocabCache, @NonNull WeightLookupTable<T> weightLookupTable, @NonNull VectorsConfiguration vectorsConfiguration) {
        if (vocabCache == null) {
            throw new NullPointerException("vocabCache");
        }
        if (weightLookupTable == null) {
            throw new NullPointerException("lookupTable");
        }
        if (vectorsConfiguration == null) {
            throw new NullPointerException("configuration");
        }
        this.vocabCache = vocabCache;
        this.lookupTable = weightLookupTable;
        this.configuration = vectorsConfiguration;
        this.expTable = ((InMemoryLookupTable) weightLookupTable).getExpTable();
        this.syn0 = ((InMemoryLookupTable) weightLookupTable).getSyn0();
        this.syn1 = ((InMemoryLookupTable) weightLookupTable).getSyn1();
        this.syn1Neg = ((InMemoryLookupTable) weightLookupTable).getSyn1Neg();
        this.table = ((InMemoryLookupTable) weightLookupTable).getTable();
        this.window = vectorsConfiguration.getWindow();
        this.useAdaGrad = vectorsConfiguration.isUseAdaGrad();
        this.negative = vectorsConfiguration.getNegative();
        this.sampling = vectorsConfiguration.getSampling();
        this.variableWindows = vectorsConfiguration.getVariableWindows();
    }

    @Override // org.deeplearning4j.models.embeddings.learning.ElementsLearningAlgorithm
    public void pretrain(SequenceIterator<T> sequenceIterator) {
    }

    public Sequence<T> applySubsampling(@NonNull Sequence<T> sequence, @NonNull AtomicLong atomicLong) {
        if (sequence == null) {
            throw new NullPointerException("sequence");
        }
        if (atomicLong == null) {
            throw new NullPointerException("nextRandom");
        }
        Sequence<T> sequence2 = new Sequence<>();
        if (this.sampling <= 0.0d) {
            return sequence;
        }
        sequence2.setSequenceId(sequence.getSequenceId());
        if (sequence.getSequenceLabels() != null) {
            sequence2.setSequenceLabels(sequence.getSequenceLabels());
        }
        if (sequence.getSequenceLabel() != null) {
            sequence2.setSequenceLabel(sequence.getSequenceLabel());
        }
        for (T t : sequence.getElements()) {
            double d = this.vocabCache.totalWordOccurrences();
            double sqrt = ((Math.sqrt(t.getElementFrequency() / (this.sampling * d)) + 1.0d) * (this.sampling * d)) / t.getElementFrequency();
            atomicLong.set((atomicLong.get() * 25214903917L) + 11);
            if (sqrt >= (atomicLong.get() & 65535) / 65536.0d) {
                sequence2.addElement(t);
            }
        }
        return sequence2;
    }

    @Override // org.deeplearning4j.models.embeddings.learning.ElementsLearningAlgorithm
    public double learnSequence(@NonNull Sequence<T> sequence, @NonNull AtomicLong atomicLong, @NonNull double d) {
        if (sequence == null) {
            throw new NullPointerException("sequence");
        }
        if (atomicLong == null) {
            throw new NullPointerException("nextRandom");
        }
        Sequence<T> sequence2 = sequence;
        if (this.sampling > 0.0d) {
            sequence2 = applySubsampling(sequence, atomicLong);
        }
        double d2 = 0.0d;
        int i = this.window;
        if (this.variableWindows != null && this.variableWindows.length != 0) {
            i = this.variableWindows[RandomUtils.nextInt(this.variableWindows.length)];
        }
        for (int i2 = 0; i2 < sequence2.getElements().size(); i2++) {
            atomicLong.set(Math.abs((atomicLong.get() * 25214903917L) + 11));
            d2 = skipGram(i2, sequence2.getElements(), ((int) atomicLong.get()) % i, atomicLong, d, i);
        }
        return d2;
    }

    @Override // org.deeplearning4j.models.embeddings.learning.ElementsLearningAlgorithm
    public boolean isEarlyTerminationHit() {
        return false;
    }

    private double skipGram(int i, List<T> list, int i2, AtomicLong atomicLong, double d, int i3) {
        int i4;
        T t = list.get(i);
        if (t == null || list.isEmpty()) {
            return 0.0d;
        }
        double d2 = 0.0d;
        int i5 = ((i3 * 2) + 1) - i2;
        for (int i6 = i2; i6 < i5; i6++) {
            if (i6 != i3 && (i4 = (i - i3) + i6) >= 0 && i4 < list.size()) {
                d2 = iterateSample(t, list.get(i4), atomicLong, d);
            }
        }
        return d2;
    }

    public double iterateSample(T t, T t2, AtomicLong atomicLong, double d) {
        int i;
        double gradient;
        int length;
        if (t == null || t2 == null || t2.getIndex() < 0 || t.getIndex() == t2.getIndex() || t.getLabel().equals("STOP") || t2.getLabel().equals("STOP") || t.getLabel().equals(WordVectorsImpl.DEFAULT_UNK) || t2.getLabel().equals(WordVectorsImpl.DEFAULT_UNK)) {
            return 0.0d;
        }
        INDArray slice = this.syn0.slice(t2.getIndex());
        INDArray create = Nd4j.create(this.configuration.getLayersSize());
        for (int i2 = 0; i2 < t.getCodeLength(); i2++) {
            int intValue = t.getCodes().get(i2).intValue();
            int intValue2 = t.getPoints().get(i2).intValue();
            if (intValue2 >= this.syn0.rows() || intValue2 < 0) {
                throw new IllegalStateException("Illegal point " + intValue2);
            }
            INDArray slice2 = this.syn1.slice(intValue2);
            double dot = Nd4j.getBlasWrapper().dot(slice, slice2);
            if (dot >= (-MAX_EXP) && dot < MAX_EXP && (length = (int) ((dot + MAX_EXP) * ((this.expTable.length / MAX_EXP) / 2.0d))) < this.expTable.length) {
                double d2 = this.expTable[length];
                double gradient2 = this.useAdaGrad ? t.getGradient(i2, (1 - intValue) - d2, d) : ((1 - intValue) - d2) * d;
                Nd4j.getBlasWrapper().level1().axpy(slice2.length(), gradient2, slice2, create);
                Nd4j.getBlasWrapper().level1().axpy(slice2.length(), gradient2, slice, slice2);
            }
        }
        int index = t.getIndex();
        if (this.negative > 0.0d) {
            for (int i3 = 0; i3 < this.negative + 1.0d; i3++) {
                if (i3 == 0) {
                    i = 1;
                } else {
                    atomicLong.set(Math.abs((atomicLong.get() * 25214903917L) + 11));
                    index = this.table.getInt(new int[]{Math.abs(((int) (atomicLong.get() >> 16)) % this.table.length())});
                    if (index <= 0) {
                        index = (((int) atomicLong.get()) % (this.vocabCache.numWords() - 1)) + 1;
                    }
                    if (index != t.getIndex()) {
                        i = 0;
                    }
                }
                if (index < this.syn1Neg.rows() && index >= 0) {
                    double dot2 = Nd4j.getBlasWrapper().dot(slice, this.syn1Neg.slice(index));
                    if (dot2 > MAX_EXP) {
                        gradient = this.useAdaGrad ? this.lookupTable.getGradient(index, i - 1) : (i - 1) * d;
                    } else if (dot2 < (-MAX_EXP)) {
                        gradient = i * (this.useAdaGrad ? this.lookupTable.getGradient(index, d) : d);
                    } else {
                        int length2 = (int) ((dot2 + MAX_EXP) * ((this.expTable.length / MAX_EXP) / 2.0d));
                        if (length2 < this.expTable.length) {
                            gradient = this.useAdaGrad ? this.lookupTable.getGradient(index, i - this.expTable[length2]) : (i - this.expTable[length2]) * d;
                        }
                    }
                    Nd4j.getBlasWrapper().level1().axpy(this.lookupTable.layerSize(), gradient, this.syn1Neg.slice(index), create);
                    Nd4j.getBlasWrapper().level1().axpy(this.lookupTable.layerSize(), gradient, slice, this.syn1Neg.slice(index));
                }
            }
        }
        Nd4j.getBlasWrapper().level1().axpy(this.lookupTable.layerSize(), 1.0d, create, slice);
        return 0.0d;
    }
}
