package com.gengoai.apollo.ml.model.topic;

import com.gengoai.ParamMap;
import com.gengoai.ParameterDef;
import com.gengoai.apollo.math.linalg.NDArray;
import com.gengoai.apollo.math.linalg.NDArrayFactory;
import com.gengoai.apollo.ml.DataSet;
import com.gengoai.apollo.ml.Datum;
import com.gengoai.apollo.ml.model.Params;
import com.gengoai.apollo.ml.observation.VariableCollection;
import com.gengoai.apollo.ml.observation.VariableNameSpace;
import com.gengoai.collection.Iterables;
import com.gengoai.collection.counter.Counter;
import com.gengoai.collection.counter.Counters;
import com.gengoai.function.Functional;
import java.lang.invoke.SerializedLambda;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.function.Consumer;
import java.util.stream.Stream;
import lombok.NonNull;
import org.apache.commons.math3.distribution.GammaDistribution;
import org.apache.commons.math3.special.Gamma;

/* loaded from: input_file:com/gengoai/apollo/ml/model/topic/OnlineLDA.class */
public class OnlineLDA extends BaseVectorTopicModel {
    private static final long serialVersionUID = 1;
    private final OnlineLDAFitParameters parameters;
    private NDArray lambda;
    private static final GammaDistribution GAMMA_DISTRIBUTION = new GammaDistribution(100.0d, 0.01d);
    public static final ParameterDef<Double> alpha = ParameterDef.doubleParam("alpha");
    public static final ParameterDef<Double> eta = ParameterDef.doubleParam("eta");
    public static final ParameterDef<Integer> inferenceSamples = ParameterDef.intParam("inferenceSamples");
    public static final ParameterDef<Double> kappa = ParameterDef.doubleParam("kappa");
    public static final ParameterDef<Double> tau0 = ParameterDef.doubleParam("tau0");

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/gengoai/apollo/ml/model/topic/OnlineLDA$ModelP.class */
    public class ModelP {
        NDArray lambda;
        NDArray eLogBeta;
        NDArray expELogBeta;
        NDArray stats;
        NDArray gamma;
        int K;
        double alpha;

        public ModelP() {
            this.K = ((Integer) OnlineLDA.this.parameters.K.value()).intValue();
            this.lambda = OnlineLDA.this.gammaSample(this.K, OnlineLDA.this.encoder.size());
            this.eLogBeta = OnlineLDA.this.dirichletExpectation(this.lambda);
            this.expELogBeta = this.eLogBeta.map(Math::exp);
            this.stats = this.lambda.zeroLike();
            this.alpha = ((Double) OnlineLDA.this.parameters.alpha.value()).doubleValue();
        }

        public ModelP(NDArray nDArray) {
            this.K = ((Integer) OnlineLDA.this.parameters.K.value()).intValue();
            this.lambda = nDArray;
            this.eLogBeta = OnlineLDA.this.dirichletExpectation(this.lambda);
            this.expELogBeta = this.eLogBeta.map(Math::exp);
            this.stats = this.lambda.zeroLike();
            this.alpha = ((Double) OnlineLDA.this.parameters.alpha.value()).doubleValue();
        }
    }

    /* loaded from: input_file:com/gengoai/apollo/ml/model/topic/OnlineLDA$OnlineLDAFitParameters.class */
    public static class OnlineLDAFitParameters extends TopicModelFitParameters {
        public final ParamMap<TopicModelFitParameters>.Parameter<Integer> K = parameter(Params.Clustering.K, 100);
        public final ParamMap<TopicModelFitParameters>.Parameter<Integer> batchSize = parameter(Params.Optimizable.batchSize, 512);
        public final ParamMap<TopicModelFitParameters>.Parameter<Double> alpha = parameter(OnlineLDA.alpha, Double.valueOf(0.1d));
        public final ParamMap<TopicModelFitParameters>.Parameter<Double> eta = parameter(OnlineLDA.eta, Double.valueOf(0.01d));
        public final ParamMap<TopicModelFitParameters>.Parameter<Double> tau0 = parameter(OnlineLDA.tau0, Double.valueOf(1.0d));
        public final ParamMap<TopicModelFitParameters>.Parameter<Double> kappa = parameter(OnlineLDA.kappa, Double.valueOf(0.75d));
        public final ParamMap<TopicModelFitParameters>.Parameter<Integer> inferenceSamples = parameter(OnlineLDA.inferenceSamples, 100);
    }

