diff --git a/python/pyspark/sql/tests/test_dataframe.py b/python/pyspark/sql/tests/test_dataframe.py index 692cf77d9afb9..c5c7dae4aeecb 100644 --- a/python/pyspark/sql/tests/test_dataframe.py +++ b/python/pyspark/sql/tests/test_dataframe.py @@ -963,6 +963,16 @@ def test_unpivot_negative(self): ): df.unpivot("id", ["int", "str"], "var", "val").collect() + def test_melt_groupby(self): + df = self.spark.createDataFrame( + [(1, 2, 3, 4, 5, 6)], + ["f1", "f2", "label", "pred", "model_version", "ts"], + ) + self.assertEqual( + df.melt("model_version", ["label", "f2"], "f1", "f2").groupby("f1").count().count(), + 2, + ) + def test_observe(self): # SPARK-36263: tests the DataFrame.observe(Observation, *Column) method from pyspark.sql import Observation diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/AnalysisHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/AnalysisHelper.scala index 56f6b116759a7..3403e3c5b71cd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/AnalysisHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/AnalysisHelper.scala @@ -177,10 +177,11 @@ trait AnalysisHelper extends QueryPlan[LogicalPlan] { self: LogicalPlan => self.markRuleAsIneffective(ruleId) self } else { - rewritten_plan + copyPlanIdTag(self, rewritten_plan) } } else { - afterRule.mapChildren(_.resolveOperatorsDownWithPruning(cond, ruleId)(rule)) + copyPlanIdTag(self, + afterRule.mapChildren(_.resolveOperatorsDownWithPruning(cond, ruleId)(rule))) } } } else { @@ -188,6 +189,12 @@ trait AnalysisHelper extends QueryPlan[LogicalPlan] { self: LogicalPlan => } } + def copyPlanIdTag(oldPlan: LogicalPlan, newPlan: LogicalPlan): LogicalPlan = { + oldPlan.getTagValue(LogicalPlan.PLAN_ID_TAG) + .foreach(id => newPlan.setTagValue(LogicalPlan.PLAN_ID_TAG, id)) + newPlan + } + /** * A variant of `transformUpWithNewOutput`, which skips touching already analyzed plan. */