package com.gengoai.apollo.ml.transform.vectorizer;

import com.gengoai.apollo.math.linalg.NDArray;
import com.gengoai.apollo.math.linalg.NDArrayFactory;
import com.gengoai.apollo.ml.encoder.Encoder;
import com.gengoai.apollo.ml.encoder.IndexEncoder;
import com.gengoai.apollo.ml.observation.Observation;
import com.gengoai.apollo.ml.observation.Sequence;
import com.gengoai.apollo.ml.observation.Variable;
import com.gengoai.apollo.ml.observation.VariableCollection;
import com.gengoai.apollo.ml.observation.VariableSequence;
import com.gengoai.conversion.Cast;
import java.util.ArrayList;
import java.util.Iterator;
import lombok.NonNull;
import org.apache.mahout.math.list.DoubleArrayList;

/* loaded from: input_file:com/gengoai/apollo/ml/transform/vectorizer/IndexingVectorizer.class */
public class IndexingVectorizer extends Vectorizer<IndexingVectorizer> {
    private static final long serialVersionUID = 1;

    public IndexingVectorizer(@NonNull Encoder encoder) {
        super(encoder);
        if (encoder == null) {
            throw new NullPointerException("encoder is marked non-null but is null");
        }
    }

    public IndexingVectorizer() {
        super(new IndexEncoder());
    }

    public IndexingVectorizer(String str) {
        super(new IndexEncoder(str));
    }

    @Override // com.gengoai.apollo.ml.transform.AbstractSingleSourceTransform
    public NDArray transform(Observation observation) {
        if (!(observation instanceof Variable) && !(observation instanceof VariableCollection)) {
            if (!(observation instanceof Sequence)) {
                throw new IllegalArgumentException("Unsupported Observation: " + observation.getClass());
            }
            Sequence sequence = (Sequence) Cast.as(observation);
            if (sequence instanceof VariableSequence) {
                ArrayList arrayList = new ArrayList();
                Iterator<T> it = sequence.iterator();
                while (it.hasNext()) {
                    arrayList.add(transform((Observation) it.next(), null));
                }
                return NDArrayFactory.ND.vstack(arrayList);
            }
            int orElse = sequence.stream().mapToInt(observation2 -> {
                return (int) observation2.getVariableSpace().count();
            }).max().orElse(1);
            NDArray array = this.ndArrayFactory.array(sequence.size(), orElse);
            for (int i = 0; i < sequence.size(); i++) {
                array.setRow(i, transform((Observation) sequence.get(i), NDArrayFactory.ND.array(1, orElse)));
            }
            return array;
        }
        return transform(observation, null);
    }

    public NDArray transform(Observation observation, NDArray nDArray) {
        if (observation.isVariable()) {
            int encode = this.encoder.encode(((Variable) observation).getName());
            return encode >= 0 ? this.ndArrayFactory.scalar(encode) : this.ndArrayFactory.scalar(0.0d);
        }
        VariableCollection asVariableCollection = observation.asVariableCollection();
        DoubleArrayList doubleArrayList = new DoubleArrayList();
        asVariableCollection.forEach(variable -> {
            int encode2 = this.encoder.encode(variable.getName());
            if (encode2 >= 0) {
                doubleArrayList.add(encode2);
            }
        });
        doubleArrayList.sort();
        if (nDArray == null) {
            nDArray = NDArrayFactory.ND.array(1, doubleArrayList.size());
        }
        for (int i = 0; i < doubleArrayList.size(); i++) {
            nDArray.set(i, doubleArrayList.get(i));
        }
        return nDArray;
    }
}
