package org.datavec.nlp.transforms;

import java.util.Collections;
import java.util.List;
import org.datavec.api.transform.metadata.ColumnMetaData;
import org.datavec.api.transform.metadata.NDArrayMetaData;
import org.datavec.api.transform.transform.BaseColumnTransform;
import org.datavec.api.writable.NDArrayWritable;
import org.datavec.api.writable.Writable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.list.NDArrayList;
import org.nd4j.shade.jackson.annotation.JsonCreator;
import org.nd4j.shade.jackson.annotation.JsonProperty;

/* loaded from: input_file:org/datavec/nlp/transforms/MultiNlpTransform.class */
public class MultiNlpTransform extends BaseColumnTransform implements BagOfWordsTransform {
    private BagOfWordsTransform[] transforms;
    private String newColumnName;
    private List<String> vocabWords;

    @JsonCreator
    public MultiNlpTransform(@JsonProperty("columnName") String str, @JsonProperty("transforms") BagOfWordsTransform[] bagOfWordsTransformArr, @JsonProperty("newColumnName") String str2) {
        super(str);
        this.transforms = bagOfWordsTransformArr;
        this.vocabWords = bagOfWordsTransformArr[0].vocabWords();
        if (bagOfWordsTransformArr.length > 1) {
            for (int i = 1; i < bagOfWordsTransformArr.length; i++) {
                if (!bagOfWordsTransformArr[i].vocabWords().equals(this.vocabWords)) {
                    throw new IllegalArgumentException("Vocab words not consistent across transforms!");
                }
            }
        }
        this.newColumnName = str2;
    }

    public Object mapSequence(Object obj) {
        NDArrayList nDArrayList = new NDArrayList();
        for (BagOfWordsTransform bagOfWordsTransform : this.transforms) {
            nDArrayList.addAll(new NDArrayList(bagOfWordsTransform.transformFromObject((List) obj)));
        }
        return nDArrayList.array();
    }

    public List<List<Writable>> mapSequence(List<List<Writable>> list) {
        return Collections.singletonList(Collections.singletonList(new NDArrayWritable(transformFrom(list))));
    }

    public ColumnMetaData getNewColumnMetaData(String str, ColumnMetaData columnMetaData) {
        return new NDArrayMetaData(str, outputShape());
    }

    public Writable map(Writable writable) {
        throw new UnsupportedOperationException("Only able to add for time series");
    }

    public String toString() {
        return this.newColumnName;
    }

    public Object map(Object obj) {
        throw new UnsupportedOperationException("Only able to add for time series");
    }

    @Override // org.datavec.nlp.transforms.BagOfWordsTransform
    public long[] outputShape() {
        long[] jArr = new long[this.transforms[0].outputShape().length];
        int length = this.transforms[0].outputShape().length;
        for (int i = 1; i < this.transforms.length; i++) {
            if (this.transforms[i].outputShape().length != length) {
                throw new IllegalArgumentException("Inconsistent shape length at transform " + i + " , should have been: " + length);
            }
        }
        for (int i2 = 0; i2 < this.transforms.length; i2++) {
            for (int i3 = 0; i3 < length; i3++) {
                int i4 = i3;
                jArr[i4] = jArr[i4] + this.transforms[i2].outputShape()[i3];
            }
        }
        return jArr;
    }

    @Override // org.datavec.nlp.transforms.BagOfWordsTransform
    public List<String> vocabWords() {
        return this.vocabWords;
    }

    @Override // org.datavec.nlp.transforms.BagOfWordsTransform
    public INDArray transformFromObject(List<List<Object>> list) {
        NDArrayList nDArrayList = new NDArrayList();
        for (BagOfWordsTransform bagOfWordsTransform : this.transforms) {
            INDArray transformFromObject = bagOfWordsTransform.transformFromObject(list);
            INDArray reshape = transformFromObject.reshape(new long[]{transformFromObject.length()});
            nDArrayList.addAll(new NDArrayList(reshape, (int) reshape.length()));
        }
        return nDArrayList.array();
    }

    @Override // org.datavec.nlp.transforms.BagOfWordsTransform
    public INDArray transformFrom(List<List<Writable>> list) {
        NDArrayList nDArrayList = new NDArrayList();
        for (BagOfWordsTransform bagOfWordsTransform : this.transforms) {
            INDArray transformFrom = bagOfWordsTransform.transformFrom(list);
            INDArray reshape = transformFrom.reshape(new long[]{transformFrom.length()});
            nDArrayList.addAll(new NDArrayList(reshape, (int) reshape.length()));
        }
        return nDArrayList.array();
    }
}
