package org.campagnelab.dl.framework.mappers;

import org.nd4j.linalg.api.ndarray.INDArray;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/campagnelab/dl/framework/mappers/NAryLabelMapper.class */
public class NAryLabelMapper<RecordType> implements LabelMapper<RecordType> {
    private static Logger LOG = LoggerFactory.getLogger(NAryLabelMapper.class);
    private int numLabels;
    private boolean isZero;
    private boolean hasMask;
    private boolean isMasked;
    private MappedDimensions dim;
    private int[] mapperIndices = {0, 0};

    public NAryLabelMapper(int i, boolean z, boolean z2, boolean z3) {
        this.numLabels = i;
        this.isZero = z;
        if (i > 1 && !z) {
            LOG.warn("Mapping multiple 1s may break one-hot encoding");
        }
        this.hasMask = z2;
        this.isMasked = z3;
        this.dim = new MappedDimensions(i);
    }

    @Override // org.campagnelab.dl.framework.mappers.LabelMapper
    public int numberOfLabels() {
        return this.numLabels;
    }

    @Override // org.campagnelab.dl.framework.mappers.LabelMapper
    public MappedDimensions dimensions() {
        return this.dim;
    }

    @Override // org.campagnelab.dl.framework.mappers.LabelMapper
    public void prepareToNormalize(RecordType recordtype, int i) {
    }

    @Override // org.campagnelab.dl.framework.mappers.LabelMapper
    public void mapLabels(RecordType recordtype, INDArray iNDArray, int i) {
        this.mapperIndices[0] = i;
        for (int i2 = 0; i2 < this.numLabels; i2++) {
            this.mapperIndices[1] = i2;
            iNDArray.putScalar(this.mapperIndices, produceLabel(recordtype, i2));
        }
    }

    @Override // org.campagnelab.dl.framework.mappers.LabelMapper
    public boolean hasMask() {
        return this.hasMask;
    }

    @Override // org.campagnelab.dl.framework.mappers.LabelMapper
    public void maskLabels(RecordType recordtype, INDArray iNDArray, int i) {
        iNDArray.putScalar(i, isMasked(recordtype, -1) ? 1.0f : 0.0f);
    }

    @Override // org.campagnelab.dl.framework.mappers.LabelMapper
    public boolean isMasked(RecordType recordtype, int i) {
        return this.isMasked;
    }

    @Override // org.campagnelab.dl.framework.mappers.LabelMapper
    public float produceLabel(RecordType recordtype, int i) {
        return this.isZero ? 0.0f : 1.0f;
    }
}
