package io.openlineage.spark3.agent.lifecycle.plan.column;

import io.openlineage.spark.agent.lifecycle.plan.column.ColumnLevelLineageBuilder;
import io.openlineage.spark.agent.lifecycle.plan.column.ColumnLevelLineageContext;
import io.openlineage.spark.agent.lifecycle.plan.column.TransformationInfo;
import io.openlineage.spark.agent.util.ScalaConversionUtils;
import io.openlineage.spark.shaded.org.apache.commons.lang3.reflect.MethodUtils;
import io.openlineage.spark3.agent.lifecycle.plan.column.visitors.ExpressionDependencyVisitor;
import io.openlineage.spark3.agent.lifecycle.plan.column.visitors.IcebergMergeIntoDependencyVisitor;
import io.openlineage.spark3.agent.lifecycle.plan.column.visitors.UnionDependencyVisitor;
import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.spark.sql.catalyst.expressions.Alias;
import org.apache.spark.sql.catalyst.expressions.AttributeReference;
import org.apache.spark.sql.catalyst.expressions.CaseWhen;
import org.apache.spark.sql.catalyst.expressions.Crc32;
import org.apache.spark.sql.catalyst.expressions.ExprId;
import org.apache.spark.sql.catalyst.expressions.Expression;
import org.apache.spark.sql.catalyst.expressions.HiveHash;
import org.apache.spark.sql.catalyst.expressions.If;
import org.apache.spark.sql.catalyst.expressions.Md5;
import org.apache.spark.sql.catalyst.expressions.Murmur3Hash;
import org.apache.spark.sql.catalyst.expressions.NamedExpression;
import org.apache.spark.sql.catalyst.expressions.Sha1;
import org.apache.spark.sql.catalyst.expressions.Sha2;
import org.apache.spark.sql.catalyst.expressions.XxHash64;
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression;
import org.apache.spark.sql.catalyst.expressions.aggregate.Count;
import org.apache.spark.sql.catalyst.plans.logical.Aggregate;
import org.apache.spark.sql.catalyst.plans.logical.CreateTableAsSelect;
import org.apache.spark.sql.catalyst.plans.logical.Distinct;
import org.apache.spark.sql.catalyst.plans.logical.Filter;
import org.apache.spark.sql.catalyst.plans.logical.Join;
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan;
import org.apache.spark.sql.catalyst.plans.logical.Project;
import org.apache.spark.sql.catalyst.plans.logical.Sort;
import org.apache.spark.sql.catalyst.plans.logical.Union;
import org.apache.spark.sql.execution.datasources.LogicalRelation;
import org.apache.spark.sql.execution.datasources.jdbc.JDBCRelation;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.Option;
import scala.collection.Seq;
import scala.runtime.BoxedUnit;

/* loaded from: input_file:io/openlineage/spark3/agent/lifecycle/plan/column/ExpressionDependencyCollector.class */
public class ExpressionDependencyCollector {
    private static final Logger log = LoggerFactory.getLogger(ExpressionDependencyCollector.class);
    private static final List<Class> classes = Arrays.asList(Crc32.class, HiveHash.class, Md5.class, Murmur3Hash.class, Sha1.class, Sha2.class, XxHash64.class, Count.class);
    private static final List<String> classNames = Collections.singletonList("org.apache.spark.sql.catalyst.expressions.Mask");
    private static final List<ExpressionDependencyVisitor> expressionDependencyVisitors = Arrays.asList(new UnionDependencyVisitor(), new IcebergMergeIntoDependencyVisitor());

