package io.qbeast.spark.index;

import io.qbeast.core.model.ColumnsToIndexSelector;
import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.feature.VectorAssembler;
import org.apache.spark.ml.linalg.Matrix;
import org.apache.spark.ml.stat.Correlation$;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.functions$;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.sql.types.TimestampType$;
import scala.Array$;
import scala.MatchError;
import scala.Predef$;
import scala.Serializable;
import scala.Tuple2;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.TraversableLike;
import scala.collection.TraversableOnce;
import scala.collection.generic.GenericTraversableTemplate;
import scala.collection.mutable.ArrayOps;
import scala.math.Ordering$Double$;
import scala.reflect.ClassTag$;
import scala.runtime.BoxesRunTime;
import scala.runtime.ScalaRunTime$;

/* compiled from: SparkColumnsToIndexSelector.scala */
/* loaded from: input_file:io/qbeast/spark/index/SparkColumnsToIndexSelector$.class */
public final class SparkColumnsToIndexSelector$ implements ColumnsToIndexSelector<Dataset<Row>>, Serializable {
    public static SparkColumnsToIndexSelector$ MODULE$;

    static {
        new SparkColumnsToIndexSelector$();
    }

    public Seq selectColumnsToIndex(Object obj) {
        return ColumnsToIndexSelector.selectColumnsToIndex$(this, obj);
    }

    public int MAX_COLUMNS_TO_INDEX() {
        return org.apache.spark.qbeast.config.package$.MODULE$.MAX_NUM_COLUMNS_TO_INDEX();
    }

    private Dataset<Row> withUnixTimestamp(Dataset<Row> dataset, Seq<StructField> seq) {
        return dataset.withColumns(((TraversableOnce) ((TraversableLike) seq.filter(structField -> {
            return BoxesRunTime.boxToBoolean($anonfun$withUnixTimestamp$1(structField));
        })).map(structField2 -> {
            return new Tuple2(structField2.name(), functions$.MODULE$.unix_timestamp(functions$.MODULE$.col(structField2.name())));
        }, Seq$.MODULE$.canBuildFrom())).toMap(Predef$.MODULE$.$conforms()));
    }

    public Dataset<Row> withPreprocessedPipeline(Dataset<Row> dataset, Seq<StructField> seq) {
        return new Pipeline().setStages((PipelineStage[]) ((GenericTraversableTemplate) seq.collect(new SparkColumnsToIndexSelector$$anonfun$1(), Seq$.MODULE$.canBuildFrom())).flatten(Predef$.MODULE$.$conforms()).toArray(ClassTag$.MODULE$.apply(PipelineStage.class))).fit(dataset).transform(dataset);
    }

    public String[] selectTopNCorrelatedColumns(Dataset<Row> dataset, Seq<StructField> seq, int i) {
        String[] strArr = (String[]) ((TraversableOnce) seq.map(structField -> {
            return new StringBuilder(4).append(structField.name()).append("_Vec").toString();
        }, Seq$.MODULE$.canBuildFrom())).toArray(ClassTag$.MODULE$.apply(String.class));
        return (String[]) new ArrayOps.ofInt(Predef$.MODULE$.intArrayOps((int[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Tuple2[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[]) new ArrayOps.ofDouble(Predef$.MODULE$.doubleArrayOps((double[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps((Object[]) new ArrayOps.ofDouble(Predef$.MODULE$.doubleArrayOps((double[]) new ArrayOps.ofDouble(Predef$.MODULE$.doubleArrayOps(((Matrix) ((Row) Correlation$.MODULE$.corr(new VectorAssembler().setInputCols(strArr).setOutputCol("features").setHandleInvalid("keep").transform(dataset), "features").select("pearson(features)", Predef$.MODULE$.wrapRefArray(new String[0])).head()).getAs(0)).toArray())).map(d -> {
            return Math.abs(d);
        }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.Double())))).grouped(strArr.length).toArray(ClassTag$.MODULE$.apply(ScalaRunTime$.MODULE$.arrayClass(Double.TYPE))))).head())).zipWithIndex(Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Tuple2.class))))).sortBy(tuple2 -> {
            return BoxesRunTime.boxToDouble($anonfun$selectTopNCorrelatedColumns$3(tuple2));
        }, Ordering$Double$.MODULE$))).take(i))).map(tuple22 -> {
            return BoxesRunTime.boxToInteger(tuple22._2$mcI$sp());
        }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.Int())))).map(obj -> {
            return $anonfun$selectTopNCorrelatedColumns$5(seq, BoxesRunTime.unboxToInt(obj));
        }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(String.class)));
    }

    public Seq<String> selectColumnsToIndex(Dataset<Row> dataset, int i) {
        if (dataset.isEmpty()) {
            return Predef$.MODULE$.wrapRefArray((Object[]) new ArrayOps.ofRef(Predef$.MODULE$.refArrayOps(dataset.columns())).take(i));
        }
        StructType schema = dataset.schema();
        return Predef$.MODULE$.wrapRefArray(selectTopNCorrelatedColumns(withPreprocessedPipeline(withUnixTimestamp(dataset, schema), schema), schema, i));
    }

    private Object readResolve() {
        return MODULE$;
    }

    public static final /* synthetic */ boolean $anonfun$withUnixTimestamp$1(StructField structField) {
        DataType dataType = structField.dataType();
        TimestampType$ timestampType$ = TimestampType$.MODULE$;
        return dataType != null ? dataType.equals(timestampType$) : timestampType$ == null;
    }

    public static final /* synthetic */ double $anonfun$selectTopNCorrelatedColumns$3(Tuple2 tuple2) {
        if (tuple2 != null) {
            return tuple2._1$mcD$sp();
        }
        throw new MatchError(tuple2);
    }

    public static final /* synthetic */ String $anonfun$selectTopNCorrelatedColumns$5(Seq seq, int i) {
        return ((StructField) seq.apply(i)).name();
    }

    private SparkColumnsToIndexSelector$() {
        MODULE$ = this;
        ColumnsToIndexSelector.$init$(this);
    }
}
