package com.gengoai.apollo.ml.model.embedding;

import com.gengoai.ParamMap;
import com.gengoai.apollo.math.linalg.DenseMatrix;
import com.gengoai.apollo.ml.DataSet;
import com.gengoai.apollo.ml.model.Params;
import com.gengoai.function.Functional;
import java.lang.invoke.SerializedLambda;
import java.util.Collection;
import java.util.Date;
import java.util.List;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import lombok.NonNull;
import org.apache.spark.mllib.feature.Word2VecModel;
import org.jblas.FloatMatrix;
import org.jblas.MatrixFunctions;
import scala.collection.JavaConversions;

/* loaded from: input_file:com/gengoai/apollo/ml/model/embedding/Word2Vec.class */
public class Word2Vec extends TrainableWordEmbedding<Parameters, Word2Vec> {
    private static final long serialVersionUID = 1;

    /* loaded from: input_file:com/gengoai/apollo/ml/model/embedding/Word2Vec$Parameters.class */
    public static class Parameters extends WordEmbeddingFitParameters<Parameters> {
        public final ParamMap<Parameters>.Parameter<Double> learningRate = parameter(Params.Optimizable.learningRate, Double.valueOf(0.025d));
        public final ParamMap<Parameters>.Parameter<Integer> maxIterations = parameter(Params.Optimizable.maxIterations, 1);
    }

    public Word2Vec() {
        super(new Parameters());
    }

    public Word2Vec(@NonNull Parameters parameters) {
        super(parameters);
        if (parameters == null) {
            throw new NullPointerException("parameters is marked non-null but is null");
        }
    }

    public Word2Vec(@NonNull Consumer<Parameters> consumer) {
        super((Parameters) Functional.with(new Parameters(), consumer));
        if (consumer == null) {
            throw new NullPointerException("updater is marked non-null but is null");
        }
    }

    @Override // com.gengoai.apollo.ml.model.Model
    public void estimate(@NonNull DataSet dataSet) {
        if (dataSet == null) {
            throw new NullPointerException("dataset is marked non-null but is null");
        }
        org.apache.spark.mllib.feature.Word2Vec word2Vec = new org.apache.spark.mllib.feature.Word2Vec();
        word2Vec.setMinCount(1);
        word2Vec.setVectorSize(((Integer) ((Parameters) this.parameters).dimension.value()).intValue());
        word2Vec.setLearningRate(((Double) ((Parameters) this.parameters).learningRate.value()).doubleValue());
        word2Vec.setNumIterations(((Integer) ((Parameters) this.parameters).maxIterations.value()).intValue());
        word2Vec.setWindowSize(((Integer) ((Parameters) this.parameters).windowSize.value()).intValue());
        word2Vec.setSeed(new Date().getTime());
        Word2VecModel fit = word2Vec.fit(dataSet.stream().toDistributedStream().flatMap(datum -> {
            return datum.stream((Collection<String>) ((Parameters) this.parameters).inputs.value());
        }).map(observation -> {
            return (List) observation.getVariableSpace().map(this::getVariableName).collect(Collectors.toList());
        }).getRDD());
        this.vectorStore = new InMemoryVectorStore(((Integer) ((Parameters) this.parameters).dimension.value()).intValue(), (String) ((Parameters) this.parameters).unknownWord.value(), (String[]) ((Parameters) this.parameters).specialWords.value());
        JavaConversions.mapAsJavaMap(fit.getVectors()).forEach((str, fArr) -> {
            this.vectorStore.updateVector(this.vectorStore.addOrGetIndex(str), new DenseMatrix(MatrixFunctions.floatToDouble(new FloatMatrix(1, fArr.length, fArr))).setLabel(str));
        });
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case 203636563:
                if (implMethodName.equals("lambda$estimate$3a27d8ff$1")) {
                    z = false;
                    break;
                }
                break;
            case 203636564:
                if (implMethodName.equals("lambda$estimate$3a27d8ff$2")) {
                    z = true;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 7 && serializedLambda.getFunctionalInterfaceClass().equals("com/gengoai/function/SerializableFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("com/gengoai/apollo/ml/model/embedding/Word2Vec") && serializedLambda.getImplMethodSignature().equals("(Lcom/gengoai/apollo/ml/Datum;)Ljava/util/stream/Stream;")) {
                    Word2Vec word2Vec = (Word2Vec) serializedLambda.getCapturedArg(0);
                    return datum -> {
                        return datum.stream((Collection<String>) ((Parameters) this.parameters).inputs.value());
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 7 && serializedLambda.getFunctionalInterfaceClass().equals("com/gengoai/function/SerializableFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("com/gengoai/apollo/ml/model/embedding/Word2Vec") && serializedLambda.getImplMethodSignature().equals("(Lcom/gengoai/apollo/ml/observation/Observation;)Ljava/util/List;")) {
                    Word2Vec word2Vec2 = (Word2Vec) serializedLambda.getCapturedArg(0);
                    return observation -> {
                        return (List) observation.getVariableSpace().map(this::getVariableName).collect(Collectors.toList());
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
