package com.datastax.insight.ml.spark.ml.tuning;

import com.datastax.insight.spec.DataSetOperator;
import com.google.common.base.Strings;
import org.apache.commons.lang.ArrayUtils;
import org.apache.spark.ml.param.*;
import org.apache.spark.ml.tuning.ParamGridBuilder;

import java.util.Arrays;
import java.util.Map;

/**
 * 参数网格
 */
public class ParamGridBuilderWrapper implements DataSetOperator {

    public static ParamMap[] getOperator(String args, Map<Object, String> params) {
        return getOperator(params);
    }

    public static ParamMap[] getOperator(Map<Object, String> params) {
        ParamGridBuilder builder = new ParamGridBuilder();

        params.keySet().forEach(key->{
            String value = params.get(key);

            if(Strings.isNullOrEmpty(value)) {
                if(key instanceof BooleanParam) {
                    builder.addGrid((BooleanParam)key);
                }
            } else {
                String[] values = value.split(";");

                if(key instanceof IntParam) {
                    int[] result = Arrays.stream(values).mapToInt(i->Integer.parseInt(i)).toArray();
                    builder.addGrid((IntParam)key, result);
                } else if(key instanceof FloatParam) {
                    Float[] result = Arrays.stream(values).map(i->Float.parseFloat(i)).toArray(Float[]::new);
                    builder.addGrid((FloatParam)key, ArrayUtils.toPrimitive(result));
                } else if(key instanceof LongParam) {
                    long[] result = Arrays.stream(values).mapToLong(i->Long.parseLong(i)).toArray();
                    builder.addGrid((LongParam)key, result);
                } else if(key instanceof DoubleParam) {
                    double[] result = Arrays.stream(values).mapToDouble(i->Double.parseDouble(i)).toArray();
                    builder.addGrid((DoubleParam)key, result);
                }
            }
        });

        return builder.build();
    }

//    public static ParamGridBuilder getOperaotr(Param key, String value) {
//        ParamGridBuilder builder = new ParamGridBuilder();
//        addGrid(builder, key, value);
//        return builder;
//    }
//
//    public static ParamGridBuilder addGrid(ParamGridBuilder builder, Param key, String value) {
//
//        String[] items = value.split(";");
//
//        if(key instanceof BooleanParam) {
//            return builder.addGrid((BooleanParam)key);
//        } else if(key instanceof IntParam) {
//            int[] result = Arrays.stream(items).mapToInt(i->Integer.parseInt(i)).toArray();
//            return builder.addGrid((IntParam)key, result);
//        } else if(key instanceof FloatParam) {
//            float[] result = new float[items.length];
//            for (int i = 0; i < items.length; i++) {
//                result[i] = Float.parseFloat(items[i]);
//            }
//            return builder.addGrid((FloatParam)key, result);
//        } else if(key instanceof LongParam) {
//            long[] result = Arrays.stream(items).mapToLong(i->Long.parseLong(i)).toArray();
//            return builder.addGrid((LongParam)key, result);
//        } else if(key instanceof DoubleParam) {
//            double[] result = Arrays.stream(items).mapToDouble(i->Double.parseDouble(i)).toArray();
//            return builder.addGrid((DoubleParam)key, result);
//        }
//
//        return null;
//    }
//
//    public static ParamMap[] build(ParamGridBuilder builder) {
//        return builder.build();
//    }
}
