diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 527618b8e2c5..e81df911852b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -207,13 +207,15 @@ trait PredicateHelper extends Logging { * CNF can explode exponentially in the size of the input expression when converting [[Or]] * clauses. Use a configuration [[SQLConf.MAX_CNF_NODE_COUNT]] to prevent such cases. * - * @param condition to be converted into CNF. + * @param condition Condition to be converted into CNF. + * @param groupExpsFunc A method for grouping intermediate results so that the final result can be + * shorter. * @return the CNF result as sequence of disjunctive expressions. If the number of expressions * exceeds threshold on converting `Or`, `Seq.empty` is returned. */ - protected def conjunctiveNormalForm( + protected def CNFConversion( condition: Expression, - groupExpsFunc: Seq[Expression] => Seq[Expression]): Seq[Expression] = { + groupExpsFunc: Seq[Expression] => Seq[Expression] = identity): Seq[Expression] = { val postOrderNodes = postOrderTraversal(condition) val resultStack = new mutable.Stack[Seq[Expression]] val maxCnfNodeCount = SQLConf.get.maxCnfNodeCount @@ -256,33 +258,15 @@ trait PredicateHelper extends Logging { * when expand predicates, we can group by the qualifier avoiding generate unnecessary * expression to control the length of final result since there are multiple tables. * - * @param condition condition need to be converted + * @param condition Condition to be converted into CNF. * @return the CNF result as sequence of disjunctive expressions. If the number of expressions * exceeds threshold on converting `Or`, `Seq.empty` is returned. */ def CNFWithGroupExpressionsByQualifier(condition: Expression): Seq[Expression] = { - conjunctiveNormalForm(condition, (expressions: Seq[Expression]) => + CNFConversion(condition, (expressions: Seq[Expression]) => expressions.groupBy(_.references.map(_.qualifier)).map(_._2.reduceLeft(And)).toSeq) } - /** - * Convert an expression to conjunctive normal form for predicate pushdown and partition pruning. - * When expanding predicates, this method groups expressions by their references for reducing - * the size of pushed down predicates and corresponding codegen. In partition pruning strategies, - * we split filters by [[splitConjunctivePredicates]] and partition filters by judging if it's - * references is subset of partCols, if we combine expressions group by reference when expand - * predicate of [[Or]], it won't impact final predicate pruning result since - * [[splitConjunctivePredicates]] won't split [[Or]] expression. - * - * @param condition condition need to be converted - * @return the CNF result as sequence of disjunctive expressions. If the number of expressions - * exceeds threshold on converting `Or`, `Seq.empty` is returned. - */ - def CNFWithGroupExpressionsByReference(condition: Expression): Seq[Expression] = { - conjunctiveNormalForm(condition, (expressions: Seq[Expression]) => - expressions.groupBy(e => AttributeSet(e.references)).map(_._2.reduceLeft(And)).toSeq) - } - /** * Iterative post order traversal over a binary tree built by And/Or clauses with two stacks. * For example, a condition `(a And b) Or c`, the postorder traversal is @@ -294,7 +278,7 @@ trait PredicateHelper extends Logging { * 2.1 Pop a node from first stack and push it to second stack * 2.2 Push the children of the popped node to first stack * - * @param condition to be traversed as binary tree + * @param condition Condition to be traversed as binary tree * @return sub-expressions in post order traversal as a stack. * The first element of result stack is the leftmost node. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala index 576a826faf89..580a0a773a8b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala @@ -53,11 +53,20 @@ private[sql] object PruneFileSourcePartitions val partitionColumns = relation.resolve(partitionSchema, sparkSession.sessionState.analyzer.resolver) val partitionSet = AttributeSet(partitionColumns) - val (partitionFilters, dataFilters) = normalizedFilters.partition(f => + val (partitionFilters, remainingFilters) = normalizedFilters.partition(f => f.references.subsetOf(partitionSet) ) - (ExpressionSet(partitionFilters), dataFilters) + // Try extracting more convertible partition filters from the remaining filters by converting + // them into CNF. + val remainingFilterInCnf = remainingFilters.flatMap(CNFConversion(_)) + val extraPartitionFilters = + remainingFilterInCnf.filter(f => f.references.subsetOf(partitionSet)) + + // For the filters that can't be used for partition pruning, we simply use `remainingFilters` + // instead of using the non-convertible part from `remainingFilterInCnf`. Otherwise, the + // result filters can be very long. + (ExpressionSet(partitionFilters ++ extraPartitionFilters), remainingFilters) } private def rebuildPhysicalOperation( @@ -88,12 +97,9 @@ private[sql] object PruneFileSourcePartitions _, _)) if filters.nonEmpty && fsRelation.partitionSchemaOption.isDefined => - val predicates = CNFWithGroupExpressionsByReference(filters.reduceLeft(And)) - val finalPredicates = if (predicates.nonEmpty) predicates else filters val (partitionKeyFilters, _) = getPartitionKeyFiltersAndDataFilters( - fsRelation.sparkSession, logicalRelation, partitionSchema, finalPredicates, + fsRelation.sparkSession, logicalRelation, partitionSchema, filters, logicalRelation.output) - if (partitionKeyFilters.nonEmpty) { val prunedFileIndex = catalogFileIndex.filterPartitions(partitionKeyFilters.toSeq) val prunedFsRelation = diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitions.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitions.scala index c4885f284259..de9bd58ba829 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitions.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitions.scala @@ -54,9 +54,15 @@ private[sql] class PruneHiveTablePartitions(session: SparkSession) val normalizedFilters = DataSourceStrategy.normalizeExprs( filters.filter(f => f.deterministic && !SubqueryExpression.hasSubquery(f)), relation.output) val partitionColumnSet = AttributeSet(relation.partitionCols) - ExpressionSet(normalizedFilters.filter { f => + val (partitionFilters, remainingFilters) = normalizedFilters.partition { f => !f.references.isEmpty && f.references.subsetOf(partitionColumnSet) - }) + } + // Try extracting more convertible partition filters from the remaining filters by converting + // them into CNF. + val remainingFilterInCnf = remainingFilters.flatMap(CNFConversion(_)) + val extraPartitionFilters = remainingFilterInCnf.filter(f => + !f.references.isEmpty && f.references.subsetOf(partitionColumnSet)) + ExpressionSet(partitionFilters ++ extraPartitionFilters) } /** @@ -103,9 +109,7 @@ private[sql] class PruneHiveTablePartitions(session: SparkSession) override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case op @ PhysicalOperation(projections, filters, relation: HiveTableRelation) if filters.nonEmpty && relation.isPartitioned && relation.prunedPartitions.isEmpty => - val predicates = CNFWithGroupExpressionsByReference(filters.reduceLeft(And)) - val finalPredicates = if (predicates.nonEmpty) predicates else filters - val partitionKeyFilters = getPartitionKeyFilters(finalPredicates, relation) + val partitionKeyFilters = getPartitionKeyFilters(filters, relation) if (partitionKeyFilters.nonEmpty) { val newPartitions = prunePartitions(relation, partitionKeyFilters) val newTableMeta = updateTableMeta(relation.tableMeta, newPartitions) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PrunePartitionSuiteBase.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PrunePartitionSuiteBase.scala index d088061cdc6e..539c405e22d5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PrunePartitionSuiteBase.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PrunePartitionSuiteBase.scala @@ -67,6 +67,31 @@ abstract class PrunePartitionSuiteBase extends QueryTest with SQLTestUtils with } } + test("SPARK-32284: Avoid expanding too many CNF predicates in partition pruning") { + withTempView("temp") { + withTable("t") { + sql( + s""" + |CREATE TABLE t(i INT, p0 INT, p1 INT) + |USING $format + |PARTITIONED BY (p0, p1)""".stripMargin) + + spark.range(0, 10, 1).selectExpr("id as col") + .createOrReplaceTempView("temp") + + for (part <- (0 to 25)) { + sql( + s""" + |INSERT OVERWRITE TABLE t PARTITION (p0='$part', p1='$part') + |SELECT col FROM temp""".stripMargin) + } + val scale = 20 + val predicate = (1 to scale).map(i => s"(p0 = '$i' AND p1 = '$i')").mkString(" OR ") + assertPrunedPartitions(s"SELECT * FROM t WHERE $predicate", scale) + } + } + } + protected def assertPrunedPartitions(query: String, expected: Long): Unit = { val plan = sql(query).queryExecution.sparkPlan assert(getScanExecPartitionSize(plan) == expected)