Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CORE] Add logical link to rewritten spark plan #4817

Merged
merged 3 commits into from
Mar 1, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
*/
package io.glutenproject.extension

import io.glutenproject.extension.columnar.{AddTransformHintRule, TRANSFORM_UNSUPPORTED, TransformHint, TransformHints}
import io.glutenproject.extension.columnar.{AddTransformHintRule, TransformHint, TransformHints}

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
Expand Down Expand Up @@ -66,11 +66,7 @@ class RewriteSparkPlanRulesManager(rewriteRules: Seq[Rule[SparkPlan]]) extends R
case p if p.nodeName == origin.nodeName => p
}
assert(target.size == 1)
if (TransformHints.isTransformable(target.head)) {
None
} else {
Some(TransformHints.getHint(target.head))
}
TransformHints.getHintOption(target.head)
}

private def applyRewriteRules(origin: SparkPlan): (SparkPlan, Option[String]) = {
Expand Down Expand Up @@ -109,15 +105,19 @@ class RewriteSparkPlanRulesManager(rewriteRules: Seq[Rule[SparkPlan]]) extends R
} else {
addHint.apply(rewrittenPlan)
val hint = getTransformHintBack(origin, rewrittenPlan)
hint match {
case Some(tu @ TRANSFORM_UNSUPPORTED(_, _)) =>
// If the rewritten plan is still not transformable, return the original plan.
TransformHints.tag(origin, tu)
origin
case None =>
rewrittenPlan.transformUp { case wall: RewrittenNodeWall => wall.originalChild }
case _ =>
throw new IllegalStateException("Unreachable code")
if (hint.isDefined) {
// If the rewritten plan is still not transformable, return the original plan.
TransformHints.tag(origin, hint.get)
origin
} else {
rewrittenPlan.transformUp {
case wall: RewrittenNodeWall => wall.originalChild
case p if p.logicalLink.isEmpty =>
// Add logical link to pull out project to make fallback reason work,
// see `GlutenFallbackReporter`.
origin.logicalLink.foreach(p.setLogicalLink)
p
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ object TransformHints {
throw new IllegalStateException("Transform hint tag not set in plan: " + plan.toString()))
}

private def getHintOption(plan: SparkPlan): Option[TransformHint] = {
def getHintOption(plan: SparkPlan): Option[TransformHint] = {
plan.getTagValue(TAG)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import io.glutenproject.execution.FileSourceScanExecTransformer

import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent}
import org.apache.spark.sql.{GlutenSQLTestsTrait, Row}
import org.apache.spark.sql.execution.ProjectExec
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.ui.{GlutenSQLAppStatusStore, SparkListenerSQLExecutionStart}
import org.apache.spark.status.ElementTrackingStore
Expand Down Expand Up @@ -133,4 +134,33 @@ class GlutenFallbackSuite extends GlutenSQLTestsTrait with AdaptiveSparkPlanHelp
}
}
}

test("Add logical link to rewritten spark plan") {
val events = new ArrayBuffer[GlutenPlanFallbackEvent]
val listener = new SparkListener {
override def onOtherEvent(event: SparkListenerEvent): Unit = {
event match {
case e: GlutenPlanFallbackEvent => events.append(e)
case _ =>
}
}
}
spark.sparkContext.addSparkListener(listener)
withSQLConf(GlutenConfig.EXPRESSION_BLACK_LIST.key -> "add") {
try {
val df = spark.sql("select sum(id +1) from range(10)")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: missing a space after "+".

spark.sparkContext.listenerBus.waitUntilEmpty()
df.collect()
val project = find(df.queryExecution.executedPlan) {
_.isInstanceOf[ProjectExec]
}
assert(project.isDefined)
events.exists(
_.fallbackNodeToReason.values.toSet
.contains("Project: Not supported to map spark function name"))
} finally {
spark.sparkContext.removeSparkListener(listener)
}
}
}
}
Loading