From a87b6f0a47ef909c552461c5d9f0a0ea2de1e826 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Wed, 25 Jan 2023 14:14:16 +0800 Subject: [PATCH 1/2] init --- .../connect/planner/SparkConnectPlanner.scala | 52 ++++++++++++++++--- python/pyspark/sql/column.py | 4 +- .../sql/tests/connect/test_connect_basic.py | 32 ++++++++++++ 3 files changed, 78 insertions(+), 10 deletions(-) diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index dc921cee2822..ef0ac5f2ba92 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -1235,14 +1235,21 @@ class SparkConnectPlanner(val session: SparkSession) { } } + private def splitConjunctivePredicates(condition: Expression): Seq[Expression] = { + condition match { + case UnresolvedFunction(Seq("and"), Seq(cond1, cond2), _, _, _) => + splitConjunctivePredicates(cond1) ++ splitConjunctivePredicates(cond2) + case other => other :: Nil + } + } + private def transformJoin(rel: proto.Join): LogicalPlan = { assert(rel.hasLeft && rel.hasRight, "Both join sides must be present") if (rel.hasJoinCondition && rel.getUsingColumnsCount > 0) { throw InvalidPlanInput( s"Using columns or join conditions cannot be set at the same time in Join") } - val joinCondition = - if (rel.hasJoinCondition) Some(transformExpression(rel.getJoinCondition)) else None + val catalystJointype = transformJoinType( if (rel.getJoinType != null) rel.getJoinType else proto.Join.JoinType.JOIN_TYPE_INNER) val joinType = if (rel.getUsingColumnsCount > 0) { @@ -1250,12 +1257,41 @@ class SparkConnectPlanner(val session: SparkSession) { } else { catalystJointype } - logical.Join( - left = transformRelation(rel.getLeft), - right = transformRelation(rel.getRight), - joinType = joinType, - condition = joinCondition, - hint = logical.JoinHint.NONE) + + if (rel.hasJoinCondition) { + val leftDF = Dataset.ofRows(session, transformRelation(rel.getLeft)) + val rightDF = Dataset.ofRows(session, transformRelation(rel.getRight)) + val joinExprs = splitConjunctivePredicates(transformExpression(rel.getJoinCondition)) + .map { + case func @ UnresolvedFunction(Seq(f), Seq(l, r), _, _, _) + if Seq("==", "<=>").contains(f) => + val l2 = l match { + case UnresolvedAttribute(Seq(c)) => leftDF.apply(c).expr + case other => other + } + val r2 = r match { + case UnresolvedAttribute(Seq(c)) => rightDF.apply(c).expr + case other => other + } + func.copy(arguments = Seq(l2, r2)) + + case other => other + } + .reduce(And) + + leftDF + .join(right = rightDF, joinExprs = Column(joinExprs), joinType = joinType.sql) + .logicalPlan + + } else { + + logical.Join( + left = transformRelation(rel.getLeft), + right = transformRelation(rel.getRight), + joinType = joinType, + condition = None, + hint = logical.JoinHint.NONE) + } } private def transformJoinType(t: proto.Join.JoinType): JoinType = { diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index a790f191110f..8729c06bb157 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -315,9 +315,9 @@ def __ne__( # type: ignore[override] ... Row(value = 'bar'), ... Row(value = None) ... ]) - >>> df1.join(df2, df1["value"] == df2["value"]).count() # doctest: +SKIP + >>> df1.join(df2, df1["value"] == df2["value"]).count() 0 - >>> df1.join(df2, df1["value"].eqNullSafe(df2["value"])).count() # doctest: +SKIP + >>> df1.join(df2, df1["value"].eqNullSafe(df2["value"])).count() 1 >>> df2 = spark.createDataFrame([ ... Row(id=1, value=float('NaN')), diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index 3f7494a63853..1059a0b27488 100644 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -337,6 +337,38 @@ def test_join_condition_column_list_columns(self): ) self.assert_eq(joined_plan3.toPandas(), joined_plan4.toPandas()) + def test_join_ambiguous_cols(self): + # SPARK-41812: test join with ambiguous columns + data1 = [Row(id=1, value="foo"), Row(id=2, value=None)] + cdf1 = self.connect.createDataFrame(data1) + sdf1 = self.spark.createDataFrame(data1) + + data2 = [Row(value="bar"), Row(value=None)] + cdf2 = self.connect.createDataFrame(data2) + sdf2 = self.spark.createDataFrame(data2) + + cdf3 = cdf1.join(cdf2, cdf1["value"] == cdf2["value"]) + sdf3 = sdf1.join(sdf2, sdf1["value"] == sdf2["value"]) + + self.assertEqual(cdf3.schema, sdf3.schema) + self.assertEqual(cdf3.collect(), sdf3.collect()) + + cdf4 = cdf1.join(cdf2, cdf1["value"].eqNullSafe(cdf2["value"])) + sdf4 = sdf1.join(sdf2, sdf1["value"].eqNullSafe(sdf2["value"])) + + self.assertEqual(cdf4.schema, sdf4.schema) + self.assertEqual(cdf4.collect(), sdf4.collect()) + + cdf5 = cdf1.join( + cdf2, (cdf1["value"] == cdf2["value"]) & (cdf1["value"].eqNullSafe(cdf2["value"])) + ) + sdf5 = sdf1.join( + sdf2, (sdf1["value"] == sdf2["value"]) & (sdf1["value"].eqNullSafe(sdf2["value"])) + ) + + self.assertEqual(cdf5.schema, sdf5.schema) + self.assertEqual(cdf5.collect(), sdf5.collect()) + def test_collect(self): cdf = self.connect.read.table(self.tbl_name) sdf = self.spark.read.table(self.tbl_name) From 0d567519b0b5af51e03b1c4687767741b02c4404 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Wed, 25 Jan 2023 15:02:05 +0800 Subject: [PATCH 2/2] fix scala test --- .../planner/SparkConnectPlannerSuite.scala | 30 +++++++++---------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala index d8baa182e5ab..436083ada209 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala @@ -260,21 +260,21 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest { .addArguments(unresolvedAttribute) .build()) - val simpleJoin = proto.Relation.newBuilder - .setJoin( - proto.Join.newBuilder - .setLeft(readRel) - .setRight(readRel) - .setJoinType(proto.Join.JoinType.JOIN_TYPE_INNER) - .setJoinCondition(joinCondition) - .build()) - .build() - - val res = transform(simpleJoin) - assert(res.nodeName == "Join") - assert(res != null) + val e0 = intercept[AnalysisException] { + val simpleJoin = proto.Relation.newBuilder + .setJoin( + proto.Join.newBuilder + .setLeft(readRel) + .setRight(readRel) + .setJoinType(proto.Join.JoinType.JOIN_TYPE_INNER) + .setJoinCondition(joinCondition) + .build()) + .build() + transform(simpleJoin) + } + assert(e0.getMessage.contains("TABLE_OR_VIEW_NOT_FOUND")) - val e = intercept[InvalidPlanInput] { + val e1 = intercept[InvalidPlanInput] { val simpleJoin = proto.Relation.newBuilder .setJoin( proto.Join.newBuilder @@ -286,7 +286,7 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest { transform(simpleJoin) } assert( - e.getMessage.contains( + e1.getMessage.contains( "Using columns or join conditions cannot be set at the same time in Join")) }