package org.deeplearning4j.models.word2vec;

import com.google.common.util.concurrent.AtomicDouble;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.io.Serializable;
import java.util.Arrays;
import org.apache.commons.math3.util.FastMath;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/deeplearning4j/models/word2vec/VocabWord.class */
public class VocabWord implements Comparable<VocabWord>, Serializable {
    private static final long serialVersionUID = 2223750736522624256L;
    private AtomicDouble wordFrequency;
    private int index;
    private VocabWord left;
    private VocabWord right;
    private VocabWord parent;
    private int[] codes;
    private String word;
    public static final String PARENT_NODE = "parent";
    private INDArray historicalGradient;
    private int[] points;
    private int codeLength;

    public static VocabWord none() {
        return new VocabWord(0.0d, "none");
    }

    public VocabWord(double d, String str) {
        this.wordFrequency = new AtomicDouble(0.0d);
        this.index = -1;
        this.codes = new int[40];
        this.points = new int[40];
        this.codeLength = 0;
        this.wordFrequency.set(d);
        if (str == null || str.isEmpty()) {
            throw new IllegalArgumentException("Word must not be null or empty");
        }
        this.word = str;
    }

    public VocabWord() {
        this.wordFrequency = new AtomicDouble(0.0d);
        this.index = -1;
        this.codes = new int[40];
        this.points = new int[40];
        this.codeLength = 0;
    }

    public void write(DataOutputStream dataOutputStream) throws IOException {
        dataOutputStream.writeDouble(this.wordFrequency.get());
    }

    public VocabWord read(DataInputStream dataInputStream) throws IOException {
        this.wordFrequency.set(dataInputStream.readDouble());
        return this;
    }

    public String getWord() {
        return this.word;
    }

    public void setWord(String str) {
        this.word = str;
    }

    public int[] getCodes() {
        return this.codes;
    }

    public void setCodes(int[] iArr) {
        this.codes = iArr;
    }

    public void setParent(VocabWord vocabWord) {
        this.parent = vocabWord;
    }

    public VocabWord getLeft() {
        return this.left;
    }

    public void setLeft(VocabWord vocabWord) {
        this.left = vocabWord;
    }

    public VocabWord getRight() {
        return this.right;
    }

    public void setRight(VocabWord vocabWord) {
        this.right = vocabWord;
    }

    public void increment() {
        increment(1);
    }

    public void increment(int i) {
        this.wordFrequency.getAndAdd(i);
    }

    public int getIndex() {
        return this.index;
    }

    public void setIndex(int i) {
        this.index = i;
    }

    public double getWordFrequency() {
        return this.wordFrequency.get();
    }

    @Override // java.lang.Comparable
    public int compareTo(VocabWord vocabWord) {
        return Double.compare(this.wordFrequency.get(), vocabWord.wordFrequency.get());
    }

    public double getLearningRate(int i, double d) {
        if (this.historicalGradient == null) {
            this.historicalGradient = Nd4j.zeros(getCodes().length);
        }
        this.historicalGradient.putScalar(i, this.historicalGradient.getDouble(i) + Math.pow(d, 2.0d));
        return (FastMath.abs(d) / (FastMath.sqrt(this.historicalGradient.getDouble(i)) + 9.999999974752427E-7d)) * 0.10000000149011612d;
    }

    public int[] getPoints() {
        return this.points;
    }

    public void setPoints(int[] iArr) {
        this.points = iArr;
    }

    public int getCodeLength() {
        return this.codeLength;
    }

    public void setCodeLength(int i) {
        this.codeLength = i;
    }

    public String toString() {
        return "VocabWord{wordFrequency=" + this.wordFrequency + ", index=" + this.index + ", codes=" + Arrays.toString(this.codes) + ", word='" + this.word + "', historicalGradient=" + this.historicalGradient + ", points=" + Arrays.toString(this.points) + '}';
    }

    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (!(obj instanceof VocabWord)) {
            return false;
        }
        VocabWord vocabWord = (VocabWord) obj;
        if (this.index != vocabWord.index || !Arrays.equals(this.codes, vocabWord.codes)) {
            return false;
        }
        if (this.historicalGradient != null) {
            if (!this.historicalGradient.equals(vocabWord.historicalGradient)) {
                return false;
            }
        } else if (vocabWord.historicalGradient != null) {
            return false;
        }
        if (this.left != null) {
            if (!this.left.equals(vocabWord.left)) {
                return false;
            }
        } else if (vocabWord.left != null) {
            return false;
        }
        if (this.parent != null) {
            if (!this.parent.equals(vocabWord.parent)) {
                return false;
            }
        } else if (vocabWord.parent != null) {
            return false;
        }
        if (!Arrays.equals(this.points, vocabWord.points)) {
            return false;
        }
        if (this.right != null) {
            if (!this.right.equals(vocabWord.right)) {
                return false;
            }
        } else if (vocabWord.right != null) {
            return false;
        }
        if (this.word != null) {
            if (!this.word.equals(vocabWord.word)) {
                return false;
            }
        } else if (vocabWord.word != null) {
            return false;
        }
        return this.wordFrequency != null ? this.wordFrequency.equals(vocabWord.wordFrequency) : vocabWord.wordFrequency == null;
    }

    public int hashCode() {
        return (31 * ((31 * ((31 * ((31 * ((31 * ((31 * ((31 * ((31 * (this.wordFrequency != null ? this.wordFrequency.hashCode() : 0)) + this.index)) + (this.left != null ? this.left.hashCode() : 0))) + (this.right != null ? this.right.hashCode() : 0))) + (this.parent != null ? this.parent.hashCode() : 0))) + (this.codes != null ? Arrays.hashCode(this.codes) : 0))) + (this.word != null ? this.word.hashCode() : 0))) + (this.historicalGradient != null ? this.historicalGradient.hashCode() : 0))) + (this.points != null ? Arrays.hashCode(this.points) : 0);
    }
}
