Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
* limitations under the License.
*/

package org.apache.spark.sql.execution.datasources
package org.apache.spark.sql.hive.execution

import org.scalatest.matchers.should.Matchers._

Expand All @@ -24,19 +24,18 @@ import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.execution.{FileSourceScanExec, SparkPlan}
import org.apache.spark.sql.execution.datasources.{CatalogFileIndex, HadoopFsRelation, LogicalRelation, PruneFileSourcePartitions}
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec
import org.apache.spark.sql.functions.broadcast
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.StructType

class PruneFileSourcePartitionsSuite extends PrunePartitionSuiteBase with SharedSparkSession {
class PruneFileSourcePartitionsSuite extends PrunePartitionSuiteBase {

override def format: String = "parquet"

Expand All @@ -46,27 +45,35 @@ class PruneFileSourcePartitionsSuite extends PrunePartitionSuiteBase with Shared

test("PruneFileSourcePartitions should not change the output of LogicalRelation") {
withTable("test") {
spark.range(10).selectExpr("id", "id % 3 as p").write.partitionBy("p").saveAsTable("test")
val tableMeta = spark.sharedState.externalCatalog.getTable("default", "test")
val catalogFileIndex = new CatalogFileIndex(spark, tableMeta, 0)

val dataSchema = StructType(tableMeta.schema.filterNot { f =>
tableMeta.partitionColumnNames.contains(f.name)
})
val relation = HadoopFsRelation(
location = catalogFileIndex,
partitionSchema = tableMeta.partitionSchema,
dataSchema = dataSchema,
bucketSpec = None,
fileFormat = new ParquetFileFormat(),
options = Map.empty)(sparkSession = spark)

val logicalRelation = LogicalRelation(relation, tableMeta)
val query = Project(Seq(Symbol("id"), Symbol("p")),
Filter(Symbol("p") === 1, logicalRelation)).analyze

val optimized = Optimize.execute(query)
assert(optimized.missingInput.isEmpty)
withTempDir { dir =>
sql(
s"""
|CREATE EXTERNAL TABLE test(i int)
|PARTITIONED BY (p int)
|STORED AS parquet
|LOCATION '${dir.toURI}'""".stripMargin)

val tableMeta = spark.sharedState.externalCatalog.getTable("default", "test")
val catalogFileIndex = new CatalogFileIndex(spark, tableMeta, 0)

val dataSchema = StructType(tableMeta.schema.filterNot { f =>
tableMeta.partitionColumnNames.contains(f.name)
})
val relation = HadoopFsRelation(
location = catalogFileIndex,
partitionSchema = tableMeta.partitionSchema,
dataSchema = dataSchema,
bucketSpec = None,
fileFormat = new ParquetFileFormat(),
options = Map.empty)(sparkSession = spark)

val logicalRelation = LogicalRelation(relation, tableMeta)
val query = Project(Seq(Symbol("i"), Symbol("p")),
Filter(Symbol("p") === 1, logicalRelation)).analyze

val optimized = Optimize.execute(query)
assert(optimized.missingInput.isEmpty)
}
}
}

Expand Down Expand Up @@ -135,10 +142,6 @@ class PruneFileSourcePartitionsSuite extends PrunePartitionSuiteBase with Shared
}
}

protected def collectPartitionFiltersFn(): PartialFunction[SparkPlan, Seq[Expression]] = {
case scan: FileSourceScanExec => scan.partitionFilters
}

override def getScanExecPartitionSize(plan: SparkPlan): Long = {
plan.collectFirst {
case p: FileSourceScanExec => p.selectedPartitions.length
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,13 @@
package org.apache.spark.sql.hive.execution

import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.datasources.PrunePartitionSuiteBase
import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.LongType

class PruneHiveTablePartitionsSuite extends PrunePartitionSuiteBase with TestHiveSingleton {
class PruneHiveTablePartitionsSuite extends PrunePartitionSuiteBase {

override def format(): String = "hive"

Expand Down Expand Up @@ -134,10 +131,6 @@ class PruneHiveTablePartitionsSuite extends PrunePartitionSuiteBase with TestHiv
}
}

protected def collectPartitionFiltersFn(): PartialFunction[SparkPlan, Seq[Expression]] = {
case scan: HiveTableScanExec => scan.partitionPruningPred
}

override def getScanExecPartitionSize(plan: SparkPlan): Long = {
plan.collectFirst {
case p: HiveTableScanExec => p
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,16 @@
* limitations under the License.
*/

package org.apache.spark.sql.execution.datasources
package org.apache.spark.sql.hive.execution

import org.apache.spark.sql.StatisticsCollectionTestBase
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, BinaryOperator, Expression, IsNotNull, Literal}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.{FileSourceScanExec, SparkPlan}
import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, FileScan}
import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.internal.SQLConf.ADAPTIVE_EXECUTION_ENABLED

abstract class PrunePartitionSuiteBase extends StatisticsCollectionTestBase {
abstract class PrunePartitionSuiteBase extends StatisticsCollectionTestBase with TestHiveSingleton {

protected def format: String

Expand Down Expand Up @@ -94,11 +95,11 @@ abstract class PrunePartitionSuiteBase extends StatisticsCollectionTestBase {
val plan = qe.sparkPlan
assert(getScanExecPartitionSize(plan) == expectedPartitionCount)

val collectFn: PartialFunction[SparkPlan, Seq[Expression]] = collectPartitionFiltersFn orElse {
val pushedDownPartitionFilters = plan.collectFirst {
case scan: FileSourceScanExec => scan.partitionFilters
case scan: HiveTableScanExec => scan.partitionPruningPred
case BatchScanExec(_, scan: FileScan, _) => scan.partitionFilters
}
val pushedDownPartitionFilters = plan.collectFirst(collectFn)
.map(exps => exps.filterNot(e => e.isInstanceOf[IsNotNull]))
}.map(exps => exps.filterNot(e => e.isInstanceOf[IsNotNull]))
val pushedFilters = pushedDownPartitionFilters.map(filters => {
filters.foldLeft("")((currentStr, exp) => {
if (currentStr == "") {
Expand All @@ -112,7 +113,5 @@ abstract class PrunePartitionSuiteBase extends StatisticsCollectionTestBase {
assert(pushedFilters == Some(expectedPushedDownFilters))
}

protected def collectPartitionFiltersFn(): PartialFunction[SparkPlan, Seq[Expression]]

protected def getScanExecPartitionSize(plan: SparkPlan): Long
}