package org.campagnelab.dl.framework.domains.prediction;

import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.accum.Sum;
import org.nd4j.linalg.api.ops.impl.indexaccum.IMax;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;

/* loaded from: input_file:org/campagnelab/dl/framework/domains/prediction/TimeSeriesPrediction.class */
public class TimeSeriesPrediction extends Prediction {
    public int[] trueLabels;
    public int[] predictedLabels;
    static final /* synthetic */ boolean $assertionsDisabled;

    public TimeSeriesPrediction setTrueLabels(int[] iArr) {
        if (this.predictedLabels != null && !$assertionsDisabled && iArr.length != this.predictedLabels.length) {
            throw new AssertionError("Labels should have same length");
        }
        this.trueLabels = iArr;
        return this;
    }

    public TimeSeriesPrediction setTrueLabels(INDArray iNDArray) {
        if ($assertionsDisabled || iNDArray.shape().length == 2) {
            return setTrueLabels(getIntArgMaxArray(iNDArray));
        }
        throw new AssertionError("True labels should be a 2D array (i.e., just for one time series)");
    }

    public TimeSeriesPrediction setTrueLabels(INDArray iNDArray, int i) {
        if (!$assertionsDisabled && iNDArray.shape().length != 3) {
            throw new AssertionError("All true labels should be a 3D array");
        }
        if ($assertionsDisabled || i < iNDArray.shape()[0]) {
            return setTrueLabels(iNDArray.getRow(i));
        }
        throw new AssertionError("label index is out of bounds");
    }

    public TimeSeriesPrediction setPredictedLabels(int[] iArr) {
        if (this.trueLabels != null && !$assertionsDisabled && this.trueLabels.length != iArr.length) {
            throw new AssertionError("Labels should have same length");
        }
        this.predictedLabels = iArr;
        return this;
    }

    public TimeSeriesPrediction setPredictedLabels(INDArray iNDArray) {
        if ($assertionsDisabled || iNDArray.shape().length == 2) {
            return setPredictedLabels(getIntArgMaxArray(iNDArray));
        }
        throw new AssertionError("Predicted labels should be a 2D array (i.e., just for one time series)");
    }

    public TimeSeriesPrediction setPredictedLabels(INDArray iNDArray, int i) {
        if (!$assertionsDisabled && iNDArray.shape().length != 3) {
            throw new AssertionError("All predicted labels should be a 3D array");
        }
        if ($assertionsDisabled || i < iNDArray.shape()[0]) {
            return setPredictedLabels(iNDArray.getRow(i));
        }
        throw new AssertionError("label index is out of bounds");
    }

    public int[] trueLabels() {
        return this.trueLabels;
    }

    public int[] predictedLabels() {
        return this.predictedLabels;
    }

    private static int[] getIntArgMaxArray(INDArray iNDArray) {
        int intValue = Nd4j.getExecutioner().exec(new Sum(iNDArray), new int[]{0}).gt(0).sumNumber().intValue();
        INDArray exec = Nd4j.getExecutioner().exec(new IMax(iNDArray), new int[]{0});
        return intValue > 0 ? exec.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(0, intValue)}).data().asInt() : exec.data().asInt();
    }

    static {
        $assertionsDisabled = !TimeSeriesPrediction.class.desiredAssertionStatus();
    }
}
