package org.deeplearning4j.models.word2vec;

import com.google.common.util.concurrent.AtomicDouble;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import org.apache.commons.math3.util.FastMath;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/deeplearning4j/models/word2vec/VocabWord.class */
public class VocabWord implements Comparable<VocabWord>, Serializable {
    private static final long serialVersionUID = 2223750736522624256L;
    private String word;
    private INDArray historicalGradient;
    private AtomicDouble wordFrequency = new AtomicDouble(0.0d);
    private int index = -1;
    private List<Integer> codes = new ArrayList();
    private List<Integer> points = new ArrayList();
    private int codeLength = 0;

    public static VocabWord none() {
        return new VocabWord(0.0d, "none");
    }

    public VocabWord(double d, String str) {
        this.wordFrequency.set(d);
        if (str == null || str.isEmpty()) {
            throw new IllegalArgumentException("Word must not be null or empty");
        }
        this.word = str;
    }

    public VocabWord() {
    }

    public void write(DataOutputStream dataOutputStream) throws IOException {
        dataOutputStream.writeDouble(this.wordFrequency.get());
    }

    public VocabWord read(DataInputStream dataInputStream) throws IOException {
        this.wordFrequency.set(dataInputStream.readDouble());
        return this;
    }

    public String getWord() {
        return this.word;
    }

    public void setWord(String str) {
        this.word = str;
    }

    public void increment() {
        increment(1);
    }

    public void increment(int i) {
        this.wordFrequency.getAndAdd(i);
    }

    public int getIndex() {
        return this.index;
    }

    public void setIndex(int i) {
        this.index = i;
    }

    public double getWordFrequency() {
        if (this.wordFrequency == null) {
            return 0.0d;
        }
        return this.wordFrequency.get();
    }

    public List<Integer> getCodes() {
        return this.codes;
    }

    public void setCodes(List<Integer> list) {
        this.codes = list;
    }

    @Override // java.lang.Comparable
    public int compareTo(VocabWord vocabWord) {
        return Double.compare(this.wordFrequency.get(), vocabWord.wordFrequency.get());
    }

    public double getGradient(int i, double d) {
        if (this.historicalGradient == null) {
            this.historicalGradient = Nd4j.zeros(getCodes().size());
        }
        this.historicalGradient.putScalar(i, this.historicalGradient.getDouble(i) + Math.pow(d, 2.0d));
        return (FastMath.abs(d) / (FastMath.sqrt(this.historicalGradient.getDouble(i)) + 9.999999974752427E-7d)) * 0.10000000149011612d;
    }

    public List<Integer> getPoints() {
        return this.points;
    }

    public void setPoints(List<Integer> list) {
        this.points = list;
    }

    public int getCodeLength() {
        return this.codeLength;
    }

    public void setCodeLength(int i) {
        this.codeLength = i;
        if (this.codes.size() < i) {
            for (int i2 = 0; i2 < i; i2++) {
                this.codes.add(0);
            }
        }
        if (this.points.size() < i) {
            for (int i3 = 0; i3 < i; i3++) {
                this.points.add(0);
            }
        }
    }

    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (!(obj instanceof VocabWord)) {
            return false;
        }
        VocabWord vocabWord = (VocabWord) obj;
        if (this.codeLength != vocabWord.codeLength || this.index != vocabWord.index || !this.codes.equals(vocabWord.codes)) {
            return false;
        }
        if (this.historicalGradient != null) {
            if (!this.historicalGradient.equals(vocabWord.historicalGradient)) {
                return false;
            }
        } else if (vocabWord.historicalGradient != null) {
            return false;
        }
        return this.points.equals(vocabWord.points) && this.word.equals(vocabWord.word) && this.wordFrequency.equals(vocabWord.wordFrequency);
    }

    public int hashCode() {
        return (31 * ((31 * ((31 * ((31 * ((31 * ((31 * this.wordFrequency.hashCode()) + this.index)) + this.codes.hashCode())) + this.word.hashCode())) + (this.historicalGradient != null ? this.historicalGradient.hashCode() : 0))) + this.points.hashCode())) + this.codeLength;
    }

    public String toString() {
        return "VocabWord{wordFrequency=" + this.wordFrequency + ", index=" + this.index + ", codes=" + this.codes + ", word='" + this.word + "', historicalGradient=" + this.historicalGradient + ", points=" + this.points + ", codeLength=" + this.codeLength + '}';
    }
}