    public OnlineLDA() {
        this(new OnlineLDAFitParameters());
    }

    public OnlineLDA(@NonNull OnlineLDAFitParameters onlineLDAFitParameters) {
        if (onlineLDAFitParameters == null) {
            throw new NullPointerException("parameters is marked non-null but is null");
        }
        this.parameters = onlineLDAFitParameters;
    }

    public OnlineLDA(@NonNull Consumer<OnlineLDAFitParameters> consumer) {
        if (consumer == null) {
            throw new NullPointerException("updater is marked non-null but is null");
        }
        this.parameters = (OnlineLDAFitParameters) Functional.with(new OnlineLDAFitParameters(), consumer);
    }

    private NDArray dirichletExpectation(NDArray nDArray) {
        return nDArray.map(Gamma::digamma).subiColumnVector(nDArray.rowSums().mapi(Gamma::digamma));
    }

    private void eStep(ModelP modelP, List<NDArray> list) {
        modelP.gamma = gammaSample(list.size(), modelP.K);
        NDArray map = dirichletExpectation(modelP.gamma).map(Math::exp);
        modelP.stats = modelP.lambda.zeroLike();
        for (int i = 0; i < list.size(); i++) {
            NDArray nDArray = list.get(i);
            int[] sparseIndices = nDArray.sparseIndices();
            if (sparseIndices.length != 0) {
                NDArray array = NDArrayFactory.DENSE.array(sparseIndices.length);
                int i2 = 0;
                int i3 = 0;
                while (i2 < sparseIndices.length) {
                    array.set(i3, nDArray.get(sparseIndices[i2]));
                    i2++;
                    i3++;
                }
                NDArray row = modelP.gamma.getRow(i);
                NDArray row2 = map.getRow(i);
                NDArray columns = modelP.expELogBeta.getColumns(sparseIndices);
                NDArray addi = row2.mmul(columns).addi(1.0E-100d);
                for (int i4 = 0; i4 < ((Integer) this.parameters.inferenceSamples.value()).intValue(); i4++) {
                    NDArray nDArray2 = row;
                    row = row2.mul(array.div(addi).mmul(columns.T())).addi(modelP.alpha);
                    row2 = dirichletExpectation(row).map(Math::exp);
                    addi = row2.mmul(columns).addi(1.0E-100d);
                    if (row.map(nDArray2, (d, d2) -> {
                        return Math.abs(d - d2);
                    }).mean() < 0.001d) {
                        break;
                    }
                }
                modelP.gamma.setRow(i, row);
                NDArray outer = outer(row2, array.div(addi));
                for (int i5 = 0; i5 < sparseIndices.length; i5++) {
                    modelP.stats.incrementiColumn(sparseIndices[i5], outer.getColumn(i5));
                }
            }
        }
        modelP.stats.muli(modelP.expELogBeta);
    }

    private Stream<NDArray> encode(Datum datum) {
        return ((Boolean) this.parameters.combineInputs.value()).booleanValue() ? VariableCollection.mergeVariableSpace(datum.stream(getInputs())).getVariableSpace().map(variable -> {
            return toCountVector(variable, (VariableNameSpace) this.parameters.namingPattern.value());
        }) : datum.stream(getInputs()).map(observation -> {
            return toCountVector(observation, (VariableNameSpace) this.parameters.namingPattern.value());
        });
    }

    @Override // com.gengoai.apollo.ml.model.Model
    public void estimate(@NonNull DataSet dataSet) {
        if (dataSet == null) {
            throw new NullPointerException("dataset is marked non-null but is null");
        }
        encoderFit(dataSet, getInputs(), (VariableNameSpace) this.parameters.namingPattern.value());
        ModelP modelP = new ModelP();
        double size = dataSet.size();
        int i = 0;
        Iterator it = Iterables.asIterable(dataSet.batchIterator(((Integer) this.parameters.batchSize.value()).intValue())).iterator();
        while (it.hasNext()) {
            List<NDArray> collect = ((DataSet) it.next()).parallelStream().flatMap(this::encode).collect();
            eStep(modelP, collect);
            mStep(modelP, size, i, collect.size());
            i++;
        }
        modelP.lambda.diviColumnVector(modelP.lambda.rowSums());
        for (int i2 = 0; i2 < modelP.lambda.rows(); i2++) {
            NDArray row = modelP.lambda.getRow(i2);
            Counter newCounter = Counters.newCounter(new String[0]);
            row.forEachSparse((j, d) -> {
                newCounter.set(this.encoder.decode(j), d);
            });
            this.topics.add(new Topic(i2, newCounter));
        }
        this.lambda = modelP.lambda;
    }

    private NDArray gammaSample(int i, int i2) {
        return NDArrayFactory.DENSE.array(i, i2).mapi(d -> {
            return GAMMA_DISTRIBUTION.sample();
        });
    }

    @Override // com.gengoai.apollo.ml.model.topic.TopicModel, com.gengoai.apollo.ml.model.MultiInputModel, com.gengoai.apollo.ml.model.CombinableOutputModel, com.gengoai.apollo.ml.model.Model
    public OnlineLDAFitParameters getFitParameters() {
        return this.parameters;
    }

    @Override // com.gengoai.apollo.ml.model.topic.TopicModel
    public NDArray getTopicDistribution(String str) {
        NDArray array = NDArrayFactory.ND.array(getNumberOfTopics());
        Iterator<Topic> it = this.topics.iterator();
        while (it.hasNext()) {
            array.set(r0.getId(), it.next().getFeatureDistribution().get(str));
        }
        return array;
    }

    @Override // com.gengoai.apollo.ml.model.topic.BaseVectorTopicModel
    protected NDArray inference(NDArray nDArray) {
        ModelP modelP = new ModelP(this.lambda);
        eStep(modelP, Collections.singletonList(nDArray));
        return modelP.gamma.divi(modelP.gamma.sum());
    }

    private void mStep(ModelP modelP, double d, int i, int i2) {
        modelP.lambda = modelP.lambda.mul(1.0d - Math.pow(((Double) this.parameters.tau0.value()).doubleValue() + i, -((Double) this.parameters.kappa.value()).doubleValue())).add(modelP.stats.mul(d / i2).addi(((Double) this.parameters.eta.value()).doubleValue()));
        modelP.eLogBeta = dirichletExpectation(modelP.lambda);
        modelP.expELogBeta = modelP.eLogBeta.map(Math::exp);
    }

    private NDArray outer(NDArray nDArray, NDArray nDArray2) {
        NDArray array = NDArrayFactory.DENSE.array((int) nDArray.length(), (int) nDArray2.length());
        long j = 0;
        while (true) {
            long j2 = j;
            if (j2 >= nDArray.length()) {
                return array;
            }
            long j3 = 0;
            while (true) {
                long j4 = j3;
                if (j4 < nDArray2.length()) {
                    array.set((int) j2, (int) j4, nDArray.get(j2) * nDArray2.get(j4));
                    j3 = j4 + serialVersionUID;
                }
            }
            j = j2 + serialVersionUID;
        }
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case -1298776554:
                if (implMethodName.equals("encode")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 7 && serializedLambda.getFunctionalInterfaceClass().equals("com/gengoai/function/SerializableFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("com/gengoai/apollo/ml/model/topic/OnlineLDA") && serializedLambda.getImplMethodSignature().equals("(Lcom/gengoai/apollo/ml/Datum;)Ljava/util/stream/Stream;")) {
                    OnlineLDA onlineLDA = (OnlineLDA) serializedLambda.getCapturedArg(0);
                    return onlineLDA::encode;
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
