Skip to content

Commit 3831f76

Browse files
committed
[SPARK-37829][SQL] Add if(isnull ...) check for DataFrame.joinWith
Wrap tuple fields deserializers in null checks when calling on DataFrames as top-level rows are not nullable and won't propagate null values.
1 parent ce6295b commit 3831f76

File tree

2 files changed

+36
-2
lines changed

2 files changed

+36
-2
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,28 @@ object RowEncoder {
7777
ClassTag(cls))
7878
}
7979

80+
/**
81+
* Returns an ExpressionEncoder allowing null top-level rows.
82+
* @param exprEnc an ExpressionEncoder[Row].
83+
* @return an ExpressionEncoder[Row] whom deserializer supports null values.
84+
*
85+
* @see SPARK-37829
86+
*/
87+
private[sql] def nullSafe(exprEnc: ExpressionEncoder[Row]): ExpressionEncoder[Row] = {
88+
val newDeserializerInput = GetColumnByOrdinal(0, exprEnc.objSerializer.dataType)
89+
val newDeserializer: Expression = if (exprEnc.objSerializer.nullable) {
90+
If(
91+
IsNull(newDeserializerInput),
92+
Literal.create(null, exprEnc.objDeserializer.dataType),
93+
exprEnc.objDeserializer)
94+
} else {
95+
exprEnc.objDeserializer
96+
}
97+
exprEnc.copy(
98+
objDeserializer = newDeserializer
99+
)
100+
}
101+
80102
private def serializerFor(
81103
inputObject: Expression,
82104
inputType: DataType): Expression = inputType match {

sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1173,8 +1173,20 @@ class Dataset[T] private[sql](
11731173
joined = resolveSelfJoinCondition(joined)
11741174
}
11751175

1176-
implicit val tuple2Encoder: Encoder[(T, U)] =
1177-
ExpressionEncoder.tuple(this.exprEnc, other.exprEnc)
1176+
// SPARK-37829: an outer-join requires the null semantics to represent missing keys.
1177+
// As we might be running on DataFrames, we need a custom encoder that will properly
1178+
// handle null top-level Rows.
1179+
def nullSafe[V](exprEnc: ExpressionEncoder[V]): ExpressionEncoder[V] = {
1180+
if (exprEnc.clsTag.runtimeClass != classOf[Row]) {
1181+
exprEnc
1182+
} else {
1183+
RowEncoder.nullSafe(exprEnc.asInstanceOf[ExpressionEncoder[Row]])
1184+
.asInstanceOf[ExpressionEncoder[V]]
1185+
}
1186+
}
1187+
implicit val tuple2Encoder: Encoder[(T, U)] = ExpressionEncoder.tuple(
1188+
nullSafe(this.exprEnc), nullSafe(other.exprEnc)
1189+
)
11781190

11791191
val leftResultExpr = {
11801192
if (!this.exprEnc.isSerializedAsStructForTopLevel) {

0 commit comments

Comments
 (0)