Skip to content

Commit 9986462

Browse files
chenhao-dbcloud-fan
authored andcommitted
[SPARK-47385] Fix tuple encoders with Option inputs
## What changes were proposed in this pull request? #40755 adds a null check on the input of the child deserializer in the tuple encoder. It breaks the deserializer for the `Option` type, because null should be deserialized into `None` rather than null. This PR adds a boolean parameter to `ExpressionEncoder.tuple` so that only the user that #40755 intended to fix has this null check. ## How was this patch tested? Unit test. Closes #45508 from chenhao-db/SPARK-47385. Authored-by: Chenhao Li <chenhao.li@databricks.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent e980211 commit 9986462

File tree

3 files changed

+23
-3
lines changed

3 files changed

+23
-3
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,14 @@ object ExpressionEncoder {
7373
* Given a set of N encoders, constructs a new encoder that produce objects as items in an
7474
* N-tuple. Note that these encoders should be unresolved so that information about
7575
* name/positional binding is preserved.
76+
* When `useNullSafeDeserializer` is true, the deserialization result for a child will be null if
77+
* the input is null. It is false by default as most deserializers handle null input properly and
78+
* don't require an extra null check. Some of them are null-tolerant, such as the deserializer for
79+
* `Option[T]`, and we must not set it to true in this case.
7680
*/
77-
def tuple(encoders: Seq[ExpressionEncoder[_]]): ExpressionEncoder[_] = {
81+
def tuple(
82+
encoders: Seq[ExpressionEncoder[_]],
83+
useNullSafeDeserializer: Boolean = false): ExpressionEncoder[_] = {
7884
if (encoders.length > 22) {
7985
throw QueryExecutionErrors.elementsOfTupleExceedLimitError()
8086
}
@@ -119,7 +125,7 @@ object ExpressionEncoder {
119125
case GetColumnByOrdinal(0, _) => input
120126
}
121127

122-
if (enc.objSerializer.nullable) {
128+
if (useNullSafeDeserializer && enc.objSerializer.nullable) {
123129
nullSafe(input, childDeserializer)
124130
} else {
125131
childDeserializer

sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1246,7 +1246,9 @@ class Dataset[T] private[sql](
12461246
JoinHint.NONE)).analyzed.asInstanceOf[Join]
12471247

12481248
implicit val tuple2Encoder: Encoder[(T, U)] =
1249-
ExpressionEncoder.tuple(this.exprEnc, other.exprEnc)
1249+
ExpressionEncoder
1250+
.tuple(Seq(this.exprEnc, other.exprEnc), useNullSafeDeserializer = true)
1251+
.asInstanceOf[Encoder[(T, U)]]
12501252

12511253
withTypedPlan(JoinWith.typedJoinWith(
12521254
joined,

sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2576,6 +2576,18 @@ class DatasetSuite extends QueryTest
25762576
assert(result == expected)
25772577
}
25782578

2579+
test("SPARK-47385: Tuple encoder with Option inputs") {
2580+
implicit val enc: Encoder[(SingleData, Option[SingleData])] =
2581+
Encoders.tuple(Encoders.product[SingleData], Encoders.product[Option[SingleData]])
2582+
2583+
val input = Seq(
2584+
(SingleData(1), Some(SingleData(1))),
2585+
(SingleData(2), None)
2586+
)
2587+
val ds = spark.createDataFrame(input).as[(SingleData, Option[SingleData])]
2588+
checkDataset(ds, input: _*)
2589+
}
2590+
25792591
test("SPARK-43124: Show does not trigger job execution on CommandResults") {
25802592
withSQLConf(SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> "") {
25812593
withTable("t1") {

0 commit comments

Comments
 (0)