package com.datastax.insight.ml.spark.ml.classification;

import com.datastax.insight.spec.DataSetOperator;
import com.datastax.insight.core.Consts;
import com.google.common.base.Strings;
import org.apache.spark.ml.classification.MultilayerPerceptronClassificationModel;
import org.apache.spark.ml.classification.MultilayerPerceptronClassifier;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;

/**
 * 多层感知机分类器
 */
public class MultilayerPerceptronClassifierWrapper implements DataSetOperator {

    public static MultilayerPerceptronClassifier getOperator(String featuresCol,
                                                             String labelCol,
                                                             String layer,
                                                             Integer blockSize,
                                                             Long seed,
                                                             Integer maxIterations,
                                                             Double tol,
                                                             Double stepSize,
                                                             String solver) {

        MultilayerPerceptronClassifier classifier = new MultilayerPerceptronClassifier();

        if (!Strings.isNullOrEmpty(featuresCol)) {
            classifier.setFeaturesCol(featuresCol);
        }

        if (!Strings.isNullOrEmpty(labelCol)) {
            classifier.setLabelCol(labelCol);
        }

        if (!Strings.isNullOrEmpty(layer)) {

            String[] ls = layer.split(Consts.DELIMITER);
            int[] layers = new int[ls.length];
            for (int i = 0; i < layers.length; i++) {
                layers[i] = Integer.parseInt(ls[i]);
            }

            classifier.setLayers(layers);
        }

        if (blockSize != null) {
            classifier.setBlockSize(blockSize);
        }

        if (seed != null) {
            classifier.setSeed(seed);
        }

        if (maxIterations != null) {
            classifier.setMaxIter(maxIterations);
        }

        if (tol != null) {
            classifier.setTol(tol);
        }

        if (stepSize != null) {
            classifier.setStepSize(stepSize);
        }

        if (!Strings.isNullOrEmpty(solver)) {
            classifier.setSolver(solver);
        }

        return classifier;
    }

    public static MultilayerPerceptronClassificationModel fit(Dataset<Row> data,
                                                                String featuresCol,
                                                                String labelCol,
                                                                String layer,
                                                                Integer blockSize,
                                                                Long seed,
                                                                Integer maxIterations,
                                                                Double tol,
                                                                Double stepSize,
                                                                String solver) {
        MultilayerPerceptronClassifier classifier = getOperator(featuresCol, labelCol, layer, blockSize, seed, maxIterations, tol, stepSize, solver);
        return classifier.fit(data);
    }

    public static MultilayerPerceptronClassificationModel fit(MultilayerPerceptronClassifier classifier, Dataset<Row> data) {
        return classifier.fit(data);
    }

    public static Dataset<Row> transform(MultilayerPerceptronClassificationModel model, Dataset<Row> data) {
        return model.transform(data);
    }
}
