package sklearn.linear_model;

import java.util.ArrayList;
import java.util.List;
import org.dmg.pmml.DataType;
import org.dmg.pmml.Model;
import org.dmg.pmml.OpType;
import org.dmg.pmml.regression.RegressionModel;
import org.jpmml.converter.CMatrixUtil;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.FieldNameUtil;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.SchemaUtil;
import org.jpmml.converter.Transformation;
import org.jpmml.converter.mining.MiningModelUtil;
import org.jpmml.converter.regression.RegressionModelUtil;
import sklearn.Classifier;
import sklearn.Estimator;

/* loaded from: input_file:sklearn/linear_model/LinearClassifier.class */
public class LinearClassifier extends Classifier {
    public LinearClassifier(String str, String str2) {
        super(str, str2);
    }

    @Override // sklearn.Estimator, sklearn.HasNumberOfFeatures
    public int getNumberOfFeatures() {
        return getCoefShape()[1];
    }

    @Override // sklearn.Estimator
    /* renamed from: encodeModel */
    public Model mo16encodeModel(Schema schema) {
        int[] coefShape = getCoefShape();
        int i = coefShape[0];
        int i2 = coefShape[1];
        boolean hasProbabilityDistribution = hasProbabilityDistribution();
        List<? extends Number> coef = getCoef();
        List<? extends Number> intercept = getIntercept();
        CategoricalLabel label = schema.getLabel();
        List features = schema.getFeatures();
        if (i == 1) {
            SchemaUtil.checkSize(2, label);
            return RegressionModelUtil.createBinaryLogisticClassification(features, CMatrixUtil.getRow(coef, i, i2, 0), intercept.get(0), RegressionModel.NormalizationMethod.LOGIT, hasProbabilityDistribution, schema);
        }
        if (i < 3) {
            throw new IllegalArgumentException();
        }
        SchemaUtil.checkSize(i, label);
        Schema emptySchema = schema.toAnonymousRegressorSchema(DataType.DOUBLE).toEmptySchema();
        ArrayList arrayList = new ArrayList();
        int size = label.size();
        for (int i3 = 0; i3 < size; i3++) {
            arrayList.add(RegressionModelUtil.createRegression(features, CMatrixUtil.getRow(coef, i, i2, i3), intercept.get(i3), RegressionModel.NormalizationMethod.LOGIT, emptySchema).setOutput(ModelUtil.createPredictedOutput(FieldNameUtil.create(Estimator.FIELD_DECISION_FUNCTION, new Object[]{label.getValue(i3)}), OpType.CONTINUOUS, DataType.DOUBLE, new Transformation[0])));
        }
        return MiningModelUtil.createClassification(arrayList, RegressionModel.NormalizationMethod.SIMPLEMAX, hasProbabilityDistribution, schema);
    }

    public List<? extends Number> getCoef() {
        return getNumberArray("coef_");
    }

    public int[] getCoefShape() {
        return getArrayShape("coef_", 2);
    }

    public List<? extends Number> getIntercept() {
        return getNumberArray("intercept_");
    }
}
