package com.gengoai.apollo.ml.model.topic;

import com.gengoai.ParamMap;
import com.gengoai.apollo.math.linalg.NDArray;
import com.gengoai.apollo.math.linalg.NDArrayFactory;
import com.gengoai.apollo.math.linalg.SparkLinearAlgebra;
import com.gengoai.apollo.ml.DataSet;
import com.gengoai.apollo.ml.Datum;
import com.gengoai.apollo.ml.model.Params;
import com.gengoai.apollo.ml.observation.VariableCollection;
import com.gengoai.apollo.ml.observation.VariableNameSpace;
import com.gengoai.collection.counter.Counter;
import com.gengoai.collection.counter.Counters;
import com.gengoai.function.Functional;
import com.gengoai.stream.spark.SparkStream;
import java.lang.invoke.SerializedLambda;
import java.util.ArrayList;
import java.util.List;
import java.util.function.Consumer;
import java.util.stream.Stream;
import lombok.NonNull;
import org.apache.spark.mllib.linalg.DenseVector;
import org.apache.spark.mllib.linalg.Matrix;
import org.apache.spark.mllib.linalg.distributed.RowMatrix;

/* loaded from: input_file:com/gengoai/apollo/ml/model/topic/LSA.class */
public class LSA extends BaseVectorTopicModel {
    private static final long serialVersionUID = 1;
    private final Parameters parameters;
    private List<NDArray> topicVectors;

    /* loaded from: input_file:com/gengoai/apollo/ml/model/topic/LSA$Parameters.class */
    public static class Parameters extends TopicModelFitParameters {
        public final ParamMap<TopicModelFitParameters>.Parameter<Integer> K = parameter(Params.Clustering.K, 100);
    }

    public LSA() {
        this(new Parameters());
    }

    public LSA(@NonNull Parameters parameters) {
        this.topicVectors = new ArrayList();
        if (parameters == null) {
            throw new NullPointerException("parameters is marked non-null but is null");
        }
        this.parameters = parameters;
    }

    public LSA(@NonNull Consumer<Parameters> consumer) {
        this.topicVectors = new ArrayList();
        if (consumer == null) {
            throw new NullPointerException("updater is marked non-null but is null");
        }
        this.parameters = (Parameters) Functional.with(new Parameters(), consumer);
    }

    private Stream<NDArray> encode(Datum datum) {
        return ((Boolean) this.parameters.combineInputs.value()).booleanValue() ? VariableCollection.mergeVariableSpace(datum.stream(getInputs())).getVariableSpace().map(variable -> {
            return toCountVector(variable, (VariableNameSpace) this.parameters.namingPattern.value());
        }) : datum.stream(getInputs()).map(observation -> {
            return toCountVector(observation, (VariableNameSpace) this.parameters.namingPattern.value());
        });
    }

    @Override // com.gengoai.apollo.ml.model.Model
    public void estimate(@NonNull DataSet dataSet) {
        if (dataSet == null) {
            throw new NullPointerException("dataset is marked non-null but is null");
        }
        encoderFit(dataSet, getInputs(), (VariableNameSpace) this.parameters.namingPattern.value());
        NDArray matrix = SparkLinearAlgebra.toMatrix(((Matrix) SparkLinearAlgebra.sparkSVD(new RowMatrix(new SparkStream(dataSet.parallelStream().toDistributedStream().flatMap(this::encode).map(nDArray -> {
            return new DenseVector(nDArray.toDoubleArray());
        })).cache().getRDD().rdd()), ((Integer) this.parameters.K.value()).intValue()).V()).transpose());
        for (int i = 0; i < ((Integer) this.parameters.K.value()).intValue(); i++) {
            Counter newCounter = Counters.newCounter(new String[0]);
            NDArray columnVector = NDArrayFactory.ND.columnVector(matrix.getRow(i).toDoubleArray());
            columnVector.forEachSparse((j, d) -> {
                newCounter.set(this.encoder.decode(j), d);
            });
            this.topics.add(new Topic(i, newCounter));
            this.topicVectors.add(columnVector);
        }
    }

    @Override // com.gengoai.apollo.ml.model.topic.TopicModel, com.gengoai.apollo.ml.model.MultiInputModel, com.gengoai.apollo.ml.model.CombinableOutputModel, com.gengoai.apollo.ml.model.Model
    public Parameters getFitParameters() {
        return this.parameters;
    }

    @Override // com.gengoai.apollo.ml.model.topic.TopicModel
    public NDArray getTopicDistribution(String str) {
        int encode = this.encoder.encode(str);
        if (encode == -1) {
            return NDArrayFactory.ND.rowVector(new double[this.topics.size()]);
        }
        double[] dArr = new double[this.topics.size()];
        for (int i = 0; i < this.topics.size(); i++) {
            dArr[i] = this.topicVectors.get(i).get(encode);
        }
        return NDArrayFactory.ND.rowVector(dArr);
    }

    @Override // com.gengoai.apollo.ml.model.topic.BaseVectorTopicModel
    protected NDArray inference(NDArray nDArray) {
        double[] dArr = new double[this.topics.size()];
        for (int i = 0; i < this.topics.size(); i++) {
            dArr[i] = nDArray.dot(this.topicVectors.get(i));
        }
        return NDArrayFactory.ND.rowVector(dArr);
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case -1298776554:
                if (implMethodName.equals("encode")) {
                    z = false;
                    break;
                }
                break;
            case 197063893:
                if (implMethodName.equals("lambda$estimate$51860b52$1")) {
                    z = true;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 7 && serializedLambda.getFunctionalInterfaceClass().equals("com/gengoai/function/SerializableFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("com/gengoai/apollo/ml/model/topic/LSA") && serializedLambda.getImplMethodSignature().equals("(Lcom/gengoai/apollo/ml/Datum;)Ljava/util/stream/Stream;")) {
                    LSA lsa = (LSA) serializedLambda.getCapturedArg(0);
                    return lsa::encode;
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("com/gengoai/function/SerializableFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("com/gengoai/apollo/ml/model/topic/LSA") && serializedLambda.getImplMethodSignature().equals("(Lcom/gengoai/apollo/math/linalg/NDArray;)Lorg/apache/spark/mllib/linalg/Vector;")) {
                    return nDArray -> {
                        return new DenseVector(nDArray.toDoubleArray());
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
