Skip to content

Commit 97027f0

Browse files
author
Davies Liu
committed
fix overflow in LongToUnsafeRowMap
1 parent 03d46aa commit 97027f0

File tree

2 files changed

+49
-4
lines changed

2 files changed

+49
-4
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -459,8 +459,8 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
459459
*/
460460
def getValue(key: Long, resultRow: UnsafeRow): UnsafeRow = {
461461
if (isDense) {
462-
val idx = (key - minKey).toInt
463-
if (idx >= 0 && key <= maxKey && array(idx) > 0) {
462+
val idx = (key - minKey).toInt // could overflow
463+
if (key >= minKey && key <= maxKey && array(idx) > 0) {
464464
return getRow(array(idx), resultRow)
465465
}
466466
} else {
@@ -497,8 +497,8 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
497497
*/
498498
def get(key: Long, resultRow: UnsafeRow): Iterator[UnsafeRow] = {
499499
if (isDense) {
500-
val idx = (key - minKey).toInt
501-
if (idx >=0 && key <= maxKey && array(idx) > 0) {
500+
val idx = (key - minKey).toInt // could overflow
501+
if (key >= minKey && key <= maxKey && array(idx) > 0) {
502502
return valueIter(array(idx), resultRow)
503503
}
504504
} else {

sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,51 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext {
152152
}
153153
}
154154

155+
test("LongToUnsafeRowMap with very wide range") {
156+
val taskMemoryManager = new TaskMemoryManager(
157+
new StaticMemoryManager(
158+
new SparkConf().set("spark.memory.offHeap.enabled", "false"),
159+
Long.MaxValue,
160+
Long.MaxValue,
161+
1),
162+
0)
163+
val unsafeProj = UnsafeProjection.create(Seq(BoundReference(0, LongType, false)))
164+
165+
{
166+
// SPARK-16740
167+
val keys = Seq(0L, Long.MaxValue, Long.MaxValue)
168+
val map = new LongToUnsafeRowMap(taskMemoryManager, 1)
169+
keys.foreach { k =>
170+
map.append(k, unsafeProj(InternalRow(k)))
171+
}
172+
map.optimize()
173+
val row = unsafeProj(InternalRow(0L)).copy()
174+
keys.foreach { k =>
175+
assert(map.getValue(k, row) eq row)
176+
assert(row.getLong(0) === k)
177+
}
178+
map.free()
179+
}
180+
181+
182+
{
183+
// SPARK-16802
184+
val keys = Seq(Long.MaxValue, Long.MaxValue - 10)
185+
val map = new LongToUnsafeRowMap(taskMemoryManager, 1)
186+
keys.foreach { k =>
187+
map.append(k, unsafeProj(InternalRow(k)))
188+
}
189+
map.optimize()
190+
val row = unsafeProj(InternalRow(0L)).copy()
191+
keys.foreach { k =>
192+
assert(map.getValue(k, row) eq row)
193+
assert(row.getLong(0) === k)
194+
}
195+
assert(map.getValue(Long.MinValue, row) eq null)
196+
map.free()
197+
}
198+
}
199+
155200
test("Spark-14521") {
156201
val ser = new KryoSerializer(
157202
(new SparkConf).set("spark.kryo.referenceTracking", "false")).newInstance()

0 commit comments

Comments
 (0)