package com.datastax.insight.ml.spark.ml.tuning;

import com.alibaba.fastjson.JSON;
import com.datastax.insight.spec.DataSetOperator;
import org.apache.spark.ml.Estimator;
import org.apache.spark.ml.evaluation.Evaluator;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.ml.tuning.TrainValidationSplit;
import org.apache.spark.ml.tuning.TrainValidationSplitModel;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;

/* loaded from: input_file:com/datastax/insight/ml/spark/ml/tuning/TrainValidationSplitWrapper.class */
public class TrainValidationSplitWrapper implements DataSetOperator {
    public static TrainValidationSplit getOperator(Estimator estimator, ParamMap[] paramMapArr, Evaluator evaluator, Double d, Long l) {
        TrainValidationSplit trainValidationSplit = new TrainValidationSplit();
        if (estimator != null) {
            trainValidationSplit.setEstimator(estimator);
        }
        if (paramMapArr != null) {
            trainValidationSplit.setEstimatorParamMaps(paramMapArr);
        }
        if (evaluator != null) {
            trainValidationSplit.setEvaluator(evaluator);
        }
        if (l != null) {
            trainValidationSplit.setSeed(l.longValue());
        }
        if (d != null) {
            trainValidationSplit.setTrainRatio(d.doubleValue());
        }
        return trainValidationSplit;
    }

    public static TrainValidationSplitModel fit(Dataset<Row> dataset, Estimator estimator, ParamMap[] paramMapArr, Evaluator evaluator, Double d, Long l) {
        return fit(dataset, getOperator(estimator, paramMapArr, evaluator, d, l));
    }

    public static TrainValidationSplitModel fit(Dataset<Row> dataset, TrainValidationSplit trainValidationSplit) {
        TrainValidationSplitModel fit = trainValidationSplit.fit(dataset);
        System.out.println(JSON.toJSONString(fit.bestModel().extractParamMap()));
        return fit;
    }

    public static Dataset<Row> transform(Dataset<Row> dataset, TrainValidationSplitModel trainValidationSplitModel) {
        return trainValidationSplitModel.transform(dataset);
    }
}
