package sklearn.linear_model;

import com.google.common.collect.Iterables;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import org.dmg.pmml.Model;
import org.dmg.pmml.mining.Segmentation;
import org.dmg.pmml.regression.RegressionModel;
import org.jpmml.converter.CMatrixUtil;
import org.jpmml.converter.Label;
import org.jpmml.converter.ScalarLabelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.mining.MiningModelUtil;
import org.jpmml.converter.regression.RegressionModelUtil;
import org.jpmml.python.ClassDictUtil;
import sklearn.SkLearnRegressor;

/* loaded from: input_file:sklearn/linear_model/LinearRegressor.class */
public class LinearRegressor extends SkLearnRegressor {
    public LinearRegressor(String str, String str2) {
        super(str, str2);
    }

    @Override // sklearn.Estimator, sklearn.HasNumberOfFeatures
    public int getNumberOfFeatures() {
        int[] coefShape = getCoefShape();
        return coefShape.length == 2 ? coefShape[1] : coefShape[0];
    }

    @Override // sklearn.Regressor, sklearn.Estimator, sklearn.HasNumberOfOutputs
    public int getNumberOfOutputs() {
        int[] coefShape = getCoefShape();
        if (coefShape.length == 2) {
            return coefShape[0];
        }
        return 1;
    }

    @Override // sklearn.Estimator
    /* renamed from: encodeModel */
    public Model mo7encodeModel(Schema schema) {
        List<? extends Number> coef = getCoef();
        List<? extends Number> intercept = getIntercept();
        Label label = schema.getLabel();
        List features = schema.getFeatures();
        int numberOfOutputs = getNumberOfOutputs();
        if (numberOfOutputs == 1) {
            return createRegression(coef, (Number) Iterables.getOnlyElement(intercept), schema);
        }
        if (numberOfOutputs < 2) {
            throw new IllegalArgumentException();
        }
        List scalarLabels = ScalarLabelUtil.toScalarLabels(label);
        ClassDictUtil.checkSize(numberOfOutputs, new Collection[]{intercept, scalarLabels});
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < numberOfOutputs; i++) {
            arrayList.add(createRegression(CMatrixUtil.getRow(coef, numberOfOutputs, features.size(), i), intercept.get(i), schema.toRelabeledSchema((Label) scalarLabels.get(i))));
        }
        return MiningModelUtil.createMultiModelChain(arrayList, Segmentation.MissingPredictionTreatment.CONTINUE);
    }

    protected RegressionModel createRegression(List<? extends Number> list, Number number, Schema schema) {
        return RegressionModelUtil.createRegression(schema.getFeatures(), list, number, (RegressionModel.NormalizationMethod) null, schema);
    }

    public List<? extends Number> getCoef() {
        return getNumberArray("coef_");
    }

    public int[] getCoefShape() {
        return getArrayShape("coef_");
    }

    public List<? extends Number> getIntercept() {
        return getNumberArray("intercept_");
    }
}
