package com.gengoai.apollo.ml.model.embedding;

import com.gengoai.LogUtils;
import com.gengoai.ParamMap;
import com.gengoai.ParameterDef;
import com.gengoai.Stopwatch;
import com.gengoai.apollo.math.linalg.DenseMatrix;
import com.gengoai.apollo.ml.DataSet;
import com.gengoai.apollo.ml.model.Params;
import com.gengoai.apollo.ml.observation.Observation;
import com.gengoai.apollo.ml.observation.Sequence;
import com.gengoai.apollo.ml.observation.VariableNameSpace;
import com.gengoai.collection.disk.DiskMap;
import com.gengoai.concurrent.AtomicDouble;
import com.gengoai.function.Functional;
import com.gengoai.io.Resources;
import com.gengoai.tuple.IntPair;
import java.lang.invoke.SerializedLambda;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Consumer;
import java.util.logging.Logger;
import lombok.NonNull;
import org.jblas.DoubleMatrix;
import org.jblas.MatrixFunctions;

/* loaded from: input_file:com/gengoai/apollo/ml/model/embedding/Glove.class */
public class Glove extends TrainableWordEmbedding<Parameters, Glove> {
    private static final long serialVersionUID = 1;
    private static final Logger log = Logger.getLogger(Glove.class.getName());
    public static final ParameterDef<Double> alpha = ParameterDef.doubleParam("alpha");
    public static final ParameterDef<Integer> xMax = ParameterDef.intParam("xMax");

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/gengoai/apollo/ml/model/embedding/Glove$Cooccurrence.class */
    public static class Cooccurrence {
        public final double count;
        public final int word1;
        public final int word2;

        public Cooccurrence(int i, int i2, double d) {
            this.word1 = i;
            this.word2 = i2;
            this.count = d;
        }
    }

    /* loaded from: input_file:com/gengoai/apollo/ml/model/embedding/Glove$Parameters.class */
    public static class Parameters extends WordEmbeddingFitParameters<Parameters> {
        public final ParamMap<Parameters>.Parameter<Double> alpha = parameter(Glove.alpha, Double.valueOf(0.75d));
        public final ParamMap<Parameters>.Parameter<Double> learningRate = parameter(Params.Optimizable.learningRate, Double.valueOf(0.05d));
        public final ParamMap<Parameters>.Parameter<Integer> xMax = parameter(Glove.xMax, 100);
        public final ParamMap<Parameters>.Parameter<Integer> maxIterations = parameter(Params.Optimizable.maxIterations, 25);
    }

    public Glove() {
        super(new Parameters());
    }

    public Glove(@NonNull Parameters parameters) {
        super(parameters);
        if (parameters == null) {
            throw new NullPointerException("parameters is marked non-null but is null");
        }
    }

    public Glove(@NonNull Consumer<Parameters> consumer) {
        super((Parameters) Functional.with(new Parameters(), consumer));
        if (consumer == null) {
            throw new NullPointerException("updater is marked non-null but is null");
        }
    }

