package com.gengoai.apollo.ml.model.embedding;

import com.gengoai.Validation;
import com.gengoai.apollo.math.linalg.NDArray;
import com.gengoai.apollo.math.linalg.VectorComposition;
import com.gengoai.apollo.ml.DataSet;
import com.gengoai.apollo.ml.Datum;
import com.gengoai.apollo.ml.model.CombinableOutputModel;
import com.gengoai.apollo.ml.model.LabelType;
import com.gengoai.apollo.ml.model.embedding.TrainableWordEmbedding;
import com.gengoai.apollo.ml.model.embedding.WordEmbeddingFitParameters;
import com.gengoai.apollo.ml.observation.Observation;
import com.gengoai.apollo.ml.observation.Variable;
import com.gengoai.apollo.ml.observation.VariableNameSpace;
import com.gengoai.collection.Iterables;
import com.gengoai.conversion.Cast;
import java.util.Collection;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.NonNull;

/* loaded from: input_file:com/gengoai/apollo/ml/model/embedding/TrainableWordEmbedding.class */
public abstract class TrainableWordEmbedding<F extends WordEmbeddingFitParameters<F>, T extends TrainableWordEmbedding<F, T>> extends WordEmbedding implements CombinableOutputModel<F, TrainableWordEmbedding<F, T>> {
    private static final long serialVersionUID = 1;
    protected final F parameters;

    /* JADX INFO: Access modifiers changed from: protected */
    public TrainableWordEmbedding(@NonNull F f) {
        if (f == null) {
            throw new NullPointerException("parameters is marked non-null but is null");
        }
        this.parameters = f;
    }

    @Override // com.gengoai.apollo.ml.model.embedding.WordEmbedding, com.gengoai.apollo.ml.transform.Transform
    public DataSet fitAndTransform(@NonNull DataSet dataSet) {
        if (dataSet == null) {
            throw new NullPointerException("dataset is marked non-null but is null");
        }
        estimate(dataSet);
        return transform(dataSet);
    }

    @Override // com.gengoai.apollo.ml.model.CombinableOutputModel, com.gengoai.apollo.ml.model.Model
    public F getFitParameters() {
        return this.parameters;
    }

    @Override // com.gengoai.apollo.ml.transform.Transform
    public Set<String> getInputs() {
        return (Set) this.parameters.inputs.value();
    }

    @Override // com.gengoai.apollo.ml.model.CombinableOutputModel
    public LabelType getOutputType() {
        return LabelType.NDArray;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.gengoai.apollo.ml.model.embedding.WordEmbedding
    public String getVariableName(Variable variable) {
        return ((VariableNameSpace) this.parameters.nameSpace.value()).getName(variable);
    }

    public T inputs(@NonNull String... strArr) {
        if (strArr == null) {
            throw new NullPointerException("inputs is marked non-null but is null");
        }
        Validation.checkArgument(strArr.length > 0, "Must specify at least one input");
        this.parameters.inputs.set(Set.of((Object[]) strArr));
        return (T) Cast.as(this);
    }

    @Override // com.gengoai.apollo.ml.model.embedding.WordEmbedding, com.gengoai.apollo.ml.transform.Transform
    public Datum transform(@NonNull Datum datum) {
        if (datum == null) {
            throw new NullPointerException("datum is marked non-null but is null");
        }
        if (!((Boolean) this.parameters.combineOutputs.value()).booleanValue()) {
            for (String str : getInputs()) {
                datum.put(str + ((String) this.parameters.outputSuffix.value()), (Observation) transform(datum.get(str)));
            }
        } else if (((Set) this.parameters.inputs.value()).size() > 1) {
            datum.put((String) this.parameters.output.value(), (Observation) ((VectorComposition) this.parameters.aggregationFunction.value()).compose((Collection<NDArray>) datum.stream(getInputs()).map(this::transform).collect(Collectors.toList())));
        } else {
            datum.put((String) this.parameters.output.value(), (Observation) transform(datum.get(Iterables.getFirst(getInputs(), (Object) null))));
        }
        return datum;
    }
}
