From a028f02f81da7a189aa2860d46841b80505da0c4 Mon Sep 17 00:00:00 2001 From: Hongze Zhang Date: Fri, 31 Oct 2025 12:31:20 +0100 Subject: [PATCH 1/8] LongHashedRelation off-heap --- .../sql/execution/joins/HashedRelation.scala | 146 +++++++++--------- 1 file changed, 77 insertions(+), 69 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index c67b55fd1d50c..3821b772f6c93 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -22,7 +22,7 @@ import java.io._ import com.esotericsoftware.kryo.{Kryo, KryoSerializable} import com.esotericsoftware.kryo.io.{Input, Output} -import org.apache.spark.{SparkConf, SparkEnv, SparkException, SparkUnsupportedOperationException} +import org.apache.spark.{SparkConf, SparkEnv, SparkUnsupportedOperationException} import org.apache.spark.internal.config.{BUFFER_PAGESIZE, MEMORY_OFFHEAP_ENABLED} import org.apache.spark.memory._ import org.apache.spark.sql.catalyst.InternalRow @@ -32,6 +32,7 @@ import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.types.LongType import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.map.BytesToBytesMap +import org.apache.spark.unsafe.memory.MemoryBlock import org.apache.spark.util.{KnownSizeEstimation, Utils} /** @@ -535,7 +536,7 @@ private[execution] final class LongToUnsafeRowMap( val mm: TaskMemoryManager, capacity: Int, ignoresDuplicatedKey: Boolean = false) - extends MemoryConsumer(mm, MemoryMode.ON_HEAP) with Externalizable with KryoSerializable { + extends MemoryConsumer(mm, mm.getTungstenMemoryMode) with Externalizable with KryoSerializable { // Whether the keys are stored in dense mode or not. private var isDense = false @@ -550,15 +551,15 @@ private[execution] final class LongToUnsafeRowMap( // // Sparse mode: [key1] [offset1 | size1] [key2] [offset | size2] ... // Dense mode: [offset1 | size1] [offset2 | size2] - private var array: Array[Long] = null + private var array: UnsafeLongArray = null private var mask: Int = 0 // The page to store all bytes of UnsafeRow and the pointer to next rows. // [row1][pointer1] [row2][pointer2] - private var page: Array[Long] = null + private var page: MemoryBlock = null // Current write cursor in the page. - private var cursor: Long = Platform.LONG_ARRAY_OFFSET + private var cursor: Long = -1 // The number of bits for size in address private val SIZE_BITS = 28 @@ -583,24 +584,15 @@ private[execution] final class LongToUnsafeRowMap( 0) } - private def ensureAcquireMemory(size: Long): Unit = { - // do not support spilling - val got = acquireMemory(size) - if (got < size) { - freeMemory(got) - throw QueryExecutionErrors.cannotAcquireMemoryToBuildLongHashedRelationError(size, got) - } - } - private def init(): Unit = { if (mm != null) { require(capacity < 512000000, "Cannot broadcast 512 million or more rows") var n = 1 while (n < capacity) n *= 2 - ensureAcquireMemory(n * 2L * 8 + (1 << 20)) - array = new Array[Long](n * 2) + array = new UnsafeLongArray(n * 2) mask = n * 2 - 2 - page = new Array[Long](1 << 17) // 1M bytes + page = allocatePage(1 << 20)// 1M bytes + cursor = page.getBaseOffset } } @@ -616,7 +608,7 @@ private[execution] final class LongToUnsafeRowMap( /** * Returns total memory consumption. */ - def getTotalMemoryConsumption: Long = array.length * 8L + page.length * 8L + def getTotalMemoryConsumption: Long = array.length * 8L + page.size() /** * Returns the first slot of array that store the keys (sparse mode). @@ -632,11 +624,11 @@ private[execution] final class LongToUnsafeRowMap( private def nextSlot(pos: Int): Int = (pos + 2) & mask private[this] def toAddress(offset: Long, size: Int): Long = { - ((offset - Platform.LONG_ARRAY_OFFSET) << SIZE_BITS) | size + (offset << SIZE_BITS) | size } private[this] def toOffset(address: Long): Long = { - (address >>> SIZE_BITS) + Platform.LONG_ARRAY_OFFSET + (address >>> SIZE_BITS) } private[this] def toSize(address: Long): Int = { @@ -644,7 +636,7 @@ private[execution] final class LongToUnsafeRowMap( } private def getRow(address: Long, resultRow: UnsafeRow): UnsafeRow = { - resultRow.pointTo(page, toOffset(address), toSize(address)) + resultRow.pointTo(page.getBaseObject, page.getBaseOffset + toOffset(address), toSize(address)) resultRow } @@ -681,8 +673,8 @@ private[execution] final class LongToUnsafeRowMap( override def next(): UnsafeRow = { val offset = toOffset(addr) val size = toSize(addr) - resultRow.pointTo(page, offset, size) - addr = Platform.getLong(page, offset + size) + resultRow.pointTo(page.getBaseObject, page.getBaseOffset + offset, size) + addr = Platform.getLong(page.getBaseObject, page.getBaseOffset + offset + size) resultRow } } @@ -777,12 +769,13 @@ private[execution] final class LongToUnsafeRowMap( // copy the bytes of UnsafeRow val offset = cursor - Platform.copyMemory(row.getBaseObject, row.getBaseOffset, page, cursor, row.getSizeInBytes) + Platform.copyMemory(row.getBaseObject, row.getBaseOffset, page.getBaseObject, cursor, + row.getSizeInBytes) cursor += row.getSizeInBytes - Platform.putLong(page, cursor, 0) + Platform.putLong(page.getBaseObject, cursor, 0) cursor += 8 numValues += 1 - updateIndex(key, pos, toAddress(offset, row.getSizeInBytes)) + updateIndex(key, pos, toAddress(offset - page.getBaseOffset, row.getSizeInBytes)) } private def findKeyPosition(key: Long): Int = { @@ -816,26 +809,24 @@ private[execution] final class LongToUnsafeRowMap( } else { // there are some values for this key, put the address in the front of them. val pointer = toOffset(address) + toSize(address) - Platform.putLong(page, pointer, array(pos + 1)) + Platform.putLong(page.getBaseObject, page.getBaseOffset + pointer, array(pos + 1)) array(pos + 1) = address } } private def grow(inputRowSize: Int): Unit = { // There is 8 bytes for the pointer to next value - val neededNumWords = (cursor - Platform.LONG_ARRAY_OFFSET + 8 + inputRowSize + 7) / 8 - if (neededNumWords > page.length) { + val neededNumWords = (cursor - page.getBaseOffset + 8 + inputRowSize + 7) / 8 + if (neededNumWords > page.size() / 8) { if (neededNumWords > (1 << 30)) { throw QueryExecutionErrors.cannotBuildHashedRelationLargerThan8GError() } - val newNumWords = math.max(neededNumWords, math.min(page.length * 2, 1 << 30)) - ensureAcquireMemory(newNumWords * 8L) - val newPage = new Array[Long](newNumWords.toInt) - Platform.copyMemory(page, Platform.LONG_ARRAY_OFFSET, newPage, Platform.LONG_ARRAY_OFFSET, - cursor - Platform.LONG_ARRAY_OFFSET) - val used = page.length + val newNumWords = math.max(neededNumWords, math.min(page.size() / 8 * 2, 1 << 30)) + val newPage = allocatePage(newNumWords.toInt * 8) + Platform.copyMemory(page.getBaseObject, page.getBaseOffset, newPage.getBaseObject, + newPage.getBaseOffset, cursor - page.getBaseOffset) + freePage(page) page = newPage - freeMemory(used * 8L) } } @@ -843,8 +834,7 @@ private[execution] final class LongToUnsafeRowMap( var old_array = array val n = array.length numKeys = 0 - ensureAcquireMemory(n * 2 * 8L) - array = new Array[Long](n * 2) + array = new UnsafeLongArray(n * 2) mask = n * 2 - 2 var i = 0 while (i < old_array.length) { @@ -854,8 +844,8 @@ private[execution] final class LongToUnsafeRowMap( } i += 2 } + old_array.free() old_array = null // release the reference to old array - freeMemory(n * 8L) } /** @@ -866,14 +856,7 @@ private[execution] final class LongToUnsafeRowMap( // Convert to dense mode if it does not require more memory or could fit within L1 cache // SPARK-16740: Make sure range doesn't overflow if minKey has a large negative value if (range >= 0 && (range < array.length || range < 1024)) { - try { - ensureAcquireMemory((range + 1) * 8L) - } catch { - case e: SparkException => - // there is no enough memory to convert - return - } - val denseArray = new Array[Long]((range + 1).toInt) + val denseArray = new UnsafeLongArray((range + 1).toInt) var i = 0 while (i < array.length) { if (array(i + 1) > 0) { @@ -882,10 +865,9 @@ private[execution] final class LongToUnsafeRowMap( } i += 2 } - val old_length = array.length + array.free() array = denseArray isDense = true - freeMemory(old_length * 8L) } } @@ -894,25 +876,26 @@ private[execution] final class LongToUnsafeRowMap( */ def free(): Unit = { if (page != null) { - freeMemory(page.length * 8L) + freePage(page) page = null } if (array != null) { - freeMemory(array.length * 8L) + array.free() array = null } } - private def writeLongArray( + private def writeBytes( writeBuffer: (Array[Byte], Int, Int) => Unit, - arr: Array[Long], + baseObject: Object, + baseOffset: Long, len: Int): Unit = { val buffer = new Array[Byte](4 << 10) - var offset: Long = Platform.LONG_ARRAY_OFFSET - val end = len * 8L + Platform.LONG_ARRAY_OFFSET + var offset: Long = baseOffset + val end = len * 8L + offset while (offset < end) { val size = Math.min(buffer.length, end - offset) - Platform.copyMemory(arr, offset, buffer, Platform.BYTE_ARRAY_OFFSET, size) + Platform.copyMemory(baseObject, offset, buffer, Platform.BYTE_ARRAY_OFFSET, size) writeBuffer(buffer, 0, size.toInt) offset += size } @@ -929,10 +912,11 @@ private[execution] final class LongToUnsafeRowMap( writeLong(numValues) writeLong(array.length) - writeLongArray(writeBuffer, array, array.length) - val used = ((cursor - Platform.LONG_ARRAY_OFFSET) / 8).toInt + writeBytes(writeBuffer, + array.memoryBlock.getBaseObject, array.memoryBlock.getBaseOffset, array.length) + val used = ((cursor - page.getBaseOffset) / 8).toInt writeLong(used) - writeLongArray(writeBuffer, page, used) + writeBytes(writeBuffer, page.getBaseObject, page.getBaseOffset, used) } override def writeExternal(output: ObjectOutput): Unit = { @@ -943,20 +927,20 @@ private[execution] final class LongToUnsafeRowMap( write(out.writeBoolean, out.writeLong, out.write) } - private def readLongArray( + private def readData( readBuffer: (Array[Byte], Int, Int) => Unit, - length: Int): Array[Long] = { - val array = new Array[Long](length) + baseObject: Object, + baseOffset: Long, + length: Int): Unit = { val buffer = new Array[Byte](4 << 10) - var offset: Long = Platform.LONG_ARRAY_OFFSET - val end = length * 8L + Platform.LONG_ARRAY_OFFSET + var offset: Long = baseOffset + val end = length * 8L + baseOffset while (offset < end) { val size = Math.min(buffer.length, end - offset) readBuffer(buffer, 0, size.toInt) - Platform.copyMemory(buffer, Platform.BYTE_ARRAY_OFFSET, array, offset, size) + Platform.copyMemory(buffer, Platform.BYTE_ARRAY_OFFSET, baseObject, offset, size) offset += size } - array } private def read( @@ -971,11 +955,15 @@ private[execution] final class LongToUnsafeRowMap( val length = readLong().toInt mask = length - 2 - array = readLongArray(readBuffer, length) + array.free() + array = new UnsafeLongArray(length) + readData(readBuffer, array.memoryBlock.getBaseObject, array.memoryBlock.getBaseOffset, length) val pageLength = readLong().toInt - page = readLongArray(readBuffer, pageLength) + freePage(page) + page = allocatePage(pageLength * 8) + readData(readBuffer, page.getBaseObject, page.getBaseOffset, pageLength) // Restore cursor variable to make this map able to be serialized again on executors. - cursor = pageLength * 8 + Platform.LONG_ARRAY_OFFSET + cursor = pageLength * 8 + page.getBaseOffset } override def readExternal(in: ObjectInput): Unit = { @@ -985,6 +973,26 @@ private[execution] final class LongToUnsafeRowMap( override def read(kryo: Kryo, in: Input): Unit = { read(() => in.readBoolean(), () => in.readLong(), in.readBytes) } + + private class UnsafeLongArray(val length: Int) { + val memoryBlock = allocatePage(length * 8) + + for (i <- 0 until length) { + update(i, 0) + } + + def apply(index: Int): Long = { + Platform.getLong(memoryBlock.getBaseObject, memoryBlock.getBaseOffset + index * 8) + } + + def update(index: Int, value: Long): Unit = { + Platform.putLong(memoryBlock.getBaseObject, memoryBlock.getBaseOffset + index * 8, value) + } + + def free(): Unit = { + freePage(memoryBlock) + } + } } class LongHashedRelation( From ddacb71fdb7030fa460e17a5a7ddcc7c5eb15699 Mon Sep 17 00:00:00 2001 From: Hongze Zhang Date: Mon, 10 Nov 2025 19:02:25 +0100 Subject: [PATCH 2/8] fixup --- .../execution/joins/HashedRelationSuite.scala | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index e46761f5cd048..df3d34cc13b16 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -25,9 +25,10 @@ import scala.util.Random import org.apache.spark.SparkConf import org.apache.spark.SparkException -import org.apache.spark.internal.config._ +import org.apache.spark.internal.config.{MEMORY_OFFHEAP_ENABLED, MEMORY_OFFHEAP_SIZE} import org.apache.spark.internal.config.Kryo._ import org.apache.spark.memory.{TaskMemoryManager, UnifiedMemoryManager} +import org.apache.spark.network.util.ByteUnit import org.apache.spark.serializer.KryoSerializer import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ @@ -39,9 +40,13 @@ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.collection.CompactBuffer -class HashedRelationSuite extends SharedSparkSession { +abstract class HashedRelationSuite extends SharedSparkSession { + protected def useOffHeapMemoryMode: Boolean + val umm = new UnifiedMemoryManager( - new SparkConf().set(MEMORY_OFFHEAP_ENABLED.key, "false"), + new SparkConf() + .set(MEMORY_OFFHEAP_ENABLED, useOffHeapMemoryMode) + .set(MEMORY_OFFHEAP_SIZE, ByteUnit.GiB.toBytes(1L)), Runtime.getRuntime.maxMemory, Runtime.getRuntime.maxMemory / 2, 1) @@ -754,3 +759,11 @@ class HashedRelationSuite extends SharedSparkSession { } } } + +class HashedRelationOnHeapSuite extends HashedRelationSuite { + override protected def useOffHeapMemoryMode: Boolean = true +} + +class HashedRelationOffHeapSuite extends HashedRelationSuite { + override protected def useOffHeapMemoryMode: Boolean = false +} From d534ed31b3e9bc43f7f7f2e7741ce891a3b54ffb Mon Sep 17 00:00:00 2001 From: Hongze Zhang Date: Tue, 11 Nov 2025 14:50:12 +0100 Subject: [PATCH 3/8] fixup --- .../org/apache/spark/sql/errors/QueryExecutionErrors.scala | 7 ------- 1 file changed, 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index 27aba1f7f2dfb..989ad8b0dc41e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -1178,13 +1178,6 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE cause = null) } - def cannotAcquireMemoryToBuildLongHashedRelationError(size: Long, got: Long): Throwable = { - new SparkException( - errorClass = "_LEGACY_ERROR_TEMP_2106", - messageParameters = Map("size" -> size.toString(), "got" -> got.toString()), - cause = null) - } - def cannotAcquireMemoryToBuildUnsafeHashedRelationError(): Throwable = { new SparkOutOfMemoryError( "_LEGACY_ERROR_TEMP_2107", From ee95b9c381c43931e5612a20d4522eef09650470 Mon Sep 17 00:00:00 2001 From: Hongze Zhang Date: Tue, 11 Nov 2025 15:00:21 +0100 Subject: [PATCH 4/8] fixup --- .../apache/spark/sql/execution/joins/HashedRelation.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 3821b772f6c93..3a8e6198d8c9e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -816,7 +816,8 @@ private[execution] final class LongToUnsafeRowMap( private def grow(inputRowSize: Int): Unit = { // There is 8 bytes for the pointer to next value - val neededNumWords = (cursor - page.getBaseOffset + 8 + inputRowSize + 7) / 8 + val usedBytes = cursor - page.getBaseOffset + val neededNumWords = (usedBytes + 8 + inputRowSize + 7) / 8 if (neededNumWords > page.size() / 8) { if (neededNumWords > (1 << 30)) { throw QueryExecutionErrors.cannotBuildHashedRelationLargerThan8GError() @@ -824,9 +825,10 @@ private[execution] final class LongToUnsafeRowMap( val newNumWords = math.max(neededNumWords, math.min(page.size() / 8 * 2, 1 << 30)) val newPage = allocatePage(newNumWords.toInt * 8) Platform.copyMemory(page.getBaseObject, page.getBaseOffset, newPage.getBaseObject, - newPage.getBaseOffset, cursor - page.getBaseOffset) + newPage.getBaseOffset, usedBytes) freePage(page) page = newPage + cursor = page.getBaseOffset + usedBytes } } From 33b92e48bb4c04e039f1cae4ef4b8c3f8a79870e Mon Sep 17 00:00:00 2001 From: Hongze Zhang Date: Wed, 12 Nov 2025 14:06:24 +0100 Subject: [PATCH 5/8] fixup --- .../org/apache/spark/sql/execution/joins/HashedRelation.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 3a8e6198d8c9e..913465cc9e377 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -977,7 +977,7 @@ private[execution] final class LongToUnsafeRowMap( } private class UnsafeLongArray(val length: Int) { - val memoryBlock = allocatePage(length * 8) + val memoryBlock: MemoryBlock = allocatePage(length * 8) for (i <- 0 until length) { update(i, 0) From e60073a5edd5183a8c92907b0b3656c4b18de5f2 Mon Sep 17 00:00:00 2001 From: Hongze Zhang Date: Wed, 12 Nov 2025 14:26:51 +0100 Subject: [PATCH 6/8] fixup --- .../spark/sql/execution/joins/HashedRelationSuite.scala | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index df3d34cc13b16..00aa46c29bfe3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -146,6 +146,9 @@ abstract class HashedRelationSuite extends SharedSparkSession { Seq(BoundReference(0, LongType, false), BoundReference(1, IntegerType, true))) val rows = (0 until 100).map(i => unsafeProj(InternalRow(Int.int2long(i), i + 1)).copy()) val key = Seq(BoundReference(0, LongType, false)) + while (true) { + LongHashedRelation(rows.iterator, key, 10, mm) + } val longRelation = LongHashedRelation(rows.iterator, key, 10, mm) assert(longRelation.keyIsUnique) (0 until 100).foreach { i => @@ -761,9 +764,9 @@ abstract class HashedRelationSuite extends SharedSparkSession { } class HashedRelationOnHeapSuite extends HashedRelationSuite { - override protected def useOffHeapMemoryMode: Boolean = true + override protected def useOffHeapMemoryMode: Boolean = false } class HashedRelationOffHeapSuite extends HashedRelationSuite { - override protected def useOffHeapMemoryMode: Boolean = false + override protected def useOffHeapMemoryMode: Boolean = true } From 2b62870a72a28905feedadc911ae91114beaf048 Mon Sep 17 00:00:00 2001 From: Hongze Zhang Date: Fri, 21 Nov 2025 17:53:06 +0100 Subject: [PATCH 7/8] fixup --- .../apache/spark/sql/execution/joins/HashedRelationSuite.scala | 3 --- 1 file changed, 3 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index 00aa46c29bfe3..f8fa2f5fe35f4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -146,9 +146,6 @@ abstract class HashedRelationSuite extends SharedSparkSession { Seq(BoundReference(0, LongType, false), BoundReference(1, IntegerType, true))) val rows = (0 until 100).map(i => unsafeProj(InternalRow(Int.int2long(i), i + 1)).copy()) val key = Seq(BoundReference(0, LongType, false)) - while (true) { - LongHashedRelation(rows.iterator, key, 10, mm) - } val longRelation = LongHashedRelation(rows.iterator, key, 10, mm) assert(longRelation.keyIsUnique) (0 until 100).foreach { i => From 6b77e26b477dedd8b0a69457514a4104e3d31b79 Mon Sep 17 00:00:00 2001 From: Hongze Zhang Date: Mon, 24 Nov 2025 19:04:44 +0100 Subject: [PATCH 8/8] fixup --- .../sql/execution/joins/HashedRelation.scala | 117 ++++++++---------- 1 file changed, 50 insertions(+), 67 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 913465cc9e377..242185e803577 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.types.LongType import org.apache.spark.unsafe.Platform +import org.apache.spark.unsafe.array.LongArray import org.apache.spark.unsafe.map.BytesToBytesMap import org.apache.spark.unsafe.memory.MemoryBlock import org.apache.spark.util.{KnownSizeEstimation, Utils} @@ -551,7 +552,7 @@ private[execution] final class LongToUnsafeRowMap( // // Sparse mode: [key1] [offset1 | size1] [key2] [offset | size2] ... // Dense mode: [offset1 | size1] [offset2 | size2] - private var array: UnsafeLongArray = null + private var array: LongArray = null private var mask: Int = 0 // The page to store all bytes of UnsafeRow and the pointer to next rows. @@ -589,9 +590,10 @@ private[execution] final class LongToUnsafeRowMap( require(capacity < 512000000, "Cannot broadcast 512 million or more rows") var n = 1 while (n < capacity) n *= 2 - array = new UnsafeLongArray(n * 2) + array = allocateArray(n * 2) + array.zeroOut() mask = n * 2 - 2 - page = allocatePage(1 << 20)// 1M bytes + page = allocatePage(1 << 20) // 1M bytes cursor = page.getBaseOffset } } @@ -608,7 +610,7 @@ private[execution] final class LongToUnsafeRowMap( /** * Returns total memory consumption. */ - def getTotalMemoryConsumption: Long = array.length * 8L + page.size() + def getTotalMemoryConsumption: Long = array.size() * 8L + page.size() /** * Returns the first slot of array that store the keys (sparse mode). @@ -646,16 +648,16 @@ private[execution] final class LongToUnsafeRowMap( def getValue(key: Long, resultRow: UnsafeRow): UnsafeRow = { if (isDense) { if (key >= minKey && key <= maxKey) { - val value = array((key - minKey).toInt) + val value = array.get((key - minKey).toInt) if (value > 0) { return getRow(value, resultRow) } } } else { var pos = firstSlot(key) - while (array(pos + 1) != 0) { - if (array(pos) == key) { - return getRow(array(pos + 1), resultRow) + while (array.get(pos + 1) != 0) { + if (array.get(pos) == key) { + return getRow(array.get(pos + 1), resultRow) } pos = nextSlot(pos) } @@ -686,16 +688,16 @@ private[execution] final class LongToUnsafeRowMap( def get(key: Long, resultRow: UnsafeRow): Iterator[UnsafeRow] = { if (isDense) { if (key >= minKey && key <= maxKey) { - val value = array((key - minKey).toInt) + val value = array.get((key - minKey).toInt) if (value > 0) { return valueIter(value, resultRow) } } } else { var pos = firstSlot(key) - while (array(pos + 1) != 0) { - if (array(pos) == key) { - return valueIter(array(pos + 1), resultRow) + while (array.get(pos + 1) != 0) { + if (array.get(pos) == key) { + return valueIter(array.get(pos + 1), resultRow) } pos = nextSlot(pos) } @@ -720,8 +722,8 @@ private[execution] final class LongToUnsafeRowMap( override def hasNext: Boolean = { // go to the next key if the current key slot is empty - while (pos + step < array.length) { - if (array(pos + step) > 0) { + while (pos + step < array.size()) { + if (array.get(pos + step) > 0) { return true } pos += step + 1 @@ -734,7 +736,7 @@ private[execution] final class LongToUnsafeRowMap( throw QueryExecutionErrors.endOfIteratorError() } else { // the key is retrieved based on the map mode - val ret = if (isDense) minKey + pos else array(pos) + val ret = if (isDense) minKey + pos else array.get(pos) // advance the cursor to the next index pos += step + 1 row.setLong(0, ret) @@ -754,7 +756,7 @@ private[execution] final class LongToUnsafeRowMap( } val pos = findKeyPosition(key) - if (ignoresDuplicatedKey && array(pos + 1) != 0) { + if (ignoresDuplicatedKey && array.get(pos + 1) != 0) { return } @@ -780,8 +782,8 @@ private[execution] final class LongToUnsafeRowMap( private def findKeyPosition(key: Long): Int = { var pos = firstSlot(key) - assert(numKeys < array.length / 2) - while (array(pos) != key && array(pos + 1) != 0) { + assert(numKeys < array.size() / 2) + while (array.get(pos) != key && array.get(pos + 1) != 0) { pos = nextSlot(pos) } pos @@ -791,17 +793,17 @@ private[execution] final class LongToUnsafeRowMap( * Update the address in array for given key. */ private def updateIndex(key: Long, pos: Int, address: Long): Unit = { - if (array(pos + 1) == 0) { + if (array.get(pos + 1) == 0) { // this is the first value for this key, put the address in array. - array(pos) = key - array(pos + 1) = address + array.set(pos, key) + array.set(pos + 1, address) numKeys += 1 - if (numKeys * 4 > array.length) { + if (numKeys * 4 > array.size()) { // reach half of the capacity - if (array.length < (1 << 30)) { + if (array.size() < (1 << 30)) { // Cannot allocate an array with 2G elements growArray() - } else if (numKeys > array.length / 2 * 0.75) { + } else if (numKeys > array.size() / 2 * 0.75) { // The fill ratio should be less than 0.75 throw QueryExecutionErrors.cannotBuildHashedRelationWithUniqueKeysExceededError() } @@ -809,8 +811,8 @@ private[execution] final class LongToUnsafeRowMap( } else { // there are some values for this key, put the address in the front of them. val pointer = toOffset(address) + toSize(address) - Platform.putLong(page.getBaseObject, page.getBaseOffset + pointer, array(pos + 1)) - array(pos + 1) = address + Platform.putLong(page.getBaseObject, page.getBaseOffset + pointer, array.get(pos + 1)) + array.set(pos + 1, address) } } @@ -834,19 +836,20 @@ private[execution] final class LongToUnsafeRowMap( private def growArray(): Unit = { var old_array = array - val n = array.length + val n = Math.toIntExact(array.size()) numKeys = 0 - array = new UnsafeLongArray(n * 2) + array = allocateArray(n * 2) + array.zeroOut() mask = n * 2 - 2 var i = 0 - while (i < old_array.length) { - if (old_array(i + 1) > 0) { - val key = old_array(i) - updateIndex(key, findKeyPosition(key), old_array(i + 1)) + while (i < old_array.size()) { + if (old_array.get(i + 1) > 0) { + val key = old_array.get(i) + updateIndex(key, findKeyPosition(key), old_array.get(i + 1)) } i += 2 } - old_array.free() + freeArray(old_array) old_array = null // release the reference to old array } @@ -857,17 +860,18 @@ private[execution] final class LongToUnsafeRowMap( val range = maxKey - minKey // Convert to dense mode if it does not require more memory or could fit within L1 cache // SPARK-16740: Make sure range doesn't overflow if minKey has a large negative value - if (range >= 0 && (range < array.length || range < 1024)) { - val denseArray = new UnsafeLongArray((range + 1).toInt) + if (range >= 0 && (range < array.size() || range < 1024)) { + val denseArray = allocateArray((range + 1).toInt) + denseArray.zeroOut() var i = 0 - while (i < array.length) { - if (array(i + 1) > 0) { - val idx = (array(i) - minKey).toInt - denseArray(idx) = array(i + 1) + while (i < array.size()) { + if (array.get(i + 1) > 0) { + val idx = (array.get(i) - minKey).toInt + denseArray.set(idx, array.get(i + 1)) } i += 2 } - array.free() + freeArray(array) array = denseArray isDense = true } @@ -882,7 +886,7 @@ private[execution] final class LongToUnsafeRowMap( page = null } if (array != null) { - array.free() + freeArray(array) array = null } } @@ -913,9 +917,9 @@ private[execution] final class LongToUnsafeRowMap( writeLong(numKeys) writeLong(numValues) - writeLong(array.length) - writeBytes(writeBuffer, - array.memoryBlock.getBaseObject, array.memoryBlock.getBaseOffset, array.length) + writeLong(array.size()) + writeBytes(writeBuffer, array.memoryBlock.getBaseObject, array.memoryBlock.getBaseOffset, + Math.toIntExact(array.size())) val used = ((cursor - page.getBaseOffset) / 8).toInt writeLong(used) writeBytes(writeBuffer, page.getBaseObject, page.getBaseOffset, used) @@ -957,8 +961,8 @@ private[execution] final class LongToUnsafeRowMap( val length = readLong().toInt mask = length - 2 - array.free() - array = new UnsafeLongArray(length) + freeArray(array) + array = allocateArray(length) readData(readBuffer, array.memoryBlock.getBaseObject, array.memoryBlock.getBaseOffset, length) val pageLength = readLong().toInt freePage(page) @@ -975,32 +979,11 @@ private[execution] final class LongToUnsafeRowMap( override def read(kryo: Kryo, in: Input): Unit = { read(() => in.readBoolean(), () => in.readLong(), in.readBytes) } - - private class UnsafeLongArray(val length: Int) { - val memoryBlock: MemoryBlock = allocatePage(length * 8) - - for (i <- 0 until length) { - update(i, 0) - } - - def apply(index: Int): Long = { - Platform.getLong(memoryBlock.getBaseObject, memoryBlock.getBaseOffset + index * 8) - } - - def update(index: Int, value: Long): Unit = { - Platform.putLong(memoryBlock.getBaseObject, memoryBlock.getBaseOffset + index * 8, value) - } - - def free(): Unit = { - freePage(memoryBlock) - } - } } class LongHashedRelation( private var nFields: Int, private var map: LongToUnsafeRowMap) extends HashedRelation with Externalizable { - private var resultRow: UnsafeRow = new UnsafeRow(nFields) // Needed for serialization (it is public to make Java serialization work)