package com.datastax.insight.ml.spark.ml.tuning;

import com.datastax.insight.spec.DataSetOperator;
import org.apache.spark.ml.Estimator;
import org.apache.spark.ml.evaluation.Evaluator;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.ml.tuning.CrossValidator;
import org.apache.spark.ml.tuning.CrossValidatorModel;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;

/* loaded from: input_file:com/datastax/insight/ml/spark/ml/tuning/CrossValidatorWrapper.class */
public class CrossValidatorWrapper implements DataSetOperator {
    public static CrossValidator getOperator(Estimator estimator, ParamMap[] paramMapArr, Evaluator evaluator, Integer num, Long l) {
        CrossValidator crossValidator = new CrossValidator();
        if (estimator != null) {
            crossValidator.setEstimator(estimator);
        }
        if (paramMapArr != null) {
            crossValidator.setEstimatorParamMaps(paramMapArr);
        }
        if (evaluator != null) {
            crossValidator.setEvaluator(evaluator);
        }
        if (num != null) {
            crossValidator.setNumFolds(num.intValue());
        }
        if (l != null) {
            crossValidator.setSeed(l.longValue());
        }
        return crossValidator;
    }

    public static CrossValidatorModel fit(Dataset<Row> dataset, CrossValidator crossValidator) {
        return crossValidator.fit(dataset);
    }

    public static CrossValidatorModel fit(Dataset<Row> dataset, Estimator estimator, ParamMap[] paramMapArr, Evaluator evaluator, Integer num, Long l) {
        return fit(dataset, getOperator(estimator, paramMapArr, evaluator, num, l));
    }

    public static Dataset<Row> transform(Dataset<Row> dataset, CrossValidatorModel crossValidatorModel) {
        return crossValidatorModel.transform(dataset);
    }
}
