package ai.tripl.arc.transform;

import ai.tripl.arc.api.API;
import ai.tripl.arc.util.log.logger.Logger;
import java.util.UUID;
import org.apache.spark.ml.feature.CountVectorizer;
import org.apache.spark.ml.feature.CountVectorizerModel;
import org.apache.spark.ml.feature.MinHashLSH;
import org.apache.spark.ml.feature.MinHashLSHModel;
import org.apache.spark.ml.feature.NGram;
import org.apache.spark.ml.feature.RegexTokenizer;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.expressions.UserDefinedFunction;
import org.apache.spark.sql.functions$;
import scala.Array$;
import scala.MatchError;
import scala.None$;
import scala.Option;
import scala.Option$;
import scala.Predef$;
import scala.Serializable;
import scala.Some;
import scala.Tuple16;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.immutable.List;
import scala.collection.immutable.List$;
import scala.collection.immutable.Map;
import scala.collection.immutable.Nil$;
import scala.reflect.ClassTag$;
import scala.reflect.api.Mirror;
import scala.reflect.api.TypeCreator;
import scala.reflect.api.Types;
import scala.reflect.api.Universe;
import scala.reflect.runtime.package$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;

/* compiled from: SimilarityJoinTransform.scala */
/* loaded from: input_file:ai/tripl/arc/transform/SimilarityJoinTransformStage$.class */
public final class SimilarityJoinTransformStage$ implements Serializable {
    public static final SimilarityJoinTransformStage$ MODULE$ = null;

    static {
        new SimilarityJoinTransformStage$();
    }

    public Option<Dataset<Row>> execute(SimilarityJoinTransformStage similarityJoinTransformStage, SparkSession sparkSession, Logger logger, API.ARCContext aRCContext) {
        Dataset repartition;
        Dataset dataset;
        Dataset dataset2;
        String uuid = UUID.randomUUID().toString();
        RegexTokenizer toLowercase = new RegexTokenizer().setInputCol(uuid).setPattern("").setMinTokenLength(1).setToLowercase(!similarityJoinTransformStage.caseSensitive());
        NGram n = new NGram().setInputCol(toLowercase.getOutputCol()).setN(similarityJoinTransformStage.shingleLength());
        CountVectorizer inputCol = new CountVectorizer().setInputCol(n.getOutputCol());
        MinHashLSH numHashTables = new MinHashLSH().setInputCol(inputCol.getOutputCol()).setNumHashTables(similarityJoinTransformStage.numHashTables());
        UserDefinedFunction udf = functions$.MODULE$.udf(new SimilarityJoinTransformStage$$anonfun$4(), package$.MODULE$.universe().TypeTag().Boolean(), package$.MODULE$.universe().TypeTag().apply(package$.MODULE$.universe().runtimeMirror(getClass().getClassLoader()), new TypeCreator() { // from class: ai.tripl.arc.transform.SimilarityJoinTransformStage$$typecreator1$1
            public <U extends Universe> Types.TypeApi apply(Mirror<U> mirror) {
                mirror.universe();
                return mirror.staticClass("org.apache.spark.ml.linalg.SparseVector").asType().toTypeConstructor();
            }
        }));
        Dataset table = sparkSession.table(similarityJoinTransformStage.leftView());
        Dataset table2 = sparkSession.table(similarityJoinTransformStage.rightView());
        Column[] columnArr = (Column[]) Predef$.MODULE$.refArrayOps(table.columns()).map(new SimilarityJoinTransformStage$$anonfun$5(), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Column.class)));
        Column[] columnArr2 = (Column[]) Predef$.MODULE$.refArrayOps(table2.columns()).map(new SimilarityJoinTransformStage$$anonfun$6(), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Column.class)));
        Dataset transform = n.transform(toLowercase.transform(table.select(Predef$.MODULE$.wrapRefArray(new Column[]{functions$.MODULE$.col("*"), functions$.MODULE$.trim(functions$.MODULE$.concat((Seq) similarityJoinTransformStage.leftFields().map(new SimilarityJoinTransformStage$$anonfun$7(), List$.MODULE$.canBuildFrom()))).alias(uuid)}))));
        transform.persist(aRCContext.storageLevel());
        CountVectorizerModel fit = inputCol.fit(transform);
        Dataset filter = fit.transform(transform).filter(udf.apply(Predef$.MODULE$.wrapRefArray(new Column[]{functions$.MODULE$.col(inputCol.getOutputCol())})));
        filter.persist(aRCContext.storageLevel());
        Dataset filter2 = fit.transform(n.transform(toLowercase.transform(table2.select(Predef$.MODULE$.wrapRefArray(new Column[]{functions$.MODULE$.col("*"), functions$.MODULE$.trim(functions$.MODULE$.concat((Seq) similarityJoinTransformStage.rightFields().map(new SimilarityJoinTransformStage$$anonfun$8(), List$.MODULE$.canBuildFrom()))).alias(uuid)}))))).filter(udf.apply(Predef$.MODULE$.wrapRefArray(new Column[]{functions$.MODULE$.col(inputCol.getOutputCol())})));
        try {
            MinHashLSHModel fit2 = numHashTables.fit(filter);
            Dataset select = fit2.approxSimilarityJoin(fit2.transform(filter), fit2.transform(filter2), 1.0d - similarityJoinTransformStage.threshold()).select(Predef$.MODULE$.wrapRefArray((Object[]) Predef$.MODULE$.refArrayOps((Object[]) Predef$.MODULE$.refArrayOps(columnArr).$plus$plus(Predef$.MODULE$.refArrayOps(columnArr2), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Column.class)))).$plus$plus(Seq$.MODULE$.apply(Predef$.MODULE$.wrapRefArray(new Column[]{functions$.MODULE$.lit(BoxesRunTime.boxToDouble(1.0d)).$minus(functions$.MODULE$.col("distCol")).alias("similarity")})), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Column.class)))));
            List<String> partitionBy = similarityJoinTransformStage.partitionBy();
            if (Nil$.MODULE$.equals(partitionBy)) {
                Some numPartitions = similarityJoinTransformStage.numPartitions();
                if (numPartitions instanceof Some) {
                    dataset2 = select.repartition(BoxesRunTime.unboxToInt(numPartitions.x()));
                } else {
                    if (!None$.MODULE$.equals(numPartitions)) {
                        throw new MatchError(numPartitions);
                    }
                    dataset2 = select;
                }
                dataset = dataset2;
            } else {
                List list = (List) partitionBy.map(new SimilarityJoinTransformStage$$anonfun$9(select), List$.MODULE$.canBuildFrom());
                Some numPartitions2 = similarityJoinTransformStage.numPartitions();
                if (numPartitions2 instanceof Some) {
                    repartition = select.repartition(BoxesRunTime.unboxToInt(numPartitions2.x()), list);
                } else {
                    if (!None$.MODULE$.equals(numPartitions2)) {
                        throw new MatchError(numPartitions2);
                    }
                    repartition = select.repartition(list);
                }
                dataset = repartition;
            }
            Dataset dataset3 = dataset;
            if (aRCContext.immutableViews()) {
                dataset3.createTempView(similarityJoinTransformStage.outputView());
            } else {
                dataset3.createOrReplaceTempView(similarityJoinTransformStage.outputView());
            }
            if (dataset3.isStreaming()) {
                BoxedUnit boxedUnit = BoxedUnit.UNIT;
            } else {
                similarityJoinTransformStage.stageDetail().put("outputColumns", Integer.valueOf(dataset3.schema().length()));
                similarityJoinTransformStage.stageDetail().put("numPartitions", Integer.valueOf(dataset3.rdd().partitions().length));
                if (similarityJoinTransformStage.persist()) {
                    dataset3.persist(aRCContext.storageLevel());
                    similarityJoinTransformStage.stageDetail().put("records", Long.valueOf(dataset3.count()));
                } else {
                    BoxedUnit boxedUnit2 = BoxedUnit.UNIT;
                }
            }
            transform.unpersist();
            filter.unpersist();
            return Option$.MODULE$.apply(dataset3);
        } catch (Exception e) {
            throw new SimilarityJoinTransformStage$$anon$1(similarityJoinTransformStage, e);
        }
    }

    public SimilarityJoinTransformStage apply(SimilarityJoinTransform similarityJoinTransform, String str, Option<String> option, String str2, List<String> list, String str3, List<String> list2, String str4, boolean z, int i, int i2, double d, boolean z2, List<String> list3, Option<Object> option2, Map<String, String> map) {
        return new SimilarityJoinTransformStage(similarityJoinTransform, str, option, str2, list, str3, list2, str4, z, i, i2, d, z2, list3, option2, map);
    }

    public Option<Tuple16<SimilarityJoinTransform, String, Option<String>, String, List<String>, String, List<String>, String, Object, Object, Object, Object, Object, List<String>, Option<Object>, Map<String, String>>> unapply(SimilarityJoinTransformStage similarityJoinTransformStage) {
        return similarityJoinTransformStage == null ? None$.MODULE$ : new Some(new Tuple16(similarityJoinTransformStage.plugin(), similarityJoinTransformStage.name(), similarityJoinTransformStage.description(), similarityJoinTransformStage.leftView(), similarityJoinTransformStage.leftFields(), similarityJoinTransformStage.rightView(), similarityJoinTransformStage.rightFields(), similarityJoinTransformStage.outputView(), BoxesRunTime.boxToBoolean(similarityJoinTransformStage.persist()), BoxesRunTime.boxToInteger(similarityJoinTransformStage.shingleLength()), BoxesRunTime.boxToInteger(similarityJoinTransformStage.numHashTables()), BoxesRunTime.boxToDouble(similarityJoinTransformStage.threshold()), BoxesRunTime.boxToBoolean(similarityJoinTransformStage.caseSensitive()), similarityJoinTransformStage.partitionBy(), similarityJoinTransformStage.numPartitions(), similarityJoinTransformStage.params()));
    }

    private Object readResolve() {
        return MODULE$;
    }

    private SimilarityJoinTransformStage$() {
        MODULE$ = this;
    }
}
