package com.datastax.insight.ml.spark.ml.classification;

import com.datastax.insight.spec.DataSetOperator;
import com.google.common.base.Strings;
import org.apache.spark.ml.classification.MultilayerPerceptronClassificationModel;
import org.apache.spark.ml.classification.MultilayerPerceptronClassifier;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;

/* loaded from: input_file:com/datastax/insight/ml/spark/ml/classification/MultilayerPerceptronClassifierWrapper.class */
public class MultilayerPerceptronClassifierWrapper implements DataSetOperator {
    public static MultilayerPerceptronClassifier getOperator(String str, String str2, String str3, Integer num, Long l, Integer num2, Double d, Double d2, String str4) {
        MultilayerPerceptronClassifier multilayerPerceptronClassifier = new MultilayerPerceptronClassifier();
        if (!Strings.isNullOrEmpty(str)) {
            multilayerPerceptronClassifier.setFeaturesCol(str);
        }
        if (!Strings.isNullOrEmpty(str2)) {
            multilayerPerceptronClassifier.setLabelCol(str2);
        }
        if (!Strings.isNullOrEmpty(str3)) {
            String[] split = str3.split(";");
            int[] iArr = new int[split.length];
            for (int i = 0; i < iArr.length; i++) {
                iArr[i] = Integer.parseInt(split[i]);
            }
            multilayerPerceptronClassifier.setLayers(iArr);
        }
        if (num != null) {
            multilayerPerceptronClassifier.setBlockSize(num.intValue());
        }
        if (l != null) {
            multilayerPerceptronClassifier.setSeed(l.longValue());
        }
        if (num2 != null) {
            multilayerPerceptronClassifier.setMaxIter(num2.intValue());
        }
        if (d != null) {
            multilayerPerceptronClassifier.setTol(d.doubleValue());
        }
        if (d2 != null) {
            multilayerPerceptronClassifier.setStepSize(d2.doubleValue());
        }
        if (!Strings.isNullOrEmpty(str4)) {
            multilayerPerceptronClassifier.setSolver(str4);
        }
        return multilayerPerceptronClassifier;
    }

    public static MultilayerPerceptronClassificationModel fit(Dataset<Row> dataset, String str, String str2, String str3, Integer num, Long l, Integer num2, Double d, Double d2, String str4) {
        return getOperator(str, str2, str3, num, l, num2, d, d2, str4).fit(dataset);
    }

    public static MultilayerPerceptronClassificationModel fit(MultilayerPerceptronClassifier multilayerPerceptronClassifier, Dataset<Row> dataset) {
        return multilayerPerceptronClassifier.fit(dataset);
    }

    public static Dataset<Row> transform(MultilayerPerceptronClassificationModel multilayerPerceptronClassificationModel, Dataset<Row> dataset) {
        return multilayerPerceptronClassificationModel.transform(dataset);
    }
}
