package org.nd4j.parameterserver.distributed.training.impl;

import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong;
import lombok.NonNull;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.RandomUtils;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.parameterserver.distributed.enums.ExecutionMode;
import org.nd4j.parameterserver.distributed.logic.completion.FrameCompletionHandler;
import org.nd4j.parameterserver.distributed.logic.completion.RequestDescriptor;
import org.nd4j.parameterserver.distributed.logic.storage.WordVectorStorage;
import org.nd4j.parameterserver.distributed.messages.VoidAggregation;
import org.nd4j.parameterserver.distributed.messages.aggregations.DotAggregation;
import org.nd4j.parameterserver.distributed.messages.complete.FrameCompleteMessage;
import org.nd4j.parameterserver.distributed.messages.intercom.DistributedSgDotMessage;
import org.nd4j.parameterserver.distributed.messages.requests.SkipGramRequestMessage;
import org.nd4j.parameterserver.distributed.training.BaseTrainer;
import org.nd4j.parameterserver.distributed.training.chains.SkipGramChain;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/parameterserver/distributed/training/impl/SkipGramTrainer.class */
public class SkipGramTrainer extends BaseTrainer<SkipGramRequestMessage> {
    private static final Logger log = LoggerFactory.getLogger(SkipGramTrainer.class);
    private static final float HS_MAX_EXP = 6.0f;
    protected Map<RequestDescriptor, SkipGramChain> chains = new ConcurrentHashMap();
    protected AtomicLong cntRounds = new AtomicLong(0);

    @Override // org.nd4j.parameterserver.distributed.training.TrainingDriver
    public void startTraining(SkipGramRequestMessage skipGramRequestMessage) {
        int nextInt;
        SkipGramChain skipGramChain = new SkipGramChain(skipGramRequestMessage.getOriginatorId(), skipGramRequestMessage.getTaskId(), skipGramRequestMessage.getFrameId());
        skipGramChain.addElement(skipGramRequestMessage);
        this.chains.put(RequestDescriptor.createDescriptor(skipGramRequestMessage.getOriginatorId(), skipGramRequestMessage.getTaskId()), skipGramChain);
        int[] iArr = new int[0];
        int[] points = skipGramRequestMessage.getPoints();
        if (skipGramRequestMessage.getNegSamples() > 0) {
            int rows = this.storage.getArray(WordVectorStorage.SYN_0).rows();
            int[] iArr2 = new int[skipGramRequestMessage.getNegSamples() + 1];
            iArr2[0] = skipGramRequestMessage.getW1();
            for (int i = 1; i < skipGramRequestMessage.getNegSamples() + 1; i++) {
                do {
                    nextInt = RandomUtils.nextInt(0, rows);
                } while (nextInt == skipGramRequestMessage.getW1());
                iArr2[i] = nextInt;
            }
            points = ArrayUtils.addAll(points, iArr2);
            skipGramRequestMessage.setNegatives(iArr2);
        }
        if (skipGramRequestMessage.getPoints().length != skipGramRequestMessage.getCodes().length) {
            throw new RuntimeException("Mismatiching points/codes lengths here!");
        }
        DistributedSgDotMessage distributedSgDotMessage = new DistributedSgDotMessage(skipGramRequestMessage.getTaskId(), iArr, points, skipGramRequestMessage.getW1(), skipGramRequestMessage.getW2(), skipGramRequestMessage.getCodes(), skipGramRequestMessage.getCodes() != null && skipGramRequestMessage.getCodes().length > 0, skipGramRequestMessage.getNegSamples(), (float) skipGramRequestMessage.getAlpha());
        distributedSgDotMessage.setTargetId((short) -1);
        distributedSgDotMessage.setOriginatorId(skipGramRequestMessage.getOriginatorId());
        if (this.voidConfiguration.getExecutionMode() == ExecutionMode.AVERAGING) {
            this.transport.putMessage(distributedSgDotMessage);
        } else if (this.voidConfiguration.getExecutionMode() == ExecutionMode.SHARDED) {
            this.transport.sendMessage(distributedSgDotMessage);
        }
    }

    @Override // org.nd4j.parameterserver.distributed.training.TrainingDriver
    public void pickTraining(@NonNull SkipGramRequestMessage skipGramRequestMessage) {
        if (skipGramRequestMessage == null) {
            throw new NullPointerException("message is marked non-null but is null");
        }
        RequestDescriptor createDescriptor = RequestDescriptor.createDescriptor(skipGramRequestMessage.getOriginatorId(), skipGramRequestMessage.getTaskId());
        if (this.chains.containsKey(createDescriptor)) {
            return;
        }
        this.chains.put(createDescriptor, new SkipGramChain(skipGramRequestMessage));
    }

    @Override // org.nd4j.parameterserver.distributed.training.TrainingDriver
    public String targetMessageClass() {
        return SkipGramRequestMessage.class.getSimpleName();
    }

    @Override // org.nd4j.parameterserver.distributed.training.TrainingDriver
    public void aggregationFinished(@NonNull VoidAggregation voidAggregation) {
        if (voidAggregation == null) {
            throw new NullPointerException("aggregation is marked non-null but is null");
        }
        SkipGramChain skipGramChain = this.chains.get(RequestDescriptor.createDescriptor(voidAggregation.getOriginatorId(), voidAggregation.getTaskId()));
        if (skipGramChain == null) {
            throw new RuntimeException("sI_" + ((int) this.transport.getShardIndex()) + " Unable to find chain for specified originatorId: [" + voidAggregation.getOriginatorId() + "]; taskId: [" + voidAggregation.getTaskId() + "]");
        }
        skipGramChain.addElement((DotAggregation) voidAggregation);
        finishTraining(voidAggregation.getOriginatorId(), voidAggregation.getTaskId());
    }

