diff --git a/extensions/spark/kyuubi-spark-lineage/src/main/scala/org/apache/kyuubi/plugin/lineage/helper/SparkSQLLineageParseHelper.scala b/extensions/spark/kyuubi-spark-lineage/src/main/scala/org/apache/kyuubi/plugin/lineage/helper/SparkSQLLineageParseHelper.scala index bcab9b74fef..5251adf4da8 100644 --- a/extensions/spark/kyuubi-spark-lineage/src/main/scala/org/apache/kyuubi/plugin/lineage/helper/SparkSQLLineageParseHelper.scala +++ b/extensions/spark/kyuubi-spark-lineage/src/main/scala/org/apache/kyuubi/plugin/lineage/helper/SparkSQLLineageParseHelper.scala @@ -119,14 +119,15 @@ trait LineageParser { val exps = named.map { case exp: Alias => val references = - if (exp.references.nonEmpty) exp.references - else { + if (exp.references.isEmpty || exp.child.isInstanceOf[ScalarSubquery]) { val attrRefs = getExpressionSubqueryPlans(exp.child) .map(extractColumnsLineage(_, ListMap[Attribute, AttributeSet]())) .foldLeft(ListMap[Attribute, AttributeSet]())(mergeColumnsLineage).values .foldLeft(AttributeSet.empty)(_ ++ _) .map(attr => attr.withQualifier(attr.qualifier :+ SUBQUERY_COLUMN_IDENTIFIER)) AttributeSet(attrRefs) + } else { + exp.references } ( exp.toAttribute, diff --git a/extensions/spark/kyuubi-spark-lineage/src/test/scala/org/apache/kyuubi/plugin/lineage/helper/SparkSQLLineageParserHelperSuite.scala b/extensions/spark/kyuubi-spark-lineage/src/test/scala/org/apache/kyuubi/plugin/lineage/helper/SparkSQLLineageParserHelperSuite.scala index e3cda6959fb..993720a9427 100644 --- a/extensions/spark/kyuubi-spark-lineage/src/test/scala/org/apache/kyuubi/plugin/lineage/helper/SparkSQLLineageParserHelperSuite.scala +++ b/extensions/spark/kyuubi-spark-lineage/src/test/scala/org/apache/kyuubi/plugin/lineage/helper/SparkSQLLineageParserHelperSuite.scala @@ -1143,6 +1143,18 @@ abstract class SparkSQLLineageParserHelperSuite extends KyuubiFunSuite List( ("a", Set(s"$DEFAULT_CATALOG.default.table1.a")), ("b", Set(s"$DEFAULT_CATALOG.default.table1.b"))))) + + val sql12 = + """ + |select (select sum(a) from table0 where table1.b = table0.b) as aa, b from table1 + |""".stripMargin + val ret12 = extractLineage(sql12) + assert(ret12 == Lineage( + List(s"$DEFAULT_CATALOG.default.table0", s"$DEFAULT_CATALOG.default.table1"), + List(), + List( + ("aa", Set(s"$DEFAULT_CATALOG.default.table0.a")), + ("b", Set(s"$DEFAULT_CATALOG.default.table1.b"))))) } }