Skip to content

Commit f83b412

Browse files
committed
Push null check into buffered iterator next().
1 parent 7d3cc5d commit f83b412

File tree

1 file changed

+19
-24
lines changed

1 file changed

+19
-24
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala

Lines changed: 19 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ private[joins] class SortMergeJoinScanner(
148148
private[this] var streamedRow: InternalRow = _
149149
private[this] var streamedRowKey: InternalRow = _
150150
private[this] var bufferedRow: InternalRow = _
151+
// Note: this is guaranteed to never have any null columns:
151152
private[this] var bufferedRowKey: InternalRow = _
152153
/**
153154
* The join key for the rows buffered in `bufferedMatches`, or null if `bufferedMatches` is empty
@@ -157,7 +158,7 @@ private[joins] class SortMergeJoinScanner(
157158
private[this] val bufferedMatches: ArrayBuffer[InternalRow] = new ArrayBuffer[InternalRow]
158159

159160
// Initialization (note: do _not_ want to advance streamed here).
160-
advancedBuffered()
161+
advancedBufferedToRowWithNullFreeJoinKey()
161162

162163
// --- Public methods ---------------------------------------------------------------------------
163164

@@ -196,11 +197,10 @@ private[joins] class SortMergeJoinScanner(
196197
do {
197198
if (streamedRowKey.anyNull) {
198199
advancedStreamed()
199-
} else if (bufferedRowKey.anyNull) {
200-
advancedBuffered()
201200
} else {
201+
assert(!bufferedRowKey.anyNull)
202202
comp = keyOrdering.compare(streamedRowKey, bufferedRowKey)
203-
if (comp > 0) advancedBuffered()
203+
if (comp > 0) advancedBufferedToRowWithNullFreeJoinKey()
204204
else if (comp < 0) advancedStreamed()
205205
}
206206
} while (streamedRow != null && bufferedRow != null && comp != 0)
@@ -242,15 +242,10 @@ private[joins] class SortMergeJoinScanner(
242242
if (bufferedRow != null && !streamedRowKey.anyNull) {
243243
// The buffered iterator could still contain matching rows, so we'll need to walk through
244244
// it until we either find matches or pass where they would be found.
245-
var comp =
246-
if (bufferedRowKey.anyNull) 1 else keyOrdering.compare(streamedRowKey, bufferedRowKey)
247-
while (comp > 0 && advancedBuffered()) {
248-
comp = if (bufferedRowKey.anyNull) {
249-
1
250-
} else {
251-
keyOrdering.compare(streamedRowKey, bufferedRowKey)
252-
}
253-
}
245+
var comp = 1
246+
do {
247+
comp = keyOrdering.compare(streamedRowKey, bufferedRowKey)
248+
} while (comp > 0 && advancedBufferedToRowWithNullFreeJoinKey())
254249
if (comp == 0) {
255250
// We have found matches, so buffer them (this updates matchJoinKey)
256251
bufferMatchingRows()
@@ -283,18 +278,22 @@ private[joins] class SortMergeJoinScanner(
283278
}
284279

285280
/**
286-
* Advance the buffered iterator and compute the new row's join key.
281+
* Advance the buffered iterator until we find a row with join key that does not contain nulls.
287282
* @return true if the buffered iterator returned a row and false otherwise.
288283
*/
289-
private def advancedBuffered(): Boolean = {
290-
if (bufferedIter.advanceNext()) {
284+
private def advancedBufferedToRowWithNullFreeJoinKey(): Boolean = {
285+
var foundRow: Boolean = false
286+
while (!foundRow && bufferedIter.advanceNext()) {
291287
bufferedRow = bufferedIter.getRow
292288
bufferedRowKey = bufferedKeyGenerator(bufferedRow)
293-
true
294-
} else {
289+
foundRow = !bufferedRowKey.anyNull
290+
}
291+
if (!foundRow) {
295292
bufferedRow = null
296293
bufferedRowKey = null
297294
false
295+
} else {
296+
true
298297
}
299298
}
300299

@@ -312,11 +311,7 @@ private[joins] class SortMergeJoinScanner(
312311
bufferedMatches.clear()
313312
do {
314313
bufferedMatches += bufferedRow.copy() // need to copy mutable rows before buffering them
315-
advancedBuffered()
316-
} while (
317-
bufferedRow != null &&
318-
!bufferedRowKey.anyNull &&
319-
keyOrdering.compare(streamedRowKey, bufferedRowKey) == 0
320-
)
314+
advancedBufferedToRowWithNullFreeJoinKey()
315+
} while (bufferedRow != null && keyOrdering.compare(streamedRowKey, bufferedRowKey) == 0)
321316
}
322317
}

0 commit comments

Comments
 (0)