package com.gengoai.apollo.ml.model.topic;

import com.gengoai.apollo.math.linalg.NDArray;
import com.gengoai.apollo.ml.DataSet;
import com.gengoai.apollo.ml.model.LabelType;
import com.gengoai.apollo.ml.model.MultiInputModel;
import com.gengoai.collection.Iterators;
import java.lang.invoke.SerializedLambda;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import lombok.NonNull;

/* loaded from: input_file:com/gengoai/apollo/ml/model/topic/TopicModel.class */
public abstract class TopicModel implements MultiInputModel<TopicModelFitParameters, TopicModel>, Iterable<Topic> {
    private static final long serialVersionUID = 1;
    protected final List<Topic> topics = new ArrayList();

    @Override // com.gengoai.apollo.ml.model.MultiInputModel, com.gengoai.apollo.ml.model.CombinableOutputModel, com.gengoai.apollo.ml.model.Model
    public abstract TopicModelFitParameters getFitParameters();

    public final int getNumberOfTopics() {
        return this.topics.size();
    }

    @Override // com.gengoai.apollo.ml.model.CombinableOutputModel
    public LabelType getOutputType() {
        return LabelType.NDArray;
    }

    public Topic getTopic(int i) {
        return this.topics.get(i);
    }

    public Topic getTopic(@NonNull String str) {
        if (str == null) {
            throw new NullPointerException("name is marked non-null but is null");
        }
        return this.topics.stream().filter(topic -> {
            return str.equals(topic.getName());
        }).findFirst().orElseThrow(IndexOutOfBoundsException::new);
    }

    public abstract NDArray getTopicDistribution(String str);

    @Override // java.lang.Iterable
    public Iterator<Topic> iterator() {
        return Iterators.unmodifiableIterator(this.topics.iterator());
    }

    @Override // com.gengoai.apollo.ml.transform.Transform
    public DataSet transform(@NonNull DataSet dataSet) {
        if (dataSet == null) {
            throw new NullPointerException("dataset is marked non-null but is null");
        }
        DataSet map = dataSet.map(this::transform);
        Iterator<String> it = getOutputs().iterator();
        while (it.hasNext()) {
            map.updateMetadata(it.next(), observationMetadata -> {
                observationMetadata.setDimension(getNumberOfTopics());
                observationMetadata.setEncoder(null);
                observationMetadata.setType(NDArray.class);
            });
        }
        return map;
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case 1052666732:
                if (implMethodName.equals("transform")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 9 && serializedLambda.getFunctionalInterfaceClass().equals("com/gengoai/function/SerializableFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("com/gengoai/apollo/ml/transform/Transform") && serializedLambda.getImplMethodSignature().equals("(Lcom/gengoai/apollo/ml/Datum;)Lcom/gengoai/apollo/ml/Datum;")) {
                    TopicModel topicModel = (TopicModel) serializedLambda.getCapturedArg(0);
                    return topicModel::transform;
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
