diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 6d18d411615c..600b383a1e66 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1110,53 +1110,61 @@ class Dataset[T] private[sql]( throw new AnalysisException("Invalid join type in joinWith: " + joined.joinType.sql) } - // For both join side, combine all outputs into a single column and alias it with "_1" or "_2", - // to match the schema for the encoder of the join result. - // Note that we do this before joining them, to enable the join operator to return null for one - // side, in cases like outer-join. - val left = { - val combined = if (!this.exprEnc.isSerializedAsStructForTopLevel) { + implicit val tuple2Encoder: Encoder[(T, U)] = + ExpressionEncoder.tuple(this.exprEnc, other.exprEnc) + + val leftResultExpr = { + if (!this.exprEnc.isSerializedAsStructForTopLevel) { assert(joined.left.output.length == 1) Alias(joined.left.output.head, "_1")() } else { Alias(CreateStruct(joined.left.output), "_1")() } - Project(combined :: Nil, joined.left) } - val right = { - val combined = if (!other.exprEnc.isSerializedAsStructForTopLevel) { + val rightResultExpr = { + if (!other.exprEnc.isSerializedAsStructForTopLevel) { assert(joined.right.output.length == 1) Alias(joined.right.output.head, "_2")() } else { Alias(CreateStruct(joined.right.output), "_2")() } - Project(combined :: Nil, joined.right) - } - - // Rewrites the join condition to make the attribute point to correct column/field, after we - // combine the outputs of each join side. - val conditionExpr = joined.condition.get transformUp { - case a: Attribute if joined.left.outputSet.contains(a) => - if (!this.exprEnc.isSerializedAsStructForTopLevel) { - left.output.head - } else { - val index = joined.left.output.indexWhere(_.exprId == a.exprId) - GetStructField(left.output.head, index) - } - case a: Attribute if joined.right.outputSet.contains(a) => - if (!other.exprEnc.isSerializedAsStructForTopLevel) { - right.output.head - } else { - val index = joined.right.output.indexWhere(_.exprId == a.exprId) - GetStructField(right.output.head, index) - } } - implicit val tuple2Encoder: Encoder[(T, U)] = - ExpressionEncoder.tuple(this.exprEnc, other.exprEnc) + if (joined.joinType.isInstanceOf[InnerLike]) { + // For inner joins, we can directly perform the join and then can project the join + // results into structs. This ensures that data remains flat during shuffles / + // exchanges (unlike the outer join path, which nests the data before shuffling). + withTypedPlan(Project(Seq(leftResultExpr, rightResultExpr), joined)) + } else { // outer joins + // For both join sides, combine all outputs into a single column and alias it with "_1 + // or "_2", to match the schema for the encoder of the join result. + // Note that we do this before joining them, to enable the join operator to return null + // for one side, in cases like outer-join. + val left = Project(leftResultExpr :: Nil, joined.left) + val right = Project(rightResultExpr :: Nil, joined.right) + + // Rewrites the join condition to make the attribute point to correct column/field, + // after we combine the outputs of each join side. + val conditionExpr = joined.condition.get transformUp { + case a: Attribute if joined.left.outputSet.contains(a) => + if (!this.exprEnc.isSerializedAsStructForTopLevel) { + left.output.head + } else { + val index = joined.left.output.indexWhere(_.exprId == a.exprId) + GetStructField(left.output.head, index) + } + case a: Attribute if joined.right.outputSet.contains(a) => + if (!other.exprEnc.isSerializedAsStructForTopLevel) { + right.output.head + } else { + val index = joined.right.output.indexWhere(_.exprId == a.exprId) + GetStructField(right.output.head, index) + } + } - withTypedPlan(Join(left, right, joined.joinType, Some(conditionExpr), JoinHint.NONE)) + withTypedPlan(Join(left, right, joined.joinType, Some(conditionExpr), JoinHint.NONE)) + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index dd7c38011bc9..721ce65bc61d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -426,8 +426,17 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val ds1 = Seq(1, 2, 3).toDS().as("a") val ds2 = Seq(1, 2).toDS().as("b") + val joined = ds1.joinWith(ds2, $"a.value" === $"b.value", "inner") + + val expectedSchema = StructType(Seq( + StructField("_1", IntegerType, nullable = false), + StructField("_2", IntegerType, nullable = false) + )) + + assert(joined.schema === expectedSchema) + checkDataset( - ds1.joinWith(ds2, $"a.value" === $"b.value", "inner"), + joined, (1, 1), (2, 2)) } @@ -435,8 +444,21 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val ds1 = Seq(1, 1, 2).toDS() val ds2 = Seq(("a", 1), ("b", 2)).toDS() + val joined = ds1.joinWith(ds2, $"value" === $"_2") + + // This is an inner join, so both outputs fields are non-nullable + val expectedSchema = StructType(Seq( + StructField("_1", IntegerType, nullable = false), + StructField("_2", + StructType(Seq( + StructField("_1", StringType), + StructField("_2", IntegerType, nullable = false) + )), nullable = false) + )) + assert(joined.schema === expectedSchema) + checkDataset( - ds1.joinWith(ds2, $"value" === $"_2"), + joined, (1, ("a", 1)), (1, ("a", 1)), (2, ("b", 2))) } @@ -1105,6 +1127,23 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val left = Seq(ClassData("a", 1), ClassData("b", 2)).toDS().as("left") val right = Seq(ClassData("x", 2), ClassData("y", 3)).toDS().as("right") val joined = left.joinWith(right, $"left.b" === $"right.b", "left") + + val expectedSchema = StructType(Seq( + StructField("_1", + StructType(Seq( + StructField("a", StringType), + StructField("b", IntegerType, nullable = false) + )), + nullable = false), + // This is a left join, so the right output is nullable: + StructField("_2", + StructType(Seq( + StructField("a", StringType), + StructField("b", IntegerType, nullable = false) + ))) + )) + assert(joined.schema === expectedSchema) + val result = joined.collect().toSet assert(result == Set(ClassData("a", 1) -> null, ClassData("b", 2) -> ClassData("x", 2))) }