Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ object ResolveHints {
}
}

def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp {
case h: UnresolvedHint if BROADCAST_HINT_NAMES.contains(h.name.toUpperCase(Locale.ROOT)) =>
if (h.parameters.isEmpty) {
// If there is no table alias specified, turn the entire subtree into a BroadcastHint.
Expand Down Expand Up @@ -134,7 +134,7 @@ object ResolveHints {
* This must be executed after all the other hint rules are executed.
*/
object RemoveAllHints extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp {
case h: UnresolvedHint => h.child
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ object TypeCoercion {
*/
object WidenSetOperationTypes extends Rule[LogicalPlan] {

def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp {
case s @ Except(left, right, isAll) if s.childrenResolved &&
left.output.length == right.output.length && !s.resolved =>
val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(left :: right :: Nil)
Expand Down Expand Up @@ -757,17 +757,18 @@ object TypeCoercion {
*/
case class ConcatCoercion(conf: SQLConf) extends TypeCoercionRule {

override protected def coerceTypes(
plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown { case p =>
p transformExpressionsUp {
// Skip nodes if unresolved or empty children
case c @ Concat(children) if !c.childrenResolved || children.isEmpty => c
case c @ Concat(children) if conf.concatBinaryAsString ||
override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = {
plan resolveOperators { case p =>
p transformExpressionsUp {
// Skip nodes if unresolved or empty children
case c @ Concat(children) if !c.childrenResolved || children.isEmpty => c
case c @ Concat(children) if conf.concatBinaryAsString ||
!children.map(_.dataType).forall(_ == BinaryType) =>
val newChildren = c.children.map { e =>
ImplicitTypeCasts.implicitCast(e, StringType).getOrElse(e)
}
c.copy(children = newChildren)
val newChildren = c.children.map { e =>
ImplicitTypeCasts.implicitCast(e, StringType).getOrElse(e)
}
c.copy(children = newChildren)
}
}
}
}
Expand All @@ -780,23 +781,24 @@ object TypeCoercion {
*/
case class EltCoercion(conf: SQLConf) extends TypeCoercionRule {

override protected def coerceTypes(
plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown { case p =>
p transformExpressionsUp {
// Skip nodes if unresolved or not enough children
case c @ Elt(children) if !c.childrenResolved || children.size < 2 => c
case c @ Elt(children) =>
val index = children.head
val newIndex = ImplicitTypeCasts.implicitCast(index, IntegerType).getOrElse(index)
val newInputs = if (conf.eltOutputAsString ||
override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = {
plan resolveOperators { case p =>
p transformExpressionsUp {
// Skip nodes if unresolved or not enough children
case c @ Elt(children) if !c.childrenResolved || children.size < 2 => c
case c @ Elt(children) =>
val index = children.head
val newIndex = ImplicitTypeCasts.implicitCast(index, IntegerType).getOrElse(index)
val newInputs = if (conf.eltOutputAsString ||
!children.tail.map(_.dataType).forall(_ == BinaryType)) {
children.tail.map { e =>
ImplicitTypeCasts.implicitCast(e, StringType).getOrElse(e)
children.tail.map { e =>
ImplicitTypeCasts.implicitCast(e, StringType).getOrElse(e)
}
} else {
children.tail
}
} else {
children.tail
}
c.copy(children = newIndex +: newInputs)
c.copy(children = newIndex +: newInputs)
}
}
}
}
Expand Down Expand Up @@ -1007,7 +1009,7 @@ trait TypeCoercionRule extends Rule[LogicalPlan] with Logging {

protected def coerceTypes(plan: LogicalPlan): LogicalPlan

private def propagateTypes(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
private def propagateTypes(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp {
// No propagation required for leaf nodes.
case q: LogicalPlan if q.children.isEmpty => q

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,20 +32,17 @@ import org.apache.spark.sql.types.DataType
*/
case class ResolveHigherOrderFunctions(catalog: SessionCatalog) extends Rule[LogicalPlan] {

override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
case q: LogicalPlan =>
q.transformExpressions {
case u @ UnresolvedFunction(fn, children, false)
if hasLambdaAndResolvedArguments(children) =>
withPosition(u) {
catalog.lookupFunction(fn, children) match {
case func: HigherOrderFunction => func
case other => other.failAnalysis(
"A lambda function should only be used in a higher order function. However, " +
s"its class is ${other.getClass.getCanonicalName}, which is not a " +
s"higher order function.")
}
}
override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveExpressions {
case u @ UnresolvedFunction(fn, children, false)
if hasLambdaAndResolvedArguments(children) =>
withPosition(u) {
catalog.lookupFunction(fn, children) match {
case func: HigherOrderFunction => func
case other => other.failAnalysis(
"A lambda function should only be used in a higher order function. However, " +
s"its class is ${other.getClass.getCanonicalName}, which is not a " +
s"higher order function.")
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ import org.apache.spark.sql.internal.SQLConf
* completely resolved during the batch of Resolution.
*/
case class AliasViewChild(conf: SQLConf) extends Rule[LogicalPlan] with CastSupport {
override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp {
case v @ View(desc, output, child) if child.resolved && output != child.output =>
val resolver = conf.resolver
val queryColumnNames = desc.viewQueryColumnNames
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,19 @@ trait AnalysisHelper extends QueryPlan[LogicalPlan] { self: LogicalPlan =>
*/
def analyzed: Boolean = _analyzed

/**
* Returns a copy of this node where `rule` has been recursively applied to the tree. When
* `rule` does not apply to a given node, it is left unchanged. This function is similar to
* `transform`, but skips sub-trees that have already been marked as analyzed.
* Users should not expect a specific directionality. If a specific directionality is needed,
* [[resolveOperatorsUp]] or [[resolveOperatorsDown]] should be used.
*
* @param rule the function use to transform this nodes children
*/
def resolveOperators(rule: PartialFunction[LogicalPlan, LogicalPlan]): LogicalPlan = {
resolveOperatorsDown(rule)
}

/**
* Returns a copy of this node where `rule` has been recursively applied first to all of its
* children and then itself (post-order, bottom-up). When `rule` does not apply to a given node,
Expand All @@ -68,10 +81,10 @@ trait AnalysisHelper extends QueryPlan[LogicalPlan] { self: LogicalPlan =>
*
* @param rule the function use to transform this nodes children
*/
def resolveOperators(rule: PartialFunction[LogicalPlan, LogicalPlan]): LogicalPlan = {
def resolveOperatorsUp(rule: PartialFunction[LogicalPlan, LogicalPlan]): LogicalPlan = {
if (!analyzed) {
AnalysisHelper.allowInvokingTransformsInAnalyzer {
val afterRuleOnChildren = mapChildren(_.resolveOperators(rule))
val afterRuleOnChildren = mapChildren(_.resolveOperatorsUp(rule))
if (self fastEquals afterRuleOnChildren) {
CurrentOrigin.withOrigin(origin) {
rule.applyOrElse(self, identity[LogicalPlan])
Expand All @@ -87,7 +100,7 @@ trait AnalysisHelper extends QueryPlan[LogicalPlan] { self: LogicalPlan =>
}
}

/** Similar to [[resolveOperators]], but does it top-down. */
/** Similar to [[resolveOperatorsUp]], but does it top-down. */
def resolveOperatorsDown(rule: PartialFunction[LogicalPlan, LogicalPlan]): LogicalPlan = {
if (!analyzed) {
AnalysisHelper.allowInvokingTransformsInAnalyzer {
Expand Down
8 changes: 3 additions & 5 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1383,8 +1383,7 @@ class Dataset[T] private[sql](
@InterfaceStability.Evolving
def select[U1](c1: TypedColumn[T, U1]): Dataset[U1] = {
implicit val encoder = c1.encoder
val project = Project(c1.withInputType(exprEnc, logicalPlan.output).named :: Nil,
logicalPlan)
val project = Project(c1.withInputType(exprEnc, logicalPlan.output).named :: Nil, logicalPlan)

if (encoder.flat) {
new Dataset[U1](sparkSession, project, encoder)
Expand Down Expand Up @@ -1658,15 +1657,14 @@ class Dataset[T] private[sql](
@Experimental
@InterfaceStability.Evolving
def groupByKey[K: Encoder](func: T => K): KeyValueGroupedDataset[K, T] = {
val inputPlan = logicalPlan
val withGroupingKey = AppendColumns(func, inputPlan)
val withGroupingKey = AppendColumns(func, logicalPlan)
val executed = sparkSession.sessionState.executePlan(withGroupingKey)

new KeyValueGroupedDataset(
encoderFor[K],
encoderFor[T],
executed,
inputPlan.output,
logicalPlan.output,
withGroupingKey.newColumns)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ case class DataSourceAnalysis(conf: SQLConf) extends Rule[LogicalPlan] with Cast
projectList
}

override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown {
override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case CreateTable(tableDesc, mode, None) if DDLUtils.isDatasourceTable(tableDesc) =>
DDLUtils.checkDataColNames(tableDesc)
CreateDataSourceTableCommand(tableDesc, ignoreIfExists = mode == SaveMode.Ignore)
Expand Down Expand Up @@ -252,7 +252,7 @@ class FindDataSourceTable(sparkSession: SparkSession) extends Rule[LogicalPlan]
table.partitionSchema.asNullable.toAttributes)
}

override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown {
override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case i @ InsertIntoTable(UnresolvedCatalogRelation(tableMeta), _, _, _, _)
if DDLUtils.isDatasourceTable(tableMeta) =>
i.copy(table = readDataSourceTable(tableMeta))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ case class PreprocessTableCreation(sparkSession: SparkSession) extends Rule[Logi
// catalog is a def and not a val/lazy val as the latter would introduce a circular reference
private def catalog = sparkSession.sessionState.catalog

def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown {
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
// When we CREATE TABLE without specifying the table schema, we should fail the query if
// bucketing information is specified, as we can't infer bucketing from data files currently.
// Since the runtime inferred partition columns could be different from what user specified,
Expand Down Expand Up @@ -365,7 +365,7 @@ case class PreprocessTableInsertion(conf: SQLConf) extends Rule[LogicalPlan] {
}
}

def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown {
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case i @ InsertIntoTable(table, _, query, _, _) if table.resolved && query.resolved =>
table match {
case relation: HiveTableRelation =>
Expand Down