diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 23f8c51045f0c..e239174e40ad4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -679,7 +679,7 @@ case class EnsureRequirements( applyPartialClustering: Boolean, replicatePartitions: Boolean): SparkPlan = plan match { case scan: BatchScanExec => - scan.copy( + val newScan = scan.copy( spjParams = scan.spjParams.copy( commonPartitionValues = Some(values), joinKeyPositions = joinKeyPositions, @@ -688,6 +688,8 @@ case class EnsureRequirements( replicatePartitions = replicatePartitions ) ) + newScan.copyTagsFrom(scan) + newScan case node => node.mapChildren(child => populateCommonPartitionInfo( child, values, joinKeyPositions, reducers, applyPartialClustering, replicatePartitions)) @@ -698,11 +700,13 @@ case class EnsureRequirements( plan: SparkPlan, joinKeyPositions: Option[Seq[Int]]): SparkPlan = plan match { case scan: BatchScanExec => - scan.copy( + val newScan = scan.copy( spjParams = scan.spjParams.copy( joinKeyPositions = joinKeyPositions ) ) + newScan.copyTagsFrom(scan) + newScan case node => node.mapChildren(child => populateJoinKeyPositions( child, joinKeyPositions))