package com.gengoai.apollo.ml.model.topic;

import cc.mallet.pipe.SerialPipes;
import cc.mallet.pipe.TargetStringToFeatures;
import cc.mallet.pipe.TokenSequence2FeatureSequence;
import cc.mallet.topics.ParallelTopicModel;
import cc.mallet.topics.TopicInferencer;
import cc.mallet.types.Alphabet;
import cc.mallet.types.IDSorter;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import com.gengoai.ParamMap;
import com.gengoai.ParameterDef;
import com.gengoai.SystemInfo;
import com.gengoai.apollo.math.linalg.NDArray;
import com.gengoai.apollo.math.linalg.NDArrayFactory;
import com.gengoai.apollo.ml.DataSet;
import com.gengoai.apollo.ml.Datum;
import com.gengoai.apollo.ml.model.Params;
import com.gengoai.apollo.ml.observation.Observation;
import com.gengoai.apollo.ml.observation.Variable;
import com.gengoai.apollo.ml.observation.VariableCollection;
import com.gengoai.apollo.ml.observation.VariableList;
import com.gengoai.apollo.ml.observation.VariableNameSpace;
import com.gengoai.collection.counter.Counter;
import com.gengoai.collection.counter.Counters;
import com.gengoai.function.Functional;
import com.gengoai.string.Strings;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;
import java.util.Set;
import java.util.TreeSet;
import java.util.function.Consumer;
import java.util.logging.Level;
import java.util.stream.Stream;
import lombok.NonNull;

/* loaded from: input_file:com/gengoai/apollo/ml/model/topic/MalletLDA.class */
public class MalletLDA extends TopicModel {
    private static final long serialVersionUID = 1;
    public static final ParameterDef<Integer> burnIn = ParameterDef.intParam("burnIn");
    public static final ParameterDef<Integer> optimizationInterval = ParameterDef.intParam("optimizationInterval");
    public static final ParameterDef<Boolean> symmetricAlpha = ParameterDef.boolParam("symmetricAlpha");
    private final Parameters parameters;
    private volatile transient TopicInferencer inferencer;
    private SerialPipes pipes;
    private ParallelTopicModel topicModel;

    /* loaded from: input_file:com/gengoai/apollo/ml/model/topic/MalletLDA$Parameters.class */
    public static class Parameters extends TopicModelFitParameters {
        public final ParamMap<TopicModelFitParameters>.Parameter<Integer> K = parameter(Params.Clustering.K, 100);
        public final ParamMap<TopicModelFitParameters>.Parameter<Integer> burnIn = parameter(MalletLDA.burnIn, 500);
        public final ParamMap<TopicModelFitParameters>.Parameter<Integer> maxIterations = parameter(Params.Optimizable.maxIterations, 2000);
        public final ParamMap<TopicModelFitParameters>.Parameter<Integer> optimizationInterval = parameter(MalletLDA.optimizationInterval, 10);
        public final ParamMap<TopicModelFitParameters>.Parameter<Boolean> symmetricAlpha = parameter(MalletLDA.symmetricAlpha, false);
    }

    public MalletLDA() {
        this(new Parameters());
    }

    public MalletLDA(@NonNull Parameters parameters) {
        if (parameters == null) {
            throw new NullPointerException("parameters is marked non-null but is null");
        }
        this.parameters = parameters;
    }

    public MalletLDA(@NonNull Consumer<Parameters> consumer) {
        if (consumer == null) {
            throw new NullPointerException("updater is marked non-null but is null");
        }
        this.parameters = (Parameters) Functional.with(new Parameters(), consumer);
    }

    private Topic createTopic(int i) {
        Alphabet dataAlphabet = this.pipes.getDataAlphabet();
        ArrayList sortedWords = this.topicModel.getSortedWords();
        double[][] topicWords = this.topicModel.getTopicWords(true, true);
        Iterator it = ((TreeSet) sortedWords.get(i)).iterator();
        Counter newCounter = Counters.newCounter(new String[0]);
        while (it.hasNext()) {
            IDSorter iDSorter = (IDSorter) it.next();
            newCounter.set(dataAlphabet.lookupObject(iDSorter.getID()).toString(), topicWords[i][iDSorter.getID()]);
        }
        return new Topic(i, newCounter);
    }

    @Override // com.gengoai.apollo.ml.model.Model
    public void estimate(@NonNull DataSet dataSet) {
        if (dataSet == null) {
            throw new NullPointerException("dataset is marked non-null but is null");
        }
        if (((Boolean) this.parameters.verbose.value()).booleanValue()) {
            ParallelTopicModel.logger.setLevel(Level.INFO);
        } else {
            ParallelTopicModel.logger.setLevel(Level.OFF);
        }
        this.topics.clear();
        this.pipes = new SerialPipes(Arrays.asList(new TargetStringToFeatures(), new InstanceToTokenSequence(), new TokenSequence2FeatureSequence()));
        InstanceList instanceList = new InstanceList(this.pipes);
        Iterator<Datum> it = dataSet.iterator();
        while (it.hasNext()) {
            Datum next = it.next();
            if (((Boolean) this.parameters.combineInputs.value()).booleanValue()) {
                instanceList.addThruPipe(new Instance(VariableCollection.mergeVariableSpace(next.stream((Collection<String>) this.parameters.inputs.value()), (VariableNameSpace) this.parameters.namingPattern.value()), "", (Object) null, (Object) null));
            } else {
                Iterator it2 = ((Set) this.parameters.inputs.value()).iterator();
                while (it2.hasNext()) {
                    Observation observation = next.get((String) it2.next());
                    if (this.parameters.namingPattern.value() != VariableNameSpace.Full) {
                        observation = new VariableList((Stream<Variable>) observation.getVariableSpace().map(variable -> {
                            return Variable.real(((VariableNameSpace) this.parameters.namingPattern.value()).getName(variable), variable.getValue());
                        }));
                    }
                    instanceList.addThruPipe(new Instance(observation, "", (Object) null, (Object) null));
                }
            }
        }
        this.topicModel = new ParallelTopicModel(((Integer) this.parameters.K.value()).intValue());
        this.topicModel.addInstances(instanceList);
        this.topicModel.setNumIterations(((Integer) this.parameters.maxIterations.value()).intValue());
        this.topicModel.setNumThreads(SystemInfo.NUMBER_OF_PROCESSORS - 1);
        this.topicModel.setBurninPeriod(((Integer) this.parameters.burnIn.value()).intValue());
        this.topicModel.setOptimizeInterval(((Integer) this.parameters.optimizationInterval.value()).intValue());
        this.topicModel.setSymmetricAlpha(((Boolean) this.parameters.symmetricAlpha.value()).booleanValue());
        try {
            this.topicModel.estimate();
            for (int i = 0; i < ((Integer) this.parameters.K.value()).intValue(); i++) {
                this.topics.add(createTopic(i));
            }
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    @Override // com.gengoai.apollo.ml.model.topic.TopicModel, com.gengoai.apollo.ml.model.MultiInputModel, com.gengoai.apollo.ml.model.CombinableOutputModel, com.gengoai.apollo.ml.model.Model
    public Parameters getFitParameters() {
        return this.parameters;
    }

    private TopicInferencer getInferencer() {
        if (this.inferencer == null) {
            synchronized (this) {
                if (this.inferencer == null) {
                    TopicInferencer inferencer = this.topicModel.getInferencer();
                    inferencer.setRandomSeed(1234);
                    this.inferencer = inferencer;
                }
            }
        }
        return this.inferencer;
    }

    @Override // com.gengoai.apollo.ml.model.topic.TopicModel
    public NDArray getTopicDistribution(String str) {
        int lookupIndex = this.pipes.getDataAlphabet().lookupIndex(str, false);
        if (lookupIndex == -1) {
            return NDArrayFactory.ND.array(this.topicModel.numTopics);
        }
        double[] dArr = new double[this.topicModel.numTopics];
        double[][] topicWords = this.topicModel.getTopicWords(true, true);
        for (int i = 0; i < this.topicModel.numTopics; i++) {
            dArr[i] = topicWords[i][lookupIndex];
        }
        return NDArrayFactory.ND.rowVector(dArr);
    }

    private NDArray inference(Observation observation) {
        InstanceList instanceList = new InstanceList(this.pipes);
        instanceList.addThruPipe(new Instance(observation, "", (Object) null, (Object) null));
        return NDArrayFactory.ND.rowVector(getInferencer().getSampledDistribution((Instance) instanceList.get(0), 800, 5, 100));
    }

    @Override // com.gengoai.apollo.ml.transform.Transform
    public Datum transform(@NonNull Datum datum) {
        if (datum == null) {
            throw new NullPointerException("datum is marked non-null but is null");
        }
        if (((Boolean) this.parameters.combineOutputs.value()).booleanValue()) {
            datum.put((String) this.parameters.output.value(), (Observation) inference(VariableCollection.mergeVariableSpace(datum.stream((Collection<String>) this.parameters.inputs.value()), (VariableNameSpace) this.parameters.namingPattern.value())));
        } else {
            for (String str : (Set) this.parameters.inputs.value()) {
                Observation observation = datum.get(str);
                if (this.parameters.namingPattern.value() != VariableNameSpace.Full) {
                    observation = new VariableList((Stream<Variable>) observation.getVariableSpace().map(variable -> {
                        return Variable.real(((VariableNameSpace) this.parameters.namingPattern.value()).getName(variable), variable.getValue());
                    }));
                }
                datum.put(str + Strings.nullToEmpty((String) this.parameters.outputSuffix.value()), (Observation) inference(observation));
            }
        }
        return datum;
    }
}
