package org.campagnelab.dl.framework.mappers.processing;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.function.Function;
import org.campagnelab.dl.framework.mappers.FeatureMapper;
import org.campagnelab.dl.framework.mappers.MappedDimensions;
import org.nd4j.linalg.api.ndarray.INDArray;

/* loaded from: input_file:org/campagnelab/dl/framework/mappers/processing/TwoDimensionalRemoveMaskFeatureMapper.class */
public class TwoDimensionalRemoveMaskFeatureMapper<RecordType> implements FeatureMapper<RecordType> {
    private FeatureMapper<RecordType> delegate;
    private MappedDimensions dim;
    private int featuresPerTimeStep;
    private int numTimeSteps;
    private Function<RecordType, Integer> recordToPaddingLength;
    private int[] mapperIndices;
    private int[] maskerIndices;
    private ArrayList<Bounds> boundsList;

    public TwoDimensionalRemoveMaskFeatureMapper(FeatureMapper<RecordType> featureMapper) {
        this(featureMapper, 0);
    }

    public TwoDimensionalRemoveMaskFeatureMapper(FeatureMapper<RecordType> featureMapper, int i) {
        this(featureMapper, obj -> {
            return Integer.valueOf(i);
        });
    }

    public TwoDimensionalRemoveMaskFeatureMapper(FeatureMapper<RecordType> featureMapper, Function<RecordType, Integer> function) {
        this.mapperIndices = new int[]{0, 0, 0};
        this.maskerIndices = new int[]{0, 0};
        this.dim = featureMapper.dimensions();
        if (this.dim.numDimensions() != 2) {
            throw new RuntimeException("Delegate mapper must be two dimensional");
        }
        this.featuresPerTimeStep = this.dim.numElements(1);
        this.numTimeSteps = this.dim.numElements(2);
        this.delegate = featureMapper;
        this.recordToPaddingLength = function;
    }

    @Override // org.campagnelab.dl.framework.mappers.FeatureMapper
    public int numberOfFeatures() {
        return this.delegate.numberOfFeatures();
    }

    @Override // org.campagnelab.dl.framework.mappers.FeatureMapper
    public MappedDimensions dimensions() {
        return this.delegate.dimensions();
    }

    @Override // org.campagnelab.dl.framework.mappers.FeatureMapper
    public void prepareToNormalize(RecordType recordtype, int i) {
        boolean z;
        this.delegate.prepareToNormalize(recordtype, i);
        this.boundsList = new ArrayList<>();
        if (this.delegate.hasMask()) {
            boolean z2 = true;
            Bounds bounds = new Bounds();
            int i2 = -1;
            for (int intValue = this.recordToPaddingLength.apply(recordtype).intValue(); intValue < this.dim.numElements(2); intValue++) {
                if (this.delegate.isMasked(recordtype, intValue * this.featuresPerTimeStep)) {
                    if (!z2) {
                        bounds.setEnd(intValue);
                        if (i2 >= 0) {
                            bounds.setShiftedSize(this.boundsList.get(i2));
                        }
                        this.boundsList.add(bounds);
                        i2++;
                        bounds = new Bounds();
                    }
                    z = true;
                } else {
                    if (z2) {
                        bounds.setStart(intValue);
                    }
                    z = false;
                }
                z2 = z;
            }
        }
    }

    @Override // org.campagnelab.dl.framework.mappers.FeatureMapper
    public void mapFeatures(RecordType recordtype, INDArray iNDArray, int i) {
        this.mapperIndices[0] = i;
        for (int i2 = 0; i2 < this.numTimeSteps; i2++) {
            for (int i3 = 0; i3 < this.featuresPerTimeStep; i3++) {
                int i4 = (i2 * this.featuresPerTimeStep) + i3;
                this.mapperIndices[1] = i3;
                this.mapperIndices[2] = i2;
                iNDArray.putScalar(this.mapperIndices, produceFeature(recordtype, i4));
            }
        }
    }

    @Override // org.campagnelab.dl.framework.mappers.FeatureMapper
    public boolean hasMask() {
        return true;
    }

    @Override // org.campagnelab.dl.framework.mappers.FeatureMapper
    public void maskFeatures(RecordType recordtype, INDArray iNDArray, int i) {
        this.maskerIndices[0] = i;
        for (int i2 = 0; i2 < this.numTimeSteps; i2++) {
            int i3 = i2 * this.featuresPerTimeStep;
            this.maskerIndices[1] = i2;
            iNDArray.putScalar(this.maskerIndices, isMasked(recordtype, i3) ? 1.0f : 0.0f);
        }
    }

    @Override // org.campagnelab.dl.framework.mappers.FeatureMapper
    public boolean isMasked(RecordType recordtype, int i) {
        int i2 = i / this.featuresPerTimeStep;
        int i3 = i % this.featuresPerTimeStep;
        int i4 = 0;
        Iterator<Bounds> it = this.boundsList.iterator();
        while (it.hasNext()) {
            Bounds next = it.next();
            if (next.contains(i2)) {
                i4 += next.size();
            }
        }
        int i5 = i2 + i4;
        if (i5 >= this.numTimeSteps) {
            return false;
        }
        return this.delegate.isMasked(recordtype, (i5 * this.featuresPerTimeStep) + i3);
    }

    @Override // org.campagnelab.dl.framework.mappers.FeatureMapper
    public float produceFeature(RecordType recordtype, int i) {
        int i2 = i / this.featuresPerTimeStep;
        int i3 = i % this.featuresPerTimeStep;
        int i4 = 0;
        Iterator<Bounds> it = this.boundsList.iterator();
        while (it.hasNext()) {
            Bounds next = it.next();
            if (next.contains(i2)) {
                i4 += next.size();
            }
        }
        int i5 = i2 + i4;
        if (i5 >= this.numTimeSteps) {
            return 0.0f;
        }
        return this.delegate.produceFeature(recordtype, (i5 * this.featuresPerTimeStep) + i3);
    }
}
