package org.campagnelab.dl.framework.iterators;

import it.unimi.dsi.fastutil.ints.IntArraySet;
import java.util.Arrays;
import org.campagnelab.dl.framework.mappers.FeatureMapper;
import org.campagnelab.dl.framework.mappers.MappedDimensions;
import org.nd4j.linalg.api.ndarray.INDArray;

/* loaded from: input_file:org/campagnelab/dl/framework/iterators/ConcatFeatureMapper.class */
public class ConcatFeatureMapper<RecordType> implements FeatureMapper<RecordType> {
    protected FeatureMapper<RecordType>[] mappers;
    protected int numFeatures = 0;
    protected int[] offsets;
    private boolean normalizedCalled;
    static final /* synthetic */ boolean $assertionsDisabled;

    @SafeVarargs
    public ConcatFeatureMapper(FeatureMapper<RecordType>... featureMapperArr) {
        IntArraySet intArraySet = new IntArraySet();
        this.mappers = featureMapperArr;
        int i = 1;
        this.offsets = new int[featureMapperArr.length + 1];
        this.offsets[0] = 0;
        for (FeatureMapper<RecordType> featureMapper : this.mappers) {
            this.numFeatures += featureMapper.numberOfFeatures();
            this.offsets[i] = this.numFeatures;
            intArraySet.add(featureMapper.dimensions().numDimensions());
            i++;
        }
        if (!$assertionsDisabled && featureMapperArr.length != 0 && intArraySet.size() != 1) {
            throw new AssertionError("All feature mappers must have the same dimensions to be concatenated.");
        }
    }

    @Override // org.campagnelab.dl.framework.mappers.FeatureMapper
    public int numberOfFeatures() {
        return this.numFeatures;
    }

    @Override // org.campagnelab.dl.framework.mappers.FeatureMapper
    public MappedDimensions dimensions() {
        return new MappedDimensions(numberOfFeatures());
    }

    @Override // org.campagnelab.dl.framework.mappers.FeatureMapper
    public void prepareToNormalize(RecordType recordtype, int i) {
        for (FeatureMapper<RecordType> featureMapper : this.mappers) {
            featureMapper.prepareToNormalize(recordtype, i);
        }
        this.normalizedCalled = true;
    }

    @Override // org.campagnelab.dl.framework.mappers.FeatureMapper
    public void mapFeatures(RecordType recordtype, INDArray iNDArray, int i) {
        if (!$assertionsDisabled && !this.normalizedCalled) {
            throw new AssertionError("prepareToNormalize must be called before mapFeatures.");
        }
        int i2 = 0;
        int[] iArr = {0, 0};
        for (FeatureMapper<RecordType> featureMapper : this.mappers) {
            int numberOfFeatures = featureMapper.numberOfFeatures();
            for (int i3 = 0; i3 < numberOfFeatures; i3++) {
                iArr[0] = i;
                iArr[1] = i3 + i2;
                iNDArray.putScalar(iArr, featureMapper.produceFeature(recordtype, i3));
            }
            i2 += numberOfFeatures;
        }
    }

    @Override // org.campagnelab.dl.framework.mappers.FeatureMapper
    public boolean hasMask() {
        boolean z = false;
        for (FeatureMapper<RecordType> featureMapper : this.mappers) {
            z |= featureMapper.hasMask();
        }
        return z;
    }

    @Override // org.campagnelab.dl.framework.mappers.FeatureMapper
    public void maskFeatures(RecordType recordtype, INDArray iNDArray, int i) {
        if (hasMask()) {
            int i2 = 0;
            int[] iArr = {0, 0};
            for (FeatureMapper<RecordType> featureMapper : this.mappers) {
                int numberOfFeatures = featureMapper.numberOfFeatures();
                for (int i3 = 0; i3 < numberOfFeatures; i3++) {
                    iArr[0] = i;
                    iArr[1] = i3 + i2;
                    iNDArray.putScalar(iArr, featureMapper.isMasked(recordtype, i3) ? 1 : 0);
                }
                i2 += numberOfFeatures;
            }
        }
    }

    @Override // org.campagnelab.dl.framework.mappers.FeatureMapper
    public boolean isMasked(RecordType recordtype, int i) {
        int binarySearch = Arrays.binarySearch(this.offsets, i);
        if (binarySearch < 0) {
            binarySearch = (-(binarySearch + 1)) - 1;
        }
        return this.mappers[binarySearch].isMasked(recordtype, i - this.offsets[binarySearch]);
    }

    @Override // org.campagnelab.dl.framework.mappers.FeatureMapper
    public float produceFeature(RecordType recordtype, int i) {
        int binarySearch = Arrays.binarySearch(this.offsets, i);
        if (binarySearch < 0) {
            binarySearch = (-(binarySearch + 1)) - 1;
        }
        return this.mappers[binarySearch].produceFeature(recordtype, i - this.offsets[binarySearch]);
    }

    static {
        $assertionsDisabled = !ConcatFeatureMapper.class.desiredAssertionStatus();
    }
}
