diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 5f136629eb15..339fbb8d8b57 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -152,10 +152,10 @@ abstract class UnaryNode extends LogicalPlan { override final def children: Seq[LogicalPlan] = child :: Nil /** - * Generates an additional set of aliased constraints by replacing the original constraint - * expressions with the corresponding alias + * Generates all valid constraints including an set of aliased constraints by replacing the + * original constraint expressions with the corresponding alias */ - protected def getAliasedConstraints(projectList: Seq[NamedExpression]): Set[Expression] = { + protected def getAllValidConstraints(projectList: Seq[NamedExpression]): Set[Expression] = { var allConstraints = child.constraints.asInstanceOf[Set[Expression]] projectList.foreach { case a @ Alias(l: Literal, _) => @@ -170,7 +170,7 @@ abstract class UnaryNode extends LogicalPlan { case _ => // Don't change. } - allConstraints -- child.constraints + allConstraints } override protected def validConstraints: Set[Expression] = child.constraints diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 7ff83a9be362..f09c5ceefed1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -64,7 +64,7 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) } override def validConstraints: Set[Expression] = - child.constraints.union(getAliasedConstraints(projectList)) + getAllValidConstraints(projectList) } /** @@ -595,7 +595,7 @@ case class Aggregate( override def validConstraints: Set[Expression] = { val nonAgg = aggregateExpressions.filter(_.find(_.isInstanceOf[AggregateExpression]).isEmpty) - child.constraints.union(getAliasedConstraints(nonAgg)) + getAllValidConstraints(nonAgg) } }