package io.squashql.transaction;

import io.squashql.SparkDatastore;
import io.squashql.SparkUtil;
import io.squashql.store.TypedField;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import org.apache.spark.sql.AnalysisException;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.catalog.Table;
import org.apache.spark.sql.functions;
import org.apache.spark.sql.types.StructType;
import org.eclipse.collections.api.list.ImmutableList;
import org.eclipse.collections.impl.list.immutable.ImmutableListFactoryImpl;

/* loaded from: input_file:io/squashql/transaction/SparkDataLoader.class */
public class SparkDataLoader implements DataLoader {
    protected final SparkSession spark;

    public SparkDataLoader(SparkSession sparkSession) {
        this.spark = sparkSession;
    }

    public void createTemporaryTable(String str, List<TypedField> list) {
        createTemporaryTable(this.spark, str, list, true);
    }

    public void createTemporaryTable(String str, List<TypedField> list, boolean z) {
        createTemporaryTable(this.spark, str, list, z);
    }

    public static void createTemporaryTable(SparkSession sparkSession, String str, List<TypedField> list, boolean z) {
        ImmutableList ofAll = ImmutableListFactoryImpl.INSTANCE.ofAll(list);
        if (z) {
            ofAll = ofAll.newWith(new TypedField(str, "scenario", String.class));
        }
        StructType createSchema = SparkUtil.createSchema(ofAll.castToList());
        sparkSession.conf().set("spark.sql.caseSensitive", String.valueOf(true));
        sparkSession.createDataFrame(Collections.emptyList(), createSchema).createOrReplaceTempView(str);
    }

    public void load(String str, String str2, List<Object[]> list) {
        if (!str.equals("base")) {
            ensureScenarioColumnIsPresent(str2);
        }
        boolean scenarioColumnIsPresent = scenarioColumnIsPresent(str2);
        appendDataset(this.spark, str2, this.spark.createDataFrame(list.stream().map(objArr -> {
            Object[] objArr = objArr;
            if (scenarioColumnIsPresent) {
                objArr = Arrays.copyOf(objArr, objArr.length + 1);
                objArr[objArr.length - 1] = str;
            }
            return RowFactory.create(objArr);
        }).toList(), SparkUtil.createSchema(SparkDatastore.getFields(this.spark, str2))));
    }

    static void appendDataset(SparkSession sparkSession, String str, Dataset<Row> dataset) {
        String str2 = "tmp_" + str;
        sparkSession.sql("ALTER VIEW " + str + " RENAME TO " + str2);
        sparkSession.table(str2).union(dataset).createOrReplaceTempView(str);
        sparkSession.catalog().dropTempView(str2);
    }

    private void ensureScenarioColumnIsPresent(String str) {
        if (!scenarioColumnIsPresent(str)) {
            throw new RuntimeException(String.format("%s field not found", "scenario"));
        }
    }

    private boolean scenarioColumnIsPresent(String str) {
        return SparkDatastore.getFields(this.spark, str).stream().anyMatch(typedField -> {
            return typedField.name().equals("scenario");
        });
    }

    public void loadCsv(String str, String str2, String str3, String str4, boolean z) {
        Dataset withColumn = this.spark.read().option("delimiter", str4).option("header", true).csv(str3).withColumn("scenario", functions.lit(str));
        Table table = null;
        try {
            table = this.spark.catalog().getTable(str2);
        } catch (AnalysisException e) {
        }
        if (table != null) {
            appendDataset(this.spark, str2, withColumn);
        } else {
            this.spark.conf().set("spark.sql.caseSensitive", String.valueOf(true));
            withColumn.createOrReplaceTempView(str2);
        }
    }
}
