/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

package org.deeplearning4j.spark.models.embeddings.word2vec;

import lombok.Getter;
import lombok.NonNull;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

import java.io.Serializable;
import java.util.concurrent.atomic.AtomicBoolean;

public class NegativeHolder implements Serializable {
    private static NegativeHolder ourInstance = new NegativeHolder();

    public static NegativeHolder getInstance() {
        return ourInstance;
    }

    @Getter
    private volatile INDArray syn1Neg;
    @Getter
    private volatile INDArray table;

    private transient AtomicBoolean wasInit = new AtomicBoolean(false);
    private transient VocabCache<VocabWord> vocab;

    private NegativeHolder() {

    }

    public synchronized void initHolder(@NonNull VocabCache<VocabWord> vocabCache, double[] expTable, int layerSize) {
        if (!wasInit.get()) {
            this.vocab = vocabCache;
            this.syn1Neg = Nd4j.zeros(vocabCache.numWords(), layerSize);
            makeTable(Math.max(expTable.length, 100000), 0.75);
            wasInit.set(true);
        }
    }

    protected void makeTable(int tableSize, double power) {
        int vocabSize = vocab.numWords();
        table = Nd4j.create(DataType.FLOAT, tableSize);
        double trainWordsPow = 0.0;
        for (String word : vocab.words()) {
            trainWordsPow += Math.pow(vocab.wordFrequency(word), power);
        }

        int wordIdx = 0;
        String word = vocab.wordAtIndex(wordIdx);
        double d1 = Math.pow(vocab.wordFrequency(word), power) / trainWordsPow;
        for (int i = 0; i < tableSize; i++) {
            table.putScalar(i, wordIdx);
            double mul = i * 1.0 / (double) tableSize;
            if (mul > d1) {
                if (wordIdx < vocabSize - 1)
                    wordIdx++;
                word = vocab.wordAtIndex(wordIdx);
                String wordAtIndex = vocab.wordAtIndex(wordIdx);
                if (word == null)
                    continue;
                d1 += Math.pow(vocab.wordFrequency(wordAtIndex), power) / trainWordsPow;
            }
        }
    }


}