    @Override // org.nd4j.parameterserver.distributed.training.TrainingDriver
    public void finishTraining(long j, long j2) {
        double d;
        RequestDescriptor createDescriptor = RequestDescriptor.createDescriptor(j, j2);
        SkipGramChain skipGramChain = this.chains.get(createDescriptor);
        if (skipGramChain == null) {
            throw new RuntimeException("Unable to find chain for specified taskId: [" + j2 + "]");
        }
        SkipGramRequestMessage requestMessage = skipGramChain.getRequestMessage();
        double alpha = requestMessage.getAlpha();
        INDArray array = this.storage.getArray(WordVectorStorage.EXP_TABLE);
        INDArray accumulatedResult = skipGramChain.getDotAggregation().getAccumulatedResult();
        INDArray array2 = this.storage.getArray(WordVectorStorage.SYN_0);
        INDArray array3 = this.storage.getArray(WordVectorStorage.SYN_1);
        INDArray array4 = this.storage.getArray(WordVectorStorage.SYN_1_NEGATIVE);
        INDArray create = Nd4j.create(array2.columns());
        int i = 0;
        boolean z = false;
        if (requestMessage.getCodes().length > 0) {
            while (i < requestMessage.getCodes().length) {
                float f = accumulatedResult.getFloat(i);
                if (f >= -6.0f && f < HS_MAX_EXP) {
                    int length = (int) ((f + HS_MAX_EXP) * ((((float) array.length()) / HS_MAX_EXP) / 2.0d));
                    if (length < array.length() && length >= 0) {
                        double d2 = ((1 - skipGramChain.getRequestMessage().getCodes()[i]) - array.getFloat(length)) * alpha;
                        z = true;
                        Nd4j.getBlasWrapper().axpy(new Double(d2), array3.getRow(requestMessage.getPoints()[i]), create);
                        Nd4j.getBlasWrapper().axpy(new Double(d2), array2.getRow(requestMessage.getW2()), array3.getRow(requestMessage.getPoints()[i]));
                    }
                }
                i++;
            }
        }
        if (requestMessage.getNegSamples() > 0) {
            int i2 = 0;
            while (i < requestMessage.getNegSamples() + 1) {
                float f2 = accumulatedResult.getFloat(i);
                float f3 = i2 == 0 ? 1.0f : 0.0f;
                if (f2 > HS_MAX_EXP) {
                    d = (f3 - 1.0f) * alpha;
                } else if (f2 < -6.0f) {
                    d = (f3 - 0.0f) * alpha;
                } else {
                    int length2 = (int) ((f2 + HS_MAX_EXP) * ((((float) array.length()) / HS_MAX_EXP) / 2.0d));
                    if (length2 < array.length() && length2 >= 0) {
                        d = (f3 - array.getDouble(length2)) * alpha;
                    }
                    i++;
                    i2++;
                }
                z = true;
                Nd4j.getBlasWrapper().axpy(new Double(d), array4.getRow(requestMessage.getNegatives()[i2]), create);
                Nd4j.getBlasWrapper().axpy(new Double(d), array2.getRow(requestMessage.getW2()), array4.getRow(requestMessage.getNegatives()[i2]));
                i++;
                i2++;
            }
        }
        if (z) {
            Nd4j.getBlasWrapper().axpy(new Double(1.0d), create, array2.getRow(requestMessage.getW2()));
        }
        RequestDescriptor createDescriptor2 = RequestDescriptor.createDescriptor(skipGramChain.getOriginatorId(), skipGramChain.getFrameId());
        if (this.completionHandler.isTrackingFrame(createDescriptor2)) {
            this.completionHandler.notifyFrame(Long.valueOf(skipGramChain.getOriginatorId()), Long.valueOf(skipGramChain.getFrameId()), Long.valueOf(skipGramChain.getTaskId()));
            if (this.completionHandler.isCompleted(createDescriptor2)) {
                FrameCompletionHandler.FrameDescriptor completedFrameInfo = this.completionHandler.getCompletedFrameInfo(createDescriptor2);
                if (completedFrameInfo != null) {
                    FrameCompleteMessage frameCompleteMessage = new FrameCompleteMessage(skipGramChain.getFrameId());
                    frameCompleteMessage.setOriginatorId(completedFrameInfo.getFrameOriginatorId());
                    this.transport.sendMessage(frameCompleteMessage);
                } else {
                    log.warn("Frame double spending detected");
                }
            }
        } else {
            log.info("sI_{} isn't tracking this frame: Originator: {}, frameId: {}, taskId: {}", new Object[]{Short.valueOf(this.transport.getShardIndex()), Long.valueOf(skipGramChain.getOriginatorId()), Long.valueOf(skipGramChain.getFrameId()), Long.valueOf(j2)});
        }
        if (this.cntRounds.incrementAndGet() % 100000 == 0) {
            log.info("{} training rounds finished...", Long.valueOf(this.cntRounds.get()));
        }
        this.chains.remove(createDescriptor);
    }
}
