package org.campagnelab.dl.framework.mappers;

import java.util.function.BiFunction;
import java.util.function.Function;
import org.campagnelab.goby.util.WarningCounter;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/campagnelab/dl/framework/mappers/OneHotBaseFeatureMapper.class */
public class OneHotBaseFeatureMapper<RecordType> implements FeatureMapper<RecordType> {
    private final int numFeatures;
    private Function<RecordType, String> recordToString;
    private BiFunction<String, Integer, Integer> recordStringAtBaseToInteger;
    private int baseIndex;
    private String cachedString;
    private static Logger LOG = LoggerFactory.getLogger(OneHotBaseFeatureMapper.class);
    private static final int[] indices = {0, 0};
    private static WarningCounter counter = new WarningCounter();

    public OneHotBaseFeatureMapper(int i, Function<RecordType, String> function) {
        this(i, function, (v0, v1) -> {
            return getIntegerOfBase(v0, v1);
        }, 6);
    }

    public OneHotBaseFeatureMapper(int i, Function<RecordType, String> function, BiFunction<String, Integer, Integer> biFunction, int i2) {
        this.baseIndex = i;
        this.recordToString = function;
        this.recordStringAtBaseToInteger = biFunction;
        this.numFeatures = i2;
    }

    @Override // org.campagnelab.dl.framework.mappers.FeatureMapper
    public int numberOfFeatures() {
        return this.numFeatures;
    }

    @Override // org.campagnelab.dl.framework.mappers.FeatureMapper
    public void prepareToNormalize(RecordType recordtype, int i) {
        this.cachedString = this.recordToString.apply(recordtype);
    }

    @Override // org.campagnelab.dl.framework.mappers.FeatureMapper
    public void mapFeatures(RecordType recordtype, INDArray iNDArray, int i) {
        indices[0] = i;
        for (int i2 = 0; i2 < numberOfFeatures(); i2++) {
            indices[1] = i2;
            iNDArray.putScalar(indices, produceFeature(recordtype, i2));
        }
    }

    @Override // org.campagnelab.dl.framework.mappers.FeatureMapper
    public float produceFeature(RecordType recordtype, int i) {
        return this.recordStringAtBaseToInteger.apply(this.cachedString, Integer.valueOf(this.baseIndex)).intValue() == i ? 1.0f : 0.0f;
    }

    @Override // org.campagnelab.dl.framework.mappers.FeatureMapper
    public boolean hasMask() {
        return false;
    }

    @Override // org.campagnelab.dl.framework.mappers.FeatureMapper
    public MappedDimensions dimensions() {
        return new MappedDimensions(numberOfFeatures());
    }

    @Override // org.campagnelab.dl.framework.mappers.FeatureMapper
    public void maskFeatures(RecordType recordtype, INDArray iNDArray, int i) {
    }

    @Override // org.campagnelab.dl.framework.mappers.FeatureMapper
    public boolean isMasked(RecordType recordtype, int i) {
        return false;
    }

    private static int getIntegerOfBase(String str, int i) {
        int i2;
        if (i < 0 || i >= str.length()) {
            counter.warn(LOG, String.format("incompatible character index: {} for context: {} of length {}", Integer.valueOf(i), str, Integer.valueOf(str.length())), new Object[0]);
            return 5;
        }
        switch (Character.valueOf(str.charAt(i)).charValue()) {
            case 'A':
            case 'a':
                i2 = 0;
                break;
            case 'C':
            case 'c':
                i2 = 2;
                break;
            case 'G':
            case 'g':
                i2 = 3;
                break;
            case 'N':
            case 'n':
                i2 = 4;
                break;
            case 'T':
            case 't':
                i2 = 1;
                break;
            default:
                i2 = 5;
                break;
        }
        return i2;
    }
}