    private static Boolean isMasking(Expression expression) {
        return Boolean.valueOf(classes.stream().anyMatch(cls -> {
            return cls.equals(expression.getClass());
        }) || classNames.stream().anyMatch(str -> {
            return str.equals(expression.getClass().getCanonicalName());
        }));
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static void collect(ColumnLevelLineageContext columnLevelLineageContext, LogicalPlan logicalPlan) {
        logicalPlan.foreach(logicalPlan2 -> {
            collectFromNode(columnLevelLineageContext, logicalPlan2);
            return BoxedUnit.UNIT;
        });
    }

    static void collectFromNode(ColumnLevelLineageContext columnLevelLineageContext, LogicalPlan logicalPlan) {
        expressionDependencyVisitors.stream().filter(expressionDependencyVisitor -> {
            return expressionDependencyVisitor.isDefinedAt(logicalPlan);
        }).forEach(expressionDependencyVisitor2 -> {
            expressionDependencyVisitor2.apply(logicalPlan, columnLevelLineageContext.getBuilder());
        });
        CustomCollectorsUtils.collectExpressionDependencies(columnLevelLineageContext, logicalPlan);
        LinkedList linkedList = new LinkedList();
        LinkedList linkedList2 = new LinkedList();
        Optional empty = Optional.empty();
        if (logicalPlan instanceof Project) {
            linkedList.addAll(ScalaConversionUtils.fromSeq(((Project) logicalPlan).projectList()));
        } else if ((logicalPlan instanceof CreateTableAsSelect) && (logicalPlan.children() == null || logicalPlan.children().isEmpty())) {
            collectFromNode(columnLevelLineageContext, ((CreateTableAsSelect) logicalPlan).query());
        } else if (logicalPlan instanceof Distinct) {
            collectFromNode(columnLevelLineageContext, ((Distinct) logicalPlan).child());
        } else if (logicalPlan instanceof Aggregate) {
            Aggregate aggregate = (Aggregate) logicalPlan;
            if (!(aggregate.child() instanceof Union) || !doesGroupByAllAggregateExpressions(aggregate)) {
                linkedList2.addAll(ScalaConversionUtils.fromSeq(aggregate.groupingExpressions()));
                empty = Optional.of(TransformationInfo.indirect(TransformationInfo.Subtypes.GROUP_BY));
                linkedList.addAll(ScalaConversionUtils.fromSeq(aggregate.aggregateExpressions()));
            }
        } else if (logicalPlan instanceof Join) {
            Option condition = ((Join) logicalPlan).condition();
            if (condition.isDefined()) {
                empty = Optional.of(TransformationInfo.indirect(TransformationInfo.Subtypes.JOIN));
                linkedList2.add(condition.get());
            }
        } else if (logicalPlan instanceof Filter) {
            linkedList2.add(((Filter) logicalPlan).condition());
            empty = Optional.of(TransformationInfo.indirect(TransformationInfo.Subtypes.FILTER));
        } else if (logicalPlan instanceof Sort) {
            linkedList2.addAll(ScalaConversionUtils.fromSeq(((Sort) logicalPlan).order()));
            empty = Optional.of(TransformationInfo.indirect(TransformationInfo.Subtypes.SORT));
        } else if ((logicalPlan instanceof LogicalRelation) && (((LogicalRelation) logicalPlan).relation() instanceof JDBCRelation)) {
            JdbcColumnLineageCollector.extractExpressionsFromJDBC(columnLevelLineageContext, logicalPlan);
        }
        linkedList.stream().forEach(namedExpression -> {
            traverseExpression((Expression) namedExpression, namedExpression.exprId(), TransformationInfo.identity(), columnLevelLineageContext.getBuilder());
        });
        empty.ifPresent(transformationInfo -> {
            ExprId newExprId = NamedExpression.newExprId();
            columnLevelLineageContext.getBuilder().addDatasetDependency(newExprId);
            linkedList2.forEach(expression -> {
                traverseExpression(expression, newExprId, transformationInfo, columnLevelLineageContext.getBuilder());
            });
        });
    }

    public static void traverseExpression(Expression expression, ExprId exprId, ColumnLevelLineageBuilder columnLevelLineageBuilder) {
        traverseExpression(expression, exprId, TransformationInfo.identity(), columnLevelLineageBuilder);
    }

    public static void traverseExpression(Expression expression, ExprId exprId, TransformationInfo transformationInfo, ColumnLevelLineageBuilder columnLevelLineageBuilder) {
        if (expression instanceof AttributeReference) {
            AttributeReference attributeReference = (AttributeReference) expression;
            if (attributeReference.exprId().equals(exprId)) {
                return;
            }
            columnLevelLineageBuilder.addDependency(exprId, attributeReference.exprId(), transformationInfo);
            return;
        }
        if (expression instanceof Alias) {
            handleExpression((Alias) expression, exprId, transformationInfo, columnLevelLineageBuilder);
            return;
        }
        if (expression instanceof CaseWhen) {
            handleExpression((CaseWhen) expression, exprId, transformationInfo, columnLevelLineageBuilder);
            return;
        }
        if (expression instanceof If) {
            handleExpression((If) expression, exprId, transformationInfo, columnLevelLineageBuilder);
            return;
        }
        if (expression instanceof AggregateExpression) {
            handleExpression((AggregateExpression) expression, exprId, transformationInfo, columnLevelLineageBuilder);
        } else {
            if (expression == null || expression.children() == null) {
                return;
            }
            handleGenericExpression(expression, exprId, transformationInfo, columnLevelLineageBuilder);
        }
    }

    private static void handleGenericExpression(Expression expression, ExprId exprId, TransformationInfo transformationInfo, ColumnLevelLineageBuilder columnLevelLineageBuilder) {
        ScalaConversionUtils.fromSeq(expression.children()).stream().forEach(expression2 -> {
            traverseExpression(expression2, exprId, transformationInfo.merge(TransformationInfo.transformation(isMasking(expression))), columnLevelLineageBuilder);
        });
    }

    private static void handleExpression(AggregateExpression aggregateExpression, ExprId exprId, TransformationInfo transformationInfo, ColumnLevelLineageBuilder columnLevelLineageBuilder) {
        if (MethodUtils.getAccessibleMethod(AggregateExpression.class, "resultId", new Class[0]) != null) {
            columnLevelLineageBuilder.addDependency(exprId, aggregateExpression.resultId(), transformationInfo.merge(TransformationInfo.aggregation()));
        } else {
            try {
                ScalaConversionUtils.fromSeq((Seq) MethodUtils.invokeMethod(aggregateExpression, "resultIds")).stream().forEach(exprId2 -> {
                    columnLevelLineageBuilder.addDependency(exprId, exprId2, transformationInfo.merge(TransformationInfo.aggregation()));
                });
            } catch (Exception e) {
                log.warn("Failed extracting resultIds from AggregateExpression", e);
            }
        }
        traverseExpression(aggregateExpression.aggregateFunction(), exprId, transformationInfo.merge(TransformationInfo.aggregation()), columnLevelLineageBuilder);
    }

    private static void handleExpression(If r5, ExprId exprId, TransformationInfo transformationInfo, ColumnLevelLineageBuilder columnLevelLineageBuilder) {
        traverseExpression(r5.predicate(), exprId, transformationInfo.merge(TransformationInfo.indirect(TransformationInfo.Subtypes.CONDITIONAL)), columnLevelLineageBuilder);
        traverseExpression(r5.trueValue(), exprId, transformationInfo, columnLevelLineageBuilder);
        traverseExpression(r5.falseValue(), exprId, transformationInfo, columnLevelLineageBuilder);
    }

    private static void handleExpression(CaseWhen caseWhen, ExprId exprId, TransformationInfo transformationInfo, ColumnLevelLineageBuilder columnLevelLineageBuilder) {
        List fromSeq = ScalaConversionUtils.fromSeq(caseWhen.branches());
        fromSeq.stream().map(tuple2 -> {
            return (Expression) tuple2._1;
        }).forEach(expression -> {
            traverseExpression(expression, exprId, transformationInfo.merge(TransformationInfo.indirect(TransformationInfo.Subtypes.CONDITIONAL)), columnLevelLineageBuilder);
        });
        fromSeq.stream().map(tuple22 -> {
            return (Expression) tuple22._2;
        }).forEach(expression2 -> {
            traverseExpression(expression2, exprId, transformationInfo, columnLevelLineageBuilder);
        });
        if (caseWhen.elseValue().isDefined()) {
            traverseExpression((Expression) caseWhen.elseValue().get(), exprId, transformationInfo, columnLevelLineageBuilder);
        }
    }

    private static void handleExpression(Alias alias, ExprId exprId, TransformationInfo transformationInfo, ColumnLevelLineageBuilder columnLevelLineageBuilder) {
        traverseExpression(alias.child(), exprId, transformationInfo.merge(TransformationInfo.identity()), columnLevelLineageBuilder);
    }

    private static boolean doesGroupByAllAggregateExpressions(Aggregate aggregate) {
        return ((Set) ScalaConversionUtils.fromSeq(aggregate.groupingExpressions()).stream().filter(expression -> {
            return expression instanceof AttributeReference;
        }).map(expression2 -> {
            return (AttributeReference) expression2;
        }).map(attributeReference -> {
            return attributeReference.exprId();
        }).collect(Collectors.toSet())).containsAll((Set) ScalaConversionUtils.fromSeq(aggregate.aggregateExpressions()).stream().map(namedExpression -> {
            return namedExpression.exprId();
        }).collect(Collectors.toSet()));
    }
}
