package org.campagnelab.dl.framework.mappers;

import java.util.function.Function;
import org.nd4j.linalg.api.ndarray.INDArray;

/* loaded from: input_file:org/campagnelab/dl/framework/mappers/RNNLabelMapper.class */
public class RNNLabelMapper<RecordType> implements LabelMapper<RecordType> {
    private int labelsPerTimeStep;
    private LabelMapper<RecordType>[] delegates;
    Function<RecordType, Integer> recordToSequenceLength;
    private int[] indicesMapper;
    private int[] indicesMasker;
    private MappedDimensions dim;
    private int sequenceLength;
    static final /* synthetic */ boolean $assertionsDisabled;

    public RNNLabelMapper(int i, int i2, Function<RecordType, int[]> function, Function<RecordType, Integer> function2) {
        this(function2, createOneHotBaseLabelMappers(i, i2, function));
    }

    @SafeVarargs
    public RNNLabelMapper(Function<RecordType, Integer> function, LabelMapper<RecordType>... labelMapperArr) {
        this.indicesMapper = new int[]{0, 0, 0};
        this.indicesMasker = new int[]{0, 0};
        this.recordToSequenceLength = function;
        MappedDimensions dimensions = labelMapperArr[0].dimensions();
        for (LabelMapper<RecordType> labelMapper : labelMapperArr) {
            MappedDimensions dimensions2 = labelMapper.dimensions();
            if (!dimensions.equals(dimensions2)) {
                throw new RuntimeException("All delegate label mappers should have same dimensions");
            }
            if (dimensions2.numDimensions() != 1) {
                throw new RuntimeException("All delegate label mappers should be one dimensional");
            }
        }
        this.labelsPerTimeStep = dimensions.numElements();
        this.delegates = labelMapperArr;
        this.dim = new MappedDimensions(this.labelsPerTimeStep, labelMapperArr.length);
    }

    @Override // org.campagnelab.dl.framework.mappers.LabelMapper
    public int numberOfLabels() {
        return this.delegates.length * this.labelsPerTimeStep;
    }

    @Override // org.campagnelab.dl.framework.mappers.LabelMapper
    public MappedDimensions dimensions() {
        if ($assertionsDisabled || this.dim.numElements() == numberOfLabels()) {
            return this.dim;
        }
        throw new AssertionError("Number of elements must match number of labels.");
    }

    @Override // org.campagnelab.dl.framework.mappers.LabelMapper
    public void mapLabels(RecordType recordtype, INDArray iNDArray, int i) {
        this.indicesMapper[0] = i;
        for (int i2 = 0; i2 < this.delegates.length; i2++) {
            this.indicesMapper[2] = i2;
            LabelMapper<RecordType> labelMapper = this.delegates[i2];
            for (int i3 = 0; i3 < this.delegates[i2].numberOfLabels(); i3++) {
                this.indicesMapper[1] = i3;
                if (i2 < this.sequenceLength) {
                    iNDArray.putScalar(this.indicesMapper, labelMapper.produceLabel(recordtype, i3));
                } else {
                    iNDArray.putScalar(this.indicesMapper, 0.0f);
                }
            }
        }
    }

    @Override // org.campagnelab.dl.framework.mappers.LabelMapper
    public boolean hasMask() {
        return true;
    }

    @Override // org.campagnelab.dl.framework.mappers.LabelMapper
    public void maskLabels(RecordType recordtype, INDArray iNDArray, int i) {
        this.indicesMasker[0] = i;
        int i2 = 0;
        while (i2 < this.delegates.length) {
            this.indicesMasker[1] = i2;
            iNDArray.putScalar(this.indicesMasker, i2 < this.sequenceLength ? 1.0f : 0.0f);
            i2++;
        }
    }

    @Override // org.campagnelab.dl.framework.mappers.LabelMapper
    public float produceLabel(RecordType recordtype, int i) {
        int i2 = i / this.labelsPerTimeStep;
        if (i2 >= this.sequenceLength) {
            return 0.0f;
        }
        return this.delegates[i2].produceLabel(recordtype, i % this.labelsPerTimeStep);
    }

    @Override // org.campagnelab.dl.framework.mappers.LabelMapper
    public boolean isMasked(RecordType recordtype, int i) {
        int i2 = i / this.labelsPerTimeStep;
        if (i2 >= this.sequenceLength) {
            return false;
        }
        return !this.delegates[i2].hasMask() || this.delegates[i2].isMasked(recordtype, i % this.labelsPerTimeStep);
    }

    @Override // org.campagnelab.dl.framework.mappers.LabelMapper
    public void prepareToNormalize(RecordType recordtype, int i) {
        this.sequenceLength = this.recordToSequenceLength.apply(recordtype).intValue();
        for (LabelMapper<RecordType> labelMapper : this.delegates) {
            labelMapper.prepareToNormalize(recordtype, i);
        }
    }

    public int maxSequenceLength() {
        return this.delegates.length;
    }

    public int labelsPerTimeStep() {
        return this.labelsPerTimeStep;
    }

    private static <RecordType> LabelMapper<RecordType>[] createOneHotBaseLabelMappers(int i, int i2, Function<RecordType, int[]> function) {
        OneHotBaseLabelMapper[] oneHotBaseLabelMapperArr = new OneHotBaseLabelMapper[i];
        for (int i3 = 0; i3 < i; i3++) {
            oneHotBaseLabelMapperArr[i3] = new OneHotBaseLabelMapper(i3, i2, function);
        }
        return oneHotBaseLabelMapperArr;
    }

    static {
        $assertionsDisabled = !RNNLabelMapper.class.desiredAssertionStatus();
    }
}
