package io.cdap.mmds.modeler.param;

import com.google.common.collect.ImmutableSet;
import io.cdap.mmds.spec.BoolParam;
import io.cdap.mmds.spec.ParamSpec;
import io.cdap.mmds.spec.Params;
import io.cdap.mmds.spec.StringParam;
import java.util.List;
import java.util.Map;
import org.apache.spark.ml.regression.GeneralizedLinearRegression;
import org.apache.twill.internal.Constants;

/* loaded from: input_file:lib/mmds-model-1.7.1.jar:io/cdap/mmds/modeler/param/GeneralizedLinearRegressionParams.class */
public class GeneralizedLinearRegressionParams extends RegressionParams {
    private final StringParam family;
    private final StringParam link;
    private final BoolParam fitIntercept;

    public GeneralizedLinearRegressionParams(Map<String, String> map) {
        super(map);
        this.family = new StringParam("family", "Family", "The error distribution to be used in the model.", "gaussian", ImmutableSet.of("gaussian", "binomial", "poisson", "gamma"), map);
        this.link = new StringParam("link", "Link", "Relationship between the linear predictor and the mean of the distribution function.", "identity", ImmutableSet.of("identity", Constants.LOG_TOPIC, "inverse", "logit", "probit", "cloglog", "sqrt"), map);
        this.fitIntercept = new BoolParam("fitIntercept", "Fit Intercept", "If the intercept should be fit", true, map);
    }

    public void setParams(GeneralizedLinearRegression generalizedLinearRegression) {
        generalizedLinearRegression.setMaxIter(this.maxIterations.getVal().intValue());
        generalizedLinearRegression.setRegParam(this.regularizationParam.getVal().doubleValue());
        generalizedLinearRegression.setTol(this.tolerance.getVal().doubleValue());
        generalizedLinearRegression.setFitIntercept(this.fitIntercept.getVal().booleanValue());
        generalizedLinearRegression.setFamily(this.family.getVal());
        generalizedLinearRegression.setLink(this.link.getVal());
    }

    @Override // io.cdap.mmds.modeler.param.RegressionParams, io.cdap.mmds.spec.Parameters
    public Map<String, String> toMap() {
        return Params.putParams(super.toMap(), this.family, this.link, this.fitIntercept);
    }

    @Override // io.cdap.mmds.modeler.param.RegressionParams, io.cdap.mmds.spec.Parameters
    public List<ParamSpec> getSpec() {
        return Params.addParams(super.getSpec(), this.family, this.link, this.fitIntercept);
    }
}