    @Override // com.gengoai.apollo.ml.model.Model
    public void estimate(@NonNull DataSet dataSet) {
        if (dataSet == null) {
            throw new NullPointerException("dataset is marked non-null but is null");
        }
        Stopwatch createStarted = Stopwatch.createStarted();
        new AtomicLong(0L);
        this.vectorStore = new InMemoryVectorStore(((Integer) ((Parameters) this.parameters).dimension.value()).intValue(), (String) ((Parameters) this.parameters).unknownWord.value(), (String[]) ((Parameters) this.parameters).specialWords.value());
        AtomicDouble atomicDouble = new AtomicDouble(0.0d);
        DiskMap build = DiskMap.builder().file(Resources.temporaryFile()).namespace("counts").build();
        dataSet.stream().forEach(datum -> {
            atomicDouble.addAndGet(1.0d);
            datum.stream((Collection<String>) ((Parameters) this.parameters).inputs.value()).forEach(observation -> {
                List<Integer> indices = toIndices(observation.asSequence());
                for (int i = 1; i < indices.size(); i++) {
                    int intValue = indices.get(i).intValue();
                    for (int max = Math.max(0, i - ((Integer) ((Parameters) this.parameters).windowSize.value()).intValue()); max < i; max++) {
                        int intValue2 = indices.get(max).intValue();
                        double d = 1.0d / (i - max);
                        double doubleValue = ((Double) build.getOrDefault(IntPair.of(intValue, intValue2), Double.valueOf(0.0d))).doubleValue();
                        build.put(IntPair.of(intValue, intValue2), Double.valueOf(doubleValue + d));
                        build.put(IntPair.of(intValue2, intValue), Double.valueOf(doubleValue + d));
                    }
                }
            });
        });
        createStarted.stop();
        if (((Boolean) ((Parameters) this.parameters).verbose.value()).booleanValue()) {
            LogUtils.logInfo(log, "Cooccurrence Matrix computed in {0}", new Object[]{createStarted});
        }
        ArrayList<Cooccurrence> arrayList = new ArrayList();
        build.forEach((intPair, d) -> {
            if (d.doubleValue() >= 5.0d) {
                arrayList.add(new Cooccurrence(intPair.v1, intPair.v2, d.doubleValue()));
            }
        });
        build.clear();
        DoubleMatrix[] doubleMatrixArr = new DoubleMatrix[this.vectorStore.size() * 2];
        DoubleMatrix[] doubleMatrixArr2 = new DoubleMatrix[this.vectorStore.size() * 2];
        for (int i = 0; i < this.vectorStore.size() * 2; i++) {
            doubleMatrixArr[i] = DoubleMatrix.rand(((Integer) ((Parameters) this.parameters).dimension.value()).intValue()).sub(0.5d).divi(((Integer) ((Parameters) this.parameters).dimension.value()).intValue());
            doubleMatrixArr2[i] = DoubleMatrix.ones(((Integer) ((Parameters) this.parameters).dimension.value()).intValue());
        }
        DoubleMatrix divi = DoubleMatrix.rand(this.vectorStore.size() * 2).sub(0.5d).divi(((Integer) ((Parameters) this.parameters).dimension.value()).intValue());
        DoubleMatrix ones = DoubleMatrix.ones(this.vectorStore.size() * 2);
        int size = this.vectorStore.size();
        for (int i2 = 0; i2 < ((Integer) ((Parameters) this.parameters).maxIterations.value()).intValue(); i2++) {
            double d2 = 0.0d;
            Collections.shuffle(arrayList);
            for (Cooccurrence cooccurrence : arrayList) {
                int i3 = cooccurrence.word1;
                int i4 = cooccurrence.word2 + size;
                double d3 = cooccurrence.count;
                DoubleMatrix doubleMatrix = doubleMatrixArr[i3];
                double d4 = divi.get(i3);
                DoubleMatrix doubleMatrix2 = doubleMatrixArr2[i3];
                double d5 = ones.get(i3);
                DoubleMatrix doubleMatrix3 = doubleMatrixArr[i4];
                double d6 = divi.get(i4);
                DoubleMatrix doubleMatrix4 = doubleMatrixArr2[i4];
                double d7 = ones.get(i4);
                double dot = ((doubleMatrix.dot(doubleMatrix3) + d4) + d6) - Math.log(d3);
                double pow = d3 > ((double) ((Integer) ((Parameters) this.parameters).xMax.value()).intValue()) ? dot : Math.pow(d3 / ((Integer) ((Parameters) this.parameters).xMax.value()).intValue(), ((Double) ((Parameters) this.parameters).alpha.value()).doubleValue()) * dot;
                d2 += 0.5d * pow * dot;
                double doubleValue = pow * ((Double) ((Parameters) this.parameters).learningRate.value()).doubleValue();
                DoubleMatrix mmul = doubleMatrix3.mmul(doubleValue);
                DoubleMatrix mmul2 = doubleMatrix.mmul(doubleValue);
                doubleMatrix.subi(mmul.divi(MatrixFunctions.sqrt(doubleMatrix2)));
                doubleMatrix3.subi(mmul2.divi(MatrixFunctions.sqrt(doubleMatrix4)));
                doubleMatrix2.addi(MatrixFunctions.pow(mmul2, 2.0d));
                doubleMatrixArr2[i3] = doubleMatrix2;
                doubleMatrix4.addi(MatrixFunctions.pow(mmul, 2.0d));
                doubleMatrixArr2[i4] = doubleMatrix4;
                divi.put(i3, d4 - (doubleValue / Math.sqrt(d5)));
                divi.put(i4, d6 - (doubleValue / Math.sqrt(d7)));
                double d8 = doubleValue * doubleValue;
                ones.put(i3, ones.get(i3) + d8);
                ones.put(i4, ones.get(i4) + d8);
            }
            if (((Boolean) ((Parameters) this.parameters).verbose.value()).booleanValue()) {
                LogUtils.logInfo(log, "Iteration: {0},  cost:{1}", new Object[]{Integer.valueOf(i2 + 1), Double.valueOf(d2 / arrayList.size())});
            }
        }
        for (int i5 = 0; i5 < size; i5++) {
            doubleMatrixArr[i5].addi(doubleMatrixArr[i5 + size]);
            this.vectorStore.updateVector(i5, new DenseMatrix(doubleMatrixArr[i5]).T().setLabel(this.vectorStore.decode(i5)));
        }
    }

