package org.campagnelab.dl.framework.mappers;

import it.unimi.dsi.fastutil.ints.IntArraySet;
import java.util.Arrays;
import org.nd4j.linalg.api.ndarray.INDArray;

/* loaded from: input_file:org/campagnelab/dl/framework/mappers/ConcatLabelMapper.class */
public class ConcatLabelMapper<RecordType> implements LabelMapper<RecordType> {
    protected LabelMapper<RecordType>[] mappers;
    protected int numFeatures = 0;
    protected int[] offsets;
    private boolean normalizedCalled;
    static final /* synthetic */ boolean $assertionsDisabled;

    @SafeVarargs
    public ConcatLabelMapper(LabelMapper<RecordType>... labelMapperArr) {
        IntArraySet intArraySet = new IntArraySet();
        this.mappers = labelMapperArr;
        int i = 1;
        this.offsets = new int[labelMapperArr.length + 1];
        this.offsets[0] = 0;
        for (LabelMapper<RecordType> labelMapper : this.mappers) {
            this.numFeatures += labelMapper.numberOfLabels();
            this.offsets[i] = this.numFeatures;
            intArraySet.add(labelMapper.dimensions().numDimensions());
            i++;
        }
        if (!$assertionsDisabled && labelMapperArr.length != 0 && intArraySet.size() != 1) {
            throw new AssertionError("All feature mappers must have the same dimensions to be concatenated.");
        }
    }

    @Override // org.campagnelab.dl.framework.mappers.LabelMapper
    public int numberOfLabels() {
        return this.numFeatures;
    }

    @Override // org.campagnelab.dl.framework.mappers.LabelMapper
    public MappedDimensions dimensions() {
        return new MappedDimensions(numberOfLabels());
    }

    @Override // org.campagnelab.dl.framework.mappers.LabelMapper
    public void prepareToNormalize(RecordType recordtype, int i) {
        for (LabelMapper<RecordType> labelMapper : this.mappers) {
            labelMapper.prepareToNormalize(recordtype, i);
        }
        this.normalizedCalled = true;
    }

    @Override // org.campagnelab.dl.framework.mappers.LabelMapper
    public void mapLabels(RecordType recordtype, INDArray iNDArray, int i) {
        if (!$assertionsDisabled && !this.normalizedCalled) {
            throw new AssertionError("prepareToNormalize must be called before mapFeatures.");
        }
        int i2 = 0;
        int[] iArr = {0, 0};
        for (LabelMapper<RecordType> labelMapper : this.mappers) {
            int numberOfLabels = labelMapper.numberOfLabels();
            for (int i3 = 0; i3 < numberOfLabels; i3++) {
                iArr[0] = i;
                iArr[1] = i3 + i2;
                iNDArray.putScalar(iArr, labelMapper.produceLabel(recordtype, i3));
            }
            i2 += numberOfLabels;
        }
    }

    @Override // org.campagnelab.dl.framework.mappers.LabelMapper
    public boolean hasMask() {
        boolean z = false;
        for (LabelMapper<RecordType> labelMapper : this.mappers) {
            z |= labelMapper.hasMask();
        }
        return z;
    }

    @Override // org.campagnelab.dl.framework.mappers.LabelMapper
    public void maskLabels(RecordType recordtype, INDArray iNDArray, int i) {
        if (hasMask()) {
            int i2 = 0;
            int[] iArr = {0, 0};
            for (LabelMapper<RecordType> labelMapper : this.mappers) {
                int numberOfLabels = labelMapper.numberOfLabels();
                for (int i3 = 0; i3 < numberOfLabels; i3++) {
                    iArr[0] = i;
                    iArr[1] = i3 + i2;
                    iNDArray.putScalar(iArr, labelMapper.isMasked(recordtype, i3) ? 1 : 0);
                }
                i2 += numberOfLabels;
            }
        }
    }

    @Override // org.campagnelab.dl.framework.mappers.LabelMapper
    public boolean isMasked(RecordType recordtype, int i) {
        int binarySearch = Arrays.binarySearch(this.offsets, i);
        if (binarySearch < 0) {
            binarySearch = (-(binarySearch + 1)) - 1;
        }
        return this.mappers[binarySearch].isMasked(recordtype, i - this.offsets[binarySearch]);
    }

    @Override // org.campagnelab.dl.framework.mappers.LabelMapper
    public float produceLabel(RecordType recordtype, int i) {
        int binarySearch = Arrays.binarySearch(this.offsets, i);
        if (binarySearch < 0) {
            binarySearch = (-(binarySearch + 1)) - 1;
        }
        return this.mappers[binarySearch].produceLabel(recordtype, i - this.offsets[binarySearch]);
    }

    static {
        $assertionsDisabled = !ConcatLabelMapper.class.desiredAssertionStatus();
    }
}
