package org.campagnelab.dl.framework.mappers;

import org.nd4j.linalg.api.ndarray.INDArray;

@Deprecated
/* loaded from: input_file:org/campagnelab/dl/framework/mappers/RNNPretrainingLabelMapper.class */
public class RNNPretrainingLabelMapper<RecordType> implements LabelMapper<RecordType> {
    private RNNFeatureMapper<RecordType> labelMapper;
    private MappedDimensions dim;
    private int labelsPerTimeStep;
    private int maxSequenceLen;
    private int[] mapperIndices = {0, 0, 0};
    private int[] maskerIndices = {0, 0};

    public RNNPretrainingLabelMapper(RNNFeatureMapper<RecordType> rNNFeatureMapper) {
        this.labelMapper = rNNFeatureMapper;
        this.labelsPerTimeStep = rNNFeatureMapper.featuresPerTimeStep();
        this.maxSequenceLen = rNNFeatureMapper.maxSequenceLength();
        this.dim = new MappedDimensions(this.labelsPerTimeStep, (this.maxSequenceLen * 2) + 1);
    }

    @Override // org.campagnelab.dl.framework.mappers.LabelMapper
    public int numberOfLabels() {
        return (this.labelMapper.numberOfFeatures() * 2) + 1;
    }

    @Override // org.campagnelab.dl.framework.mappers.LabelMapper
    public MappedDimensions dimensions() {
        return this.dim;
    }

    @Override // org.campagnelab.dl.framework.mappers.LabelMapper
    public void prepareToNormalize(RecordType recordtype, int i) {
        this.labelMapper.prepareToNormalize(recordtype, i);
    }

    @Override // org.campagnelab.dl.framework.mappers.LabelMapper
    public void mapLabels(RecordType recordtype, INDArray iNDArray, int i) {
        this.labelMapper.prepareToNormalize(recordtype, i);
        this.mapperIndices[0] = i;
        for (int i2 = 0; i2 < this.maxSequenceLen; i2++) {
            this.mapperIndices[2] = i2;
            int i3 = i2 % this.labelMapper.sequenceLength;
            for (int i4 = 0; i4 < this.labelsPerTimeStep; i4++) {
                this.mapperIndices[1] = i4;
                iNDArray.putScalar(this.mapperIndices, produceLabel(recordtype, (i3 * this.labelsPerTimeStep) + i4));
            }
        }
    }

    @Override // org.campagnelab.dl.framework.mappers.LabelMapper
    public boolean hasMask() {
        return true;
    }

    @Override // org.campagnelab.dl.framework.mappers.LabelMapper
    public void maskLabels(RecordType recordtype, INDArray iNDArray, int i) {
        this.labelMapper.prepareToNormalize(recordtype, i);
        this.maskerIndices[0] = i;
        int sequenceLenWithPadding = sequenceLenWithPadding(this.labelMapper.sequenceLength);
        int i2 = 0;
        while (i2 < this.maxSequenceLen) {
            this.maskerIndices[1] = i2;
            iNDArray.putScalar(this.maskerIndices, (i2 < this.labelMapper.sequenceLength || i2 >= sequenceLenWithPadding) ? 0.0f : 1.0f);
            i2++;
        }
    }

    @Override // org.campagnelab.dl.framework.mappers.LabelMapper
    public boolean isMasked(RecordType recordtype, int i) {
        int sequenceLenWithPadding = sequenceLenWithPadding(this.labelMapper.sequenceLength);
        int i2 = i / this.labelsPerTimeStep;
        return i2 >= this.labelMapper.sequenceLength && i2 < sequenceLenWithPadding;
    }

    @Override // org.campagnelab.dl.framework.mappers.LabelMapper
    public float produceLabel(RecordType recordtype, int i) {
        int sequenceLenWithPadding = sequenceLenWithPadding(this.labelMapper.sequenceLength);
        int i2 = i / this.labelsPerTimeStep;
        if (i2 < this.labelMapper.sequenceLength || i2 >= sequenceLenWithPadding) {
            return 0.0f;
        }
        return this.labelMapper.produceFeature(recordtype, i);
    }

    private int sequenceLenWithPadding(int i) {
        return i * 2;
    }
}
