package com.datastax.insight.ml.spark.ml.model;

import com.datastax.insight.core.entity.Cache;
import com.datastax.insight.core.entity.Model;
import com.datastax.insight.core.service.PersistService;
import com.datastax.insight.spec.DataSetOperator;
import com.datastax.util.lang.ReflectUtil;
import com.google.common.base.Strings;
import java.io.IOException;
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.PredictionModel;
import org.apache.spark.ml.Transformer;
import org.apache.spark.ml.attribute.Attribute;
import org.apache.spark.ml.attribute.AttributeGroup;
import org.apache.spark.ml.attribute.NumericAttribute;
import org.apache.spark.ml.feature.StringIndexerModel;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.linalg.Vectors;
import org.apache.spark.ml.param.shared.HasFeaturesCol;
import org.apache.spark.ml.param.shared.HasInputCol;
import org.apache.spark.ml.param.shared.HasOutputCol;
import org.apache.spark.ml.util.MLWritable;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;

/* loaded from: input_file:com/datastax/insight/ml/spark/ml/model/ModelHandler.class */
public class ModelHandler implements DataSetOperator {
    public static MLWritable load(String str) {
        Model model = (Model) PersistService.invoke("com.datastax.insight.agent.dao.InsightDAO", "getModel", new String[]{String.class.getTypeName()}, new Object[]{str});
        return a(model.getModelClass(), model.getPath());
    }

    private static MLWritable a(String str, String str2) {
        return (MLWritable) ReflectUtil.invokeStaticMethod(str, "load", new String[]{String.class.getTypeName()}, str2);
    }

    @Deprecated
    public static void save(MLWritable mLWritable, String str, String str2, boolean z) throws IOException {
        String str3 = (str2.endsWith("/") && str.startsWith("/")) ? str2 + str.substring(1) : (str2.endsWith("/") || str.startsWith("/")) ? str2 + str : str2 + "/" + str;
        if (z) {
            mLWritable.write().overwrite().save(str3);
        } else {
            mLWritable.write().save(str3);
        }
        PersistService.invoke("com.datastax.insight.agent.dao.InsightDAO", "saveModel", new String[]{Long.class.getTypeName(), String.class.getTypeName(), String.class.getTypeName(), String.class.getTypeName()}, new Object[]{PersistService.getFlowId(), str, mLWritable.getClass().getName(), str3});
    }

    public static void save(MLWritable mLWritable, String str, boolean z) throws IOException {
        Object cache = Cache.getCache("userId");
        if (cache != null) {
            try {
                Long.valueOf(Long.parseLong(cache.toString()));
            } catch (Exception e) {
            }
        }
        String obj = Cache.getCache("modelPath").toString();
        if (z) {
            mLWritable.write().overwrite().save(obj);
        } else {
            mLWritable.write().save(obj);
        }
        PersistService.invoke("com.datastax.insight.agent.dao.InsightDAO", "saveModel", new String[]{Long.class.getTypeName(), Long.class.getTypeName(), Long.class.getTypeName(), String.class.getTypeName(), String.class.getTypeName(), String.class.getTypeName()}, new Object[]{PersistService.getFlowId(), Long.valueOf(PersistService.getBatchId()), Long.valueOf(PersistService.getFlowVersionId()), str, mLWritable.getClass().getName(), obj});
    }

    public static String predict(String str, String str2) {
        return a(load(str), str2);
    }

    public static String[] predict(String str, String[] strArr) {
        return a(load(str), (List<Vector>) Arrays.stream(strArr).map(str2 -> {
            return Vectors.dense(((List) Arrays.stream(str2.split(",")).map(str2 -> {
                return Double.valueOf(str2);
            }).collect(Collectors.toList())).stream().mapToDouble(d -> {
                return d.doubleValue();
            }).toArray());
        }).collect(Collectors.toList()));
    }

    public static String predict(String str, String str2, String str3) {
        return a(a(str, str2), str3);
    }

    public static String[] predict(String str, String str2, String[] strArr) {
        return a(a(str, str2), (List<Vector>) Arrays.stream(strArr).map(str3 -> {
            return Vectors.dense(((List) Arrays.stream(str3.split(",")).map(str3 -> {
                return Double.valueOf(str3);
            }).collect(Collectors.toList())).stream().mapToDouble(d -> {
                return d.doubleValue();
            }).toArray());
        }).collect(Collectors.toList()));
    }

    private static String a(MLWritable mLWritable, String str) {
        return a(mLWritable, (List<Vector>) Arrays.asList(Vectors.dense(((List) Arrays.stream(str.split(",")).map(str2 -> {
            return Double.valueOf(str2);
        }).collect(Collectors.toList())).stream().mapToDouble(d -> {
            return d.doubleValue();
        }).toArray())))[0];
    }

    private static String[] a(MLWritable mLWritable, List<Vector> list) {
        if (mLWritable instanceof PredictionModel) {
            PredictionModel predictionModel = (PredictionModel) mLWritable;
            return (String[]) list.stream().map(vector -> {
                return String.valueOf(predictionModel.predict(vector));
            }).toArray(i -> {
                return new String[i];
            });
        }
        if (mLWritable instanceof PipelineModel) {
            PipelineModel pipelineModel = (PipelineModel) mLWritable;
            String str = null;
            String str2 = null;
            String str3 = null;
            String[] strArr = null;
            for (int length = pipelineModel.stages().length - 1; length >= 0; length--) {
                StringIndexerModel stringIndexerModel = pipelineModel.stages()[length];
                if (stringIndexerModel instanceof StringIndexerModel) {
                    strArr = stringIndexerModel.labels();
                }
                if ((stringIndexerModel instanceof PredictionModel) && Strings.isNullOrEmpty(str)) {
                    str = ((PredictionModel) stringIndexerModel).getPredictionCol();
                }
                if ((stringIndexerModel instanceof HasFeaturesCol) && Strings.isNullOrEmpty(str2)) {
                    str2 = ((HasFeaturesCol) stringIndexerModel).getFeaturesCol();
                }
                if (stringIndexerModel instanceof HasOutputCol) {
                    String outputCol = ((HasOutputCol) stringIndexerModel).getOutputCol();
                    if (stringIndexerModel instanceof HasInputCol) {
                        if (Strings.isNullOrEmpty(str3)) {
                            if (outputCol.equals(str2)) {
                                str3 = ((HasInputCol) stringIndexerModel).getInputCol();
                            }
                        } else if (outputCol.equals(str3)) {
                            str3 = ((HasInputCol) stringIndexerModel).getInputCol();
                        }
                    }
                }
                if (!Strings.isNullOrEmpty(str2) && !Strings.isNullOrEmpty(str3) && strArr != null) {
                    break;
                }
            }
            if (Strings.isNullOrEmpty(str3)) {
                for (int length2 = pipelineModel.stages().length - 1; length2 >= 0; length2--) {
                    Transformer transformer = pipelineModel.stages()[length2];
                    if (transformer instanceof PredictionModel) {
                        return (String[]) list.stream().map(vector2 -> {
                            return String.valueOf(((PredictionModel) transformer).predict(vector2));
                        }).toArray(i2 -> {
                            return new String[i2];
                        });
                    }
                }
            }
            SparkSession orCreate = SparkSession.builder().getOrCreate();
            NumericAttribute defaultAttr = NumericAttribute.defaultAttr();
            int size = list.get(0).size();
            Attribute[] attributeArr = new Attribute[size];
            for (int i3 = 0; i3 < size; i3++) {
                attributeArr[i3] = defaultAttr.withName("f" + (i3 + 1));
            }
            Dataset transform = pipelineModel.transform(orCreate.createDataFrame((List) list.stream().map(vector3 -> {
                return RowFactory.create(new Object[]{vector3});
            }).collect(Collectors.toList()), new StructType(new StructField[]{new AttributeGroup(str3, attributeArr).toStructField()})));
            if (!Strings.isNullOrEmpty(str)) {
                String str4 = str;
                if (strArr == null) {
                    return (String[]) transform.select(str4, new String[0]).collectAsList().stream().map(row -> {
                        return String.valueOf(row.getDouble(row.fieldIndex(str4)));
                    }).toArray(i4 -> {
                        return new String[i4];
                    });
                }
                String[] strArr2 = strArr;
                return (String[]) transform.select(str4, new String[0]).collectAsList().stream().map(row2 -> {
                    return strArr2[(int) row2.getDouble(row2.fieldIndex(str4))];
                }).toArray(i5 -> {
                    return new String[i5];
                });
            }
        }
        throw new UnsupportedOperationException("predict method is no support for Class:" + mLWritable.getClass().getTypeName());
    }

    private static Map<String, String> a(Long l) {
        return (Map) PersistService.invoke("com.datastax.insight.agent.dao.InsightDAO", "getSystemSettings", new String[]{Long.class.getTypeName()}, new Object[]{l});
    }

    /* renamed from: a, reason: collision with other method in class */
    private static String m243a(String str, String str2) {
        return str.toLowerCase().startsWith("hdfs://") ? (str.endsWith("/") && str2.startsWith("/")) ? str + str2.substring(1) : (str.endsWith("/") || str2.startsWith("/")) ? str + str2 : str + "/" + str2 : Paths.get(str, str2).toString();
    }
}
