package org.campagnelab.dl.framework.mappers;

import java.util.function.Function;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/campagnelab/dl/framework/mappers/OneHotBaseLabelMapper.class */
public class OneHotBaseLabelMapper<RecordType> implements LabelMapper<RecordType> {
    private int baseIndex;
    private int numLabels;
    private Function<RecordType, int[]> recordToLabel;
    private int[] cachedLabel;
    private static Logger LOG = LoggerFactory.getLogger(OneHotBaseFeatureMapper.class);
    private static final int[] indices = {0, 0};

    public OneHotBaseLabelMapper(int i, int i2, Function<RecordType, int[]> function) {
        this.baseIndex = i;
        this.numLabels = i2;
        this.recordToLabel = function;
    }

    @Override // org.campagnelab.dl.framework.mappers.LabelMapper
    public void prepareToNormalize(RecordType recordtype, int i) {
        this.cachedLabel = this.recordToLabel.apply(recordtype);
    }

    @Override // org.campagnelab.dl.framework.mappers.LabelMapper
    public int numberOfLabels() {
        return this.numLabels;
    }

    @Override // org.campagnelab.dl.framework.mappers.LabelMapper
    public MappedDimensions dimensions() {
        return new MappedDimensions(numberOfLabels());
    }

    @Override // org.campagnelab.dl.framework.mappers.LabelMapper
    public void mapLabels(RecordType recordtype, INDArray iNDArray, int i) {
        indices[0] = i;
        for (int i2 = 0; i2 < numberOfLabels(); i2++) {
            indices[1] = i2;
            iNDArray.putScalar(indices, produceLabel(recordtype, i2));
        }
    }

    @Override // org.campagnelab.dl.framework.mappers.LabelMapper
    public float produceLabel(RecordType recordtype, int i) {
        if (this.baseIndex >= 0 && this.baseIndex < this.cachedLabel.length) {
            return i == this.cachedLabel[this.baseIndex] ? 1.0f : 0.0f;
        }
        LOG.warn("incompatible base index: {} for label: {} of length {}", new Object[]{Integer.valueOf(this.baseIndex), this.cachedLabel, Integer.valueOf(this.cachedLabel.length)});
        return this.numLabels - 1;
    }

    @Override // org.campagnelab.dl.framework.mappers.LabelMapper
    public boolean hasMask() {
        return false;
    }

    @Override // org.campagnelab.dl.framework.mappers.LabelMapper
    public void maskLabels(RecordType recordtype, INDArray iNDArray, int i) {
    }

    @Override // org.campagnelab.dl.framework.mappers.LabelMapper
    public boolean isMasked(RecordType recordtype, int i) {
        return false;
    }
}
