package com.gengoai.apollo.ml.transform;

import com.gengoai.apollo.math.linalg.NDArray;
import com.gengoai.apollo.math.linalg.NDArrayFactory;
import com.gengoai.apollo.ml.DataSet;
import com.gengoai.apollo.ml.observation.Observation;
import com.gengoai.stream.MStream;
import lombok.NonNull;

/* loaded from: input_file:com/gengoai/apollo/ml/transform/SequenceVectorContext.class */
public class SequenceVectorContext extends AbstractSingleSourceTransform<SequenceVectorContext> {
    private final int left;
    private final int right;

    public SequenceVectorContext(int i, int i2) {
        this.left = Math.abs(i);
        this.right = Math.abs(i2);
    }

    @Override // com.gengoai.apollo.ml.transform.AbstractSingleSourceTransform
    protected void fit(@NonNull MStream<Observation> mStream) {
        if (mStream == null) {
            throw new NullPointerException("observations is marked non-null but is null");
        }
    }

    @Override // com.gengoai.apollo.ml.transform.AbstractSingleSourceTransform
    protected Observation transform(@NonNull Observation observation) {
        if (observation == null) {
            throw new NullPointerException("observation is marked non-null but is null");
        }
        NDArray asNDArray = observation.asNDArray();
        NDArray array = NDArrayFactory.ND.array(asNDArray.rows(), asNDArray.columns() + (this.left * asNDArray.columns()) + (this.right * asNDArray.columns()));
        for (int i = 0; i < asNDArray.rows(); i++) {
            int i2 = 0;
            int i3 = i;
            for (int i4 = i - this.left; i4 < i + this.right; i4++) {
                if (i4 >= 0 && i4 < asNDArray.rows()) {
                    int i5 = i2;
                    asNDArray.getRow(i4).forEachSparse((j, d) -> {
                        array.set(i3, ((int) j) + i5, d);
                    });
                }
                i2 += asNDArray.columns();
            }
        }
        return array;
    }

    @Override // com.gengoai.apollo.ml.transform.AbstractSingleSourceTransform
    protected void updateMetadata(@NonNull DataSet dataSet) {
        if (dataSet == null) {
            throw new NullPointerException("data is marked non-null but is null");
        }
        long dimension = dataSet.getMetadata(this.input).getDimension();
        long j = dimension + (this.left * dimension) + (this.right * dimension);
        dataSet.updateMetadata(this.output, observationMetadata -> {
            observationMetadata.setDimension(j);
            observationMetadata.setType(NDArray.class);
            observationMetadata.setEncoder(null);
        });
    }
}
