package org.campagnelab.dl.framework.domains.prediction;

import java.util.Arrays;
import java.util.function.Function;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/campagnelab/dl/framework/domains/prediction/TimeSeriesPredictionInterpreter.class */
public class TimeSeriesPredictionInterpreter<RecordType> implements PredictionInterpreter<RecordType, TimeSeriesPrediction> {
    private static Logger LOG;
    private final Function<RecordType, int[]> recordToLabel;
    private final Function<RecordType, Integer> recordToSequenceLength;
    static final /* synthetic */ boolean $assertionsDisabled;

    public TimeSeriesPredictionInterpreter(Function<RecordType, int[]> function) {
        this(function, null);
    }

    public TimeSeriesPredictionInterpreter(Function<RecordType, int[]> function, Function<RecordType, Integer> function2) {
        this.recordToLabel = function;
        this.recordToSequenceLength = function2;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.campagnelab.dl.framework.domains.prediction.PredictionInterpreter
    public TimeSeriesPrediction interpret(RecordType recordtype, INDArray iNDArray) {
        int[] apply = this.recordToLabel.apply(recordtype);
        Integer apply2 = this.recordToSequenceLength.apply(recordtype);
        if (apply2.intValue() != apply.length) {
            LOG.warn("sequence and true labels lengths must agree. Make sure the function recordToSequenceLength accounts for sequence clipping to maxLength.");
        }
        int min = Math.min(apply2.intValue(), apply.length);
        TimeSeriesPrediction timeSeriesPrediction = new TimeSeriesPrediction(Integer.valueOf(min));
        timeSeriesPrediction.setPredictedLabels(iNDArray.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(0, min)}));
        timeSeriesPrediction.setTrueLabels(apply.length != min ? Arrays.copyOfRange(apply, 0, min) : apply);
        return timeSeriesPrediction;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.campagnelab.dl.framework.domains.prediction.PredictionInterpreter
    public TimeSeriesPrediction interpret(INDArray iNDArray, INDArray iNDArray2, int i) {
        TimeSeriesPrediction timeSeriesPrediction = new TimeSeriesPrediction();
        timeSeriesPrediction.setPredictedLabels(iNDArray2, i);
        timeSeriesPrediction.setTrueLabels(iNDArray, i);
        return timeSeriesPrediction;
    }

    public TimeSeriesPrediction interpret(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, int i) {
        if (!$assertionsDisabled && iNDArray2.shape().length != 3) {
            throw new AssertionError("True labels should be a 3D array");
        }
        if (!$assertionsDisabled && iNDArray.shape().length != 2) {
            throw new AssertionError("True masks should be a 2D array");
        }
        if (!$assertionsDisabled && iNDArray3.shape().length != 3) {
            throw new AssertionError("True labels should be a 3D array");
        }
        if (!$assertionsDisabled && i >= iNDArray.shape()[0]) {
            throw new AssertionError("prediction index is out of bounds for true masks");
        }
        if (!$assertionsDisabled && i >= iNDArray2.shape()[0]) {
            throw new AssertionError("prediction index is out of bounds for true labels");
        }
        if (!$assertionsDisabled && i >= iNDArray3.shape()[0]) {
            throw new AssertionError("prediction index is out of bounds for output");
        }
        int intValue = iNDArray.getRow(i).gt(0).sumNumber().intValue();
        return interpret(iNDArray2.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(0, intValue)}), iNDArray3.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(0, intValue)}), i);
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.campagnelab.dl.framework.domains.prediction.PredictionInterpreter
    public /* bridge */ /* synthetic */ TimeSeriesPrediction interpret(Object obj, INDArray iNDArray) {
        return interpret((TimeSeriesPredictionInterpreter<RecordType>) obj, iNDArray);
    }

    static {
        $assertionsDisabled = !TimeSeriesPredictionInterpreter.class.desiredAssertionStatus();
        LOG = LoggerFactory.getLogger(TimeSeriesPredictionInterpreter.class);
    }
}
