package org.deidentifier.arx.aggregates.classification;

import org.apache.hadoop.mapred.lib.aggregate.ValueAggregatorDescriptor;
import org.apache.mahout.classifier.sgd.ElasticBandPrior;
import org.apache.mahout.classifier.sgd.L1;
import org.apache.mahout.classifier.sgd.L2;
import org.apache.mahout.classifier.sgd.OnlineLogisticRegression;
import org.apache.mahout.classifier.sgd.PriorFunction;
import org.apache.mahout.classifier.sgd.UniformPrior;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.vectorizer.encoders.ConstantValueEncoder;
import org.apache.mahout.vectorizer.encoders.StaticWordValueEncoder;
import org.deidentifier.arx.DataHandleInternal;
import org.deidentifier.arx.aggregates.ClassificationConfigurationLogisticRegression;
import org.deidentifier.arx.common.WrappedBoolean;

/* loaded from: input_file:libarx-3.7.1.jar:org/deidentifier/arx/aggregates/classification/MultiClassLogisticRegression.class */
public class MultiClassLogisticRegression extends ClassificationMethod {
    private final ClassificationConfigurationLogisticRegression config;
    private final ConstantValueEncoder interceptEncoder;
    private final OnlineLogisticRegression lr;
    private final ClassificationDataSpecification specification;
    private final StaticWordValueEncoder wordEncoder;
    private final DataHandleInternal inputHandle;

    public MultiClassLogisticRegression(WrappedBoolean wrappedBoolean, ClassificationDataSpecification classificationDataSpecification, ClassificationConfigurationLogisticRegression classificationConfigurationLogisticRegression, DataHandleInternal dataHandleInternal) {
        super(wrappedBoolean);
        PriorFunction uniformPrior;
        this.config = classificationConfigurationLogisticRegression;
        this.specification = classificationDataSpecification;
        this.inputHandle = dataHandleInternal;
        switch (classificationConfigurationLogisticRegression.getPriorFunction()) {
            case ELASTIC_BAND:
                uniformPrior = new ElasticBandPrior();
                break;
            case L1:
                uniformPrior = new L1();
                break;
            case L2:
                uniformPrior = new L2();
                break;
            case UNIFORM:
                uniformPrior = new UniformPrior();
                break;
            default:
                throw new IllegalArgumentException("Unknown prior function");
        }
        this.lr = new OnlineLogisticRegression(this.specification.classMap.size(), classificationConfigurationLogisticRegression.getVectorLength(), uniformPrior);
        this.lr.learningRate(classificationConfigurationLogisticRegression.getLearningRate());
        this.lr.alpha(classificationConfigurationLogisticRegression.getAlpha());
        this.lr.lambda(classificationConfigurationLogisticRegression.getLambda());
        this.lr.stepOffset(classificationConfigurationLogisticRegression.getStepOffset());
        this.lr.decayExponent(classificationConfigurationLogisticRegression.getDecayExponent());
        this.interceptEncoder = new ConstantValueEncoder("intercept");
        this.wordEncoder = new StaticWordValueEncoder("feature");
    }

    @Override // org.deidentifier.arx.aggregates.classification.ClassificationMethod
    public ClassificationResult classify(DataHandleInternal dataHandleInternal, int i) {
        return new MultiClassLogisticRegressionClassificationResult(this.lr.classifyFull(encodeFeatures(dataHandleInternal, i, true)), this.specification.classMap);
    }

    @Override // org.deidentifier.arx.aggregates.classification.ClassificationMethod
    public void close() {
        this.lr.close();
    }

    @Override // org.deidentifier.arx.aggregates.classification.ClassificationMethod
    public void train(DataHandleInternal dataHandleInternal, DataHandleInternal dataHandleInternal2, int i) {
        this.lr.train(encodeClass(dataHandleInternal2, i), encodeFeatures(dataHandleInternal, i, false));
    }

    private int encodeClass(DataHandleInternal dataHandleInternal, int i) {
        return this.specification.classMap.get(dataHandleInternal.getValue(i, this.specification.classIndex, true)).intValue();
    }

    private Vector encodeFeatures(DataHandleInternal dataHandleInternal, int i, boolean z) {
        DenseVector denseVector = new DenseVector(this.config.getVectorLength());
        this.interceptEncoder.addToVector("1", denseVector);
        if (this.specification.featureIndices.length == 0) {
            this.wordEncoder.addToVector("Feature:1", 1.0d, denseVector);
            return denseVector;
        }
        int i2 = 0;
        for (int i3 : this.specification.featureIndices) {
            ClassificationFeatureMetadata classificationFeatureMetadata = this.specification.featureMetadata[i2];
            String value = (z && classificationFeatureMetadata.isNumericMicroaggregation()) ? this.inputHandle.getValue(i, i3, true) : dataHandleInternal.getValue(i, i3, true);
            Double valueOf = Double.valueOf(classificationFeatureMetadata.getNumericValue(value));
            if (Double.isNaN(valueOf.doubleValue())) {
                this.wordEncoder.addToVector("Attribute-" + i3 + ValueAggregatorDescriptor.TYPE_SEPARATOR + value, 1.0d, denseVector);
            } else {
                this.wordEncoder.addToVector("Attribute-" + i3, valueOf.doubleValue(), denseVector);
            }
            i2++;
        }
        return denseVector;
    }
}
