package org.deeplearning4j.spark.canova;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import org.apache.spark.api.java.function.Function;
import org.canova.api.io.WritableConverter;
import org.canova.api.io.converters.WritableConverterException;
import org.canova.api.writable.Writable;
import org.canova.common.data.NDArrayWritable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.FeatureUtil;

/* loaded from: input_file:org/deeplearning4j/spark/canova/CanovaDataSetFunction.class */
public class CanovaDataSetFunction implements Function<Collection<Writable>, DataSet>, Serializable {
    private final int labelIndex;
    private final int numPossibleLabels;
    private final boolean regression;
    private final DataSetPreProcessor preProcessor;
    private final WritableConverter converter;
    protected int batchSize;
    static final /* synthetic */ boolean $assertionsDisabled;

    public CanovaDataSetFunction(int i, int i2, boolean z) {
        this(i, i2, z, null, null);
    }

    public CanovaDataSetFunction(int i, int i2, boolean z, DataSetPreProcessor dataSetPreProcessor, WritableConverter writableConverter) {
        this.batchSize = -1;
        this.labelIndex = i;
        this.numPossibleLabels = i2;
        this.regression = z;
        this.preProcessor = dataSetPreProcessor;
        this.converter = writableConverter;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v75, types: [java.util.List] */
    public DataSet call(Collection<Writable> collection) throws Exception {
        ArrayList arrayList = collection instanceof List ? (List) collection : new ArrayList(collection);
        int i = this.labelIndex;
        if (this.numPossibleLabels >= 1 && i < 0) {
            i = arrayList.size() - 1;
        }
        INDArray iNDArray = null;
        INDArray iNDArray2 = null;
        int i2 = 0;
        for (int i3 = 0; i3 < arrayList.size(); i3++) {
            Writable writable = (Writable) arrayList.get(i3);
            if (this.converter != null) {
                writable = this.converter.convert(writable);
            }
            if (i < 0 || i3 != i) {
                try {
                    double d = writable.toDouble();
                    if (iNDArray2 == null) {
                        iNDArray2 = Nd4j.create(i >= 0 ? arrayList.size() - 1 : arrayList.size());
                    }
                    int i4 = i2;
                    i2++;
                    iNDArray2.putScalar(i4, d);
                } catch (UnsupportedOperationException e) {
                    if (!(writable instanceof NDArrayWritable)) {
                        throw e;
                    }
                    if (!$assertionsDisabled && iNDArray2 != null) {
                        throw new AssertionError();
                    }
                    iNDArray2 = ((NDArrayWritable) writable).get();
                }
            } else {
                if (this.converter != null) {
                    try {
                        writable = this.converter.convert(writable);
                    } catch (WritableConverterException e2) {
                        e2.printStackTrace();
                    }
                }
                if (this.numPossibleLabels < 1) {
                    throw new IllegalStateException("Number of possible labels invalid, must be >= 1");
                }
                if (this.regression) {
                    iNDArray = Nd4j.scalar(writable.toDouble());
                } else {
                    int i5 = writable.toInt();
                    if (i5 >= this.numPossibleLabels) {
                        throw new IllegalStateException("Invalid input: class label is " + i5 + " with numPossibleLables = " + this.numPossibleLabels + " (class label must be 0 <= labelIdx < numPossibleLabels)");
                    }
                    iNDArray = FeatureUtil.toOutcomeVector(i5, this.numPossibleLabels);
                }
            }
        }
        DataSet dataSet = new DataSet(iNDArray2, i >= 0 ? iNDArray : iNDArray2);
        if (this.preProcessor != null) {
            this.preProcessor.preProcess(dataSet);
        }
        return dataSet;
    }

    static {
        $assertionsDisabled = !CanovaDataSetFunction.class.desiredAssertionStatus();
    }
}
