package org.campagnelab.dl.framework.mappers;

import java.util.function.Function;
import org.nd4j.linalg.api.ndarray.INDArray;

/* loaded from: input_file:org/campagnelab/dl/framework/mappers/RNNFeatureMapper.class */
public class RNNFeatureMapper<RecordType> implements FeatureMapper<RecordType> {
    private int featuresPerTimeStep;
    private FeatureMapper<RecordType>[] delegates;
    private Function<RecordType, Integer> recordToSequenceLength;
    private int[] indicesMapper;
    private int[] indicesMasker;
    int sequenceLength;
    private MappedDimensions dim;
    static final /* synthetic */ boolean $assertionsDisabled;

    public RNNFeatureMapper(int i, Function<RecordType, String> function, Function<RecordType, Integer> function2) {
        this(function2, createOneHotBaseFeatureMappers(i, function));
    }

    @SafeVarargs
    public RNNFeatureMapper(Function<RecordType, Integer> function, FeatureMapper<RecordType>... featureMapperArr) {
        this.indicesMapper = new int[]{0, 0, 0};
        this.indicesMasker = new int[]{0, 0};
        this.recordToSequenceLength = function;
        MappedDimensions dimensions = featureMapperArr[0].dimensions();
        for (FeatureMapper<RecordType> featureMapper : featureMapperArr) {
            MappedDimensions dimensions2 = featureMapper.dimensions();
            if (!dimensions.equals(dimensions2)) {
                throw new RuntimeException("All delegate feature mappers should have same dimensions");
            }
            if (dimensions2.numDimensions() != 1) {
                throw new RuntimeException("All delegate feature mappers should be one dimensional");
            }
        }
        this.featuresPerTimeStep = dimensions.numElements();
        this.delegates = featureMapperArr;
        this.dim = new MappedDimensions(this.featuresPerTimeStep, featureMapperArr.length);
    }

    @Override // org.campagnelab.dl.framework.mappers.FeatureMapper
    public int numberOfFeatures() {
        return this.delegates.length * this.featuresPerTimeStep;
    }

    @Override // org.campagnelab.dl.framework.mappers.FeatureMapper
    public MappedDimensions dimensions() {
        if ($assertionsDisabled || this.dim.numElements() == numberOfFeatures()) {
            return this.dim;
        }
        throw new AssertionError("Number of elements must match number of features.");
    }

    @Override // org.campagnelab.dl.framework.mappers.FeatureMapper
    public void prepareToNormalize(RecordType recordtype, int i) {
        this.sequenceLength = this.recordToSequenceLength.apply(recordtype).intValue();
        for (FeatureMapper<RecordType> featureMapper : this.delegates) {
            featureMapper.prepareToNormalize(recordtype, i);
        }
    }

    @Override // org.campagnelab.dl.framework.mappers.FeatureMapper
    public void mapFeatures(RecordType recordtype, INDArray iNDArray, int i) {
        this.indicesMapper[0] = i;
        for (int i2 = 0; i2 < this.delegates.length; i2++) {
            this.indicesMapper[2] = i2;
            FeatureMapper<RecordType> featureMapper = this.delegates[i2];
            for (int i3 = 0; i3 < featureMapper.numberOfFeatures(); i3++) {
                this.indicesMapper[1] = i3;
                if (i2 < this.sequenceLength) {
                    iNDArray.putScalar(this.indicesMapper, featureMapper.produceFeature(recordtype, i3));
                } else {
                    iNDArray.putScalar(this.indicesMapper, 0.0f);
                }
            }
        }
    }

    @Override // org.campagnelab.dl.framework.mappers.FeatureMapper
    public boolean hasMask() {
        return true;
    }

    @Override // org.campagnelab.dl.framework.mappers.FeatureMapper
    public void maskFeatures(RecordType recordtype, INDArray iNDArray, int i) {
        this.indicesMasker[0] = i;
        int i2 = 0;
        while (i2 < this.delegates.length) {
            this.indicesMasker[1] = i2;
            iNDArray.putScalar(this.indicesMasker, i2 < this.sequenceLength ? 1.0f : 0.0f);
            i2++;
        }
    }

    @Override // org.campagnelab.dl.framework.mappers.FeatureMapper
    public float produceFeature(RecordType recordtype, int i) {
        int i2 = i / this.featuresPerTimeStep;
        if (i2 >= this.sequenceLength) {
            return 0.0f;
        }
        return this.delegates[i2].produceFeature(recordtype, i % this.featuresPerTimeStep);
    }

    @Override // org.campagnelab.dl.framework.mappers.FeatureMapper
    public boolean isMasked(RecordType recordtype, int i) {
        int i2 = i / this.featuresPerTimeStep;
        if (i2 >= this.sequenceLength) {
            return false;
        }
        return !this.delegates[i2].hasMask() || this.delegates[i2].isMasked(recordtype, i % this.featuresPerTimeStep);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public int maxSequenceLength() {
        return this.delegates.length;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public int featuresPerTimeStep() {
        return this.featuresPerTimeStep;
    }

    private static <RecordType> FeatureMapper<RecordType>[] createOneHotBaseFeatureMappers(int i, Function<RecordType, String> function) {
        OneHotBaseFeatureMapper[] oneHotBaseFeatureMapperArr = new OneHotBaseFeatureMapper[i];
        for (int i2 = 0; i2 < i; i2++) {
            oneHotBaseFeatureMapperArr[i2] = new OneHotBaseFeatureMapper(i2, function);
        }
        return oneHotBaseFeatureMapperArr;
    }

    static {
        $assertionsDisabled = !RNNFeatureMapper.class.desiredAssertionStatus();
    }
}
