package org.campagnelab.dl.framework.mappers.pretraining;

import java.util.function.Function;
import org.campagnelab.dl.framework.mappers.ConcatFeatureMapper;
import org.campagnelab.dl.framework.mappers.FeatureMapper;
import org.campagnelab.dl.framework.mappers.MappedDimensions;
import org.campagnelab.dl.framework.mappers.NAryFeatureMapper;
import org.campagnelab.dl.framework.mappers.RNNFeatureMapper;
import org.campagnelab.dl.framework.mappers.TwoDimensionalConcatFeatureMapper;
import org.campagnelab.dl.framework.mappers.processing.TwoDimensionalRemoveMaskFeatureMapper;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/campagnelab/dl/framework/mappers/pretraining/RNNPretrainingFeatureMapper.class */
public class RNNPretrainingFeatureMapper<RecordType> implements FeatureMapper<RecordType> {
    private static Logger LOG = LoggerFactory.getLogger(RNNPretrainingFeatureMapper.class);
    public FeatureMapper<RecordType> delegate;

    public RNNPretrainingFeatureMapper(FeatureMapper<RecordType> featureMapper, Integer num, Function<RecordType, Integer> function) {
        MappedDimensions dimensions = featureMapper.dimensions();
        if (dimensions.numDimensions() != 2) {
            throw new IllegalArgumentException("Mapper must map two dimensional features");
        }
        int numElements = dimensions.numElements(1);
        if (num != null && num.intValue() > numElements) {
            throw new IllegalArgumentException(String.format("Invalid EOS index %d greater than number of features %d", num, Integer.valueOf(numElements)));
        }
        TwoDimensionalConcatFeatureMapper twoDimensionalConcatFeatureMapper = new TwoDimensionalConcatFeatureMapper(((num == null || num.intValue() != numElements) && num != null) ? 0 : 1, featureMapper);
        this.delegate = new TwoDimensionalRemoveMaskFeatureMapper(new TwoDimensionalConcatFeatureMapper(twoDimensionalConcatFeatureMapper, createEosMapper(num, numElements, function, LOG), twoDimensionalConcatFeatureMapper));
    }

    @Override // org.campagnelab.dl.framework.mappers.FeatureMapper
    public int numberOfFeatures() {
        return this.delegate.numberOfFeatures();
    }

    @Override // org.campagnelab.dl.framework.mappers.FeatureMapper
    public MappedDimensions dimensions() {
        return this.delegate.dimensions();
    }

    @Override // org.campagnelab.dl.framework.mappers.FeatureMapper
    public void prepareToNormalize(RecordType recordtype, int i) {
        this.delegate.prepareToNormalize(recordtype, i);
    }

    @Override // org.campagnelab.dl.framework.mappers.FeatureMapper
    public void mapFeatures(RecordType recordtype, INDArray iNDArray, int i) {
        this.delegate.mapFeatures(recordtype, iNDArray, i);
    }

    @Override // org.campagnelab.dl.framework.mappers.FeatureMapper
    public boolean hasMask() {
        return this.delegate.hasMask();
    }

    @Override // org.campagnelab.dl.framework.mappers.FeatureMapper
    public void maskFeatures(RecordType recordtype, INDArray iNDArray, int i) {
        this.delegate.maskFeatures(recordtype, iNDArray, i);
    }

    @Override // org.campagnelab.dl.framework.mappers.FeatureMapper
    public boolean isMasked(RecordType recordtype, int i) {
        return this.delegate.isMasked(recordtype, i);
    }

    @Override // org.campagnelab.dl.framework.mappers.FeatureMapper
    public float produceFeature(RecordType recordtype, int i) {
        return this.delegate.produceFeature(recordtype, i);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static <RecordType> FeatureMapper<RecordType> createEosMapper(Integer num, int i, Function<RecordType, Integer> function, Logger logger) {
        ConcatFeatureMapper concatFeatureMapper;
        if (num != null) {
            int intValue = num.intValue();
            int intValue2 = (i - num.intValue()) - 1;
            NAryFeatureMapper nAryFeatureMapper = new NAryFeatureMapper(1, false, true, true);
            NAryFeatureMapper nAryFeatureMapper2 = new NAryFeatureMapper(intValue, true, true, true);
            NAryFeatureMapper nAryFeatureMapper3 = new NAryFeatureMapper(intValue2, true, true, true);
            if (intValue > 0 && intValue2 > 0) {
                concatFeatureMapper = new ConcatFeatureMapper(nAryFeatureMapper2, nAryFeatureMapper, nAryFeatureMapper3);
            } else if (intValue > 0 && intValue2 <= 0) {
                concatFeatureMapper = new ConcatFeatureMapper(nAryFeatureMapper2, nAryFeatureMapper);
            } else if (intValue > 0 || intValue2 <= 0) {
                logger.warn("EOS index may be invalid; no features before or after");
                concatFeatureMapper = new ConcatFeatureMapper(nAryFeatureMapper);
            } else {
                concatFeatureMapper = new ConcatFeatureMapper(nAryFeatureMapper, nAryFeatureMapper3);
            }
        } else {
            concatFeatureMapper = new ConcatFeatureMapper(new NAryFeatureMapper(i, true, true, true), new NAryFeatureMapper(1, false, true, true));
        }
        return new RNNFeatureMapper(function, concatFeatureMapper);
    }
}