    private List<Integer> toIndices(Sequence<? extends Observation> sequence) {
        ArrayList arrayList = new ArrayList();
        Iterator<T> it = sequence.iterator();
        while (it.hasNext()) {
            ((Observation) it.next()).getVariableSpace().forEach(variable -> {
                arrayList.add(Integer.valueOf(this.vectorStore.addOrGetIndex(((VariableNameSpace) ((Parameters) this.parameters).nameSpace.value()).getName(variable))));
            });
        }
        return arrayList;
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case 151956200:
                if (implMethodName.equals("lambda$estimate$2e67cd18$1")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 7 && serializedLambda.getFunctionalInterfaceClass().equals("com/gengoai/function/SerializableConsumer") && serializedLambda.getFunctionalInterfaceMethodName().equals("accept") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)V") && serializedLambda.getImplClass().equals("com/gengoai/apollo/ml/model/embedding/Glove") && serializedLambda.getImplMethodSignature().equals("(Lcom/gengoai/concurrent/AtomicDouble;Lcom/gengoai/collection/disk/DiskMap;Lcom/gengoai/apollo/ml/Datum;)V")) {
                    Glove glove = (Glove) serializedLambda.getCapturedArg(0);
                    AtomicDouble atomicDouble = (AtomicDouble) serializedLambda.getCapturedArg(1);
                    DiskMap diskMap = (DiskMap) serializedLambda.getCapturedArg(2);
                    return datum -> {
                        atomicDouble.addAndGet(1.0d);
                        datum.stream((Collection<String>) ((Parameters) this.parameters).inputs.value()).forEach(observation -> {
                            List<Integer> indices = toIndices(observation.asSequence());
                            for (int i = 1; i < indices.size(); i++) {
                                int intValue = indices.get(i).intValue();
                                for (int max = Math.max(0, i - ((Integer) ((Parameters) this.parameters).windowSize.value()).intValue()); max < i; max++) {
                                    int intValue2 = indices.get(max).intValue();
                                    double d = 1.0d / (i - max);
                                    double doubleValue = ((Double) diskMap.getOrDefault(IntPair.of(intValue, intValue2), Double.valueOf(0.0d))).doubleValue();
                                    diskMap.put(IntPair.of(intValue, intValue2), Double.valueOf(doubleValue + d));
                                    diskMap.put(IntPair.of(intValue2, intValue), Double.valueOf(doubleValue + d));
                                }
                            }
                        });
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
