package com.gengoai.apollo.ml.model.topic;

import com.gengoai.apollo.math.linalg.NDArray;
import com.gengoai.apollo.math.linalg.NDArrayFactory;
import com.gengoai.apollo.ml.DataSet;
import com.gengoai.apollo.ml.Datum;
import com.gengoai.apollo.ml.encoder.IndexEncoder;
import com.gengoai.apollo.ml.observation.Observation;
import com.gengoai.apollo.ml.observation.VariableCollection;
import com.gengoai.apollo.ml.observation.VariableNameSpace;
import com.gengoai.stream.MStream;
import com.gengoai.string.Strings;
import java.lang.invoke.SerializedLambda;
import java.util.Collection;
import java.util.Objects;
import java.util.Set;
import lombok.NonNull;

/* loaded from: input_file:com/gengoai/apollo/ml/model/topic/BaseVectorTopicModel.class */
public abstract class BaseVectorTopicModel extends TopicModel {
    protected final IndexEncoder encoder = new IndexEncoder();

    /* JADX INFO: Access modifiers changed from: protected */
    public void encoderFit(@NonNull DataSet dataSet, @NonNull Collection<String> collection, @NonNull VariableNameSpace variableNameSpace) {
        if (dataSet == null) {
            throw new NullPointerException("dataset is marked non-null but is null");
        }
        if (collection == null) {
            throw new NullPointerException("sources is marked non-null but is null");
        }
        if (variableNameSpace == null) {
            throw new NullPointerException("nameSpace is marked non-null but is null");
        }
        IndexEncoder indexEncoder = this.encoder;
        MStream flatMap = dataSet.stream().flatMap(datum -> {
            return datum.stream((Collection<String>) collection);
        }).flatMap((v0) -> {
            return v0.getVariableSpace();
        });
        Objects.requireNonNull(variableNameSpace);
        indexEncoder.fit(flatMap.map(variableNameSpace::transform));
    }

    protected abstract NDArray inference(NDArray nDArray);

    /* JADX INFO: Access modifiers changed from: protected */
    public NDArray toCountVector(@NonNull Observation observation, @NonNull VariableNameSpace variableNameSpace) {
        if (observation == null) {
            throw new NullPointerException("observation is marked non-null but is null");
        }
        if (variableNameSpace == null) {
            throw new NullPointerException("nameSpace is marked non-null but is null");
        }
        NDArray array = NDArrayFactory.ND.array(this.encoder.size());
        observation.getVariableSpace().forEach(variable -> {
            int encode = this.encoder.encode(variableNameSpace.getName(variable));
            if (encode >= 0) {
                array.set(encode, array.get(encode) + variable.getValue());
            }
        });
        return array;
    }

    @Override // com.gengoai.apollo.ml.transform.Transform
    public Datum transform(@NonNull Datum datum) {
        if (datum == null) {
            throw new NullPointerException("datum is marked non-null but is null");
        }
        if (((Boolean) getFitParameters().combineOutputs.value()).booleanValue()) {
            datum.put((String) getFitParameters().output.value(), (Observation) inference(toCountVector(VariableCollection.mergeVariableSpace(datum.stream((Collection<String>) getFitParameters().inputs.value())), (VariableNameSpace) getFitParameters().namingPattern.value())));
        } else {
            for (String str : (Set) getFitParameters().inputs.value()) {
                datum.put(str + Strings.nullToEmpty((String) getFitParameters().outputSuffix.value()), (Observation) inference(toCountVector(datum.get(str), (VariableNameSpace) getFitParameters().namingPattern.value())));
            }
        }
        return datum;
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case -1593354728:
                if (implMethodName.equals("lambda$encoderFit$c74e55c7$1")) {
                    z = 2;
                    break;
                }
                break;
            case 208476212:
                if (implMethodName.equals("getVariableSpace")) {
                    z = true;
                    break;
                }
                break;
            case 1052666732:
                if (implMethodName.equals("transform")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 5 && serializedLambda.getFunctionalInterfaceClass().equals("com/gengoai/function/SerializableFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("com/gengoai/apollo/ml/observation/VariableNameSpace") && serializedLambda.getImplMethodSignature().equals("(Lcom/gengoai/apollo/ml/observation/Variable;)Lcom/gengoai/apollo/ml/observation/Variable;")) {
                    VariableNameSpace variableNameSpace = (VariableNameSpace) serializedLambda.getCapturedArg(0);
                    return variableNameSpace::transform;
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 9 && serializedLambda.getFunctionalInterfaceClass().equals("com/gengoai/function/SerializableFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("com/gengoai/apollo/ml/observation/Observation") && serializedLambda.getImplMethodSignature().equals("()Ljava/util/stream/Stream;")) {
                    return (v0) -> {
                        return v0.getVariableSpace();
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("com/gengoai/function/SerializableFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("com/gengoai/apollo/ml/model/topic/BaseVectorTopicModel") && serializedLambda.getImplMethodSignature().equals("(Ljava/util/Collection;Lcom/gengoai/apollo/ml/Datum;)Ljava/util/stream/Stream;")) {
                    Collection collection = (Collection) serializedLambda.getCapturedArg(0);
                    return datum -> {
                        return datum.stream((Collection<String>) collection);
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
