package org.campagnelab.dl.framework.mappers;

import java.util.function.Function;
import org.nd4j.linalg.api.ndarray.INDArray;

/* loaded from: input_file:org/campagnelab/dl/framework/mappers/RNNLabelMapper.class */
public class RNNLabelMapper<RecordType> implements LabelMapper<RecordType> {
    private int labelsPerTimeStep;
    private OneHotBaseLabelMapper<RecordType>[] delegates;
    private Function<RecordType, Integer> recordToSequenceLength;
    private int[] indicesMapper = {0, 0, 0};
    private int[] indicesMasker = {0, 0};
    private MappedDimensions dim;
    static final /* synthetic */ boolean $assertionsDisabled;

    public RNNLabelMapper(int i, int i2, Function<RecordType, int[]> function, Function<RecordType, Integer> function2) {
        this.recordToSequenceLength = function2;
        this.delegates = new OneHotBaseLabelMapper[i2];
        for (int i3 = 0; i3 < i2; i3++) {
            this.delegates[i3] = new OneHotBaseLabelMapper<>(i3, i, function);
        }
        this.labelsPerTimeStep = this.delegates[0].numberOfLabels();
        for (OneHotBaseLabelMapper<RecordType> oneHotBaseLabelMapper : this.delegates) {
            if (oneHotBaseLabelMapper.numberOfLabels() != i) {
                throw new RuntimeException("All delegate one hot base mappers should have same number of labels");
            }
        }
        this.dim = new MappedDimensions(i, this.delegates.length);
    }

    @Override // org.campagnelab.dl.framework.mappers.LabelMapper
    public int numberOfLabels() {
        return this.delegates.length * this.labelsPerTimeStep;
    }

    @Override // org.campagnelab.dl.framework.mappers.LabelMapper
    public MappedDimensions dimensions() {
        if ($assertionsDisabled || this.dim.numElements() == numberOfLabels()) {
            return this.dim;
        }
        throw new AssertionError("Number of elements must match number of labels.");
    }

    @Override // org.campagnelab.dl.framework.mappers.LabelMapper
    public void mapLabels(RecordType recordtype, INDArray iNDArray, int i) {
        this.indicesMapper[0] = i;
        for (int i2 = 0; i2 < this.delegates.length; i2++) {
            this.indicesMapper[2] = i2;
            OneHotBaseLabelMapper<RecordType> oneHotBaseLabelMapper = this.delegates[i2];
            for (int i3 = 0; i3 < this.delegates[i2].numberOfLabels(); i3++) {
                this.indicesMapper[1] = i3;
                if (i2 < this.recordToSequenceLength.apply(recordtype).intValue()) {
                    iNDArray.putScalar(this.indicesMapper, oneHotBaseLabelMapper.produceLabel(recordtype, i3));
                } else {
                    iNDArray.putScalar(this.indicesMapper, 0.0f);
                }
            }
        }
    }

    @Override // org.campagnelab.dl.framework.mappers.LabelMapper
    public boolean hasMask() {
        return true;
    }

    @Override // org.campagnelab.dl.framework.mappers.LabelMapper
    public void maskLabels(RecordType recordtype, INDArray iNDArray, int i) {
        this.indicesMasker[0] = i;
        int i2 = 0;
        while (i2 < this.delegates.length) {
            this.indicesMasker[1] = i2;
            iNDArray.putScalar(this.indicesMasker, i2 < this.recordToSequenceLength.apply(recordtype).intValue() ? 1.0f : 0.0f);
            i2++;
        }
    }

    @Override // org.campagnelab.dl.framework.mappers.LabelMapper
    public float produceLabel(RecordType recordtype, int i) {
        int i2 = i / this.labelsPerTimeStep;
        return this.delegates[i2].produceLabel(recordtype, i % this.labelsPerTimeStep);
    }

    @Override // org.campagnelab.dl.framework.mappers.LabelMapper
    public boolean isMasked(RecordType recordtype, int i) {
        int i2 = i / this.labelsPerTimeStep;
        return this.delegates[i2].isMasked(recordtype, i % this.labelsPerTimeStep);
    }

    @Override // org.campagnelab.dl.framework.mappers.LabelMapper
    public void prepareToNormalize(RecordType recordtype, int i) {
    }

    static {
        $assertionsDisabled = !RNNLabelMapper.class.desiredAssertionStatus();
    }
}
