package com.gengoai.apollo.ml.transform.vectorizer;

import com.gengoai.apollo.math.linalg.NDArray;
import com.gengoai.apollo.ml.encoder.Encoder;
import com.gengoai.apollo.ml.observation.Observation;
import com.gengoai.apollo.ml.observation.Sequence;
import com.gengoai.apollo.ml.observation.Variable;
import com.gengoai.apollo.ml.observation.VariableCollection;
import com.gengoai.apollo.ml.transform.vectorizer.AbstractVariableVectorizer;
import com.gengoai.conversion.Cast;
import java.util.ArrayList;
import lombok.NonNull;

/* loaded from: input_file:com/gengoai/apollo/ml/transform/vectorizer/AbstractVariableVectorizer.class */
public abstract class AbstractVariableVectorizer<T extends AbstractVariableVectorizer<T>> extends Vectorizer<T> {
    /* JADX INFO: Access modifiers changed from: protected */
    public AbstractVariableVectorizer(@NonNull Encoder encoder) {
        super(encoder);
        if (encoder == null) {
            throw new NullPointerException("encoder is marked non-null but is null");
        }
    }

    protected abstract void encodeVariableInto(Variable variable, NDArray nDArray);

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.gengoai.apollo.ml.transform.AbstractSingleSourceTransform
    public final NDArray transform(Observation observation) {
        if (observation instanceof Variable) {
            NDArray array = this.ndArrayFactory.array(1, this.encoder.size());
            encodeVariableInto((Variable) Cast.as(observation), array);
            return array;
        }
        if (observation instanceof VariableCollection) {
            NDArray array2 = this.ndArrayFactory.array(1, this.encoder.size());
            observation.asVariableCollection().forEach(variable -> {
                encodeVariableInto(variable, array2);
            });
            return array2;
        }
        if (!(observation instanceof Sequence)) {
            throw new IllegalArgumentException("Unsupported Observation: " + observation.getClass());
        }
        Sequence sequence = (Sequence) Cast.as(observation);
        ArrayList arrayList = new ArrayList();
        sequence.forEach(observation2 -> {
            arrayList.add(transform(observation2));
        });
        return this.ndArrayFactory.vstack(arrayList);
    }
}
