diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index 27aba1f7f2dfb..989ad8b0dc41e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -1178,13 +1178,6 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE cause = null) } - def cannotAcquireMemoryToBuildLongHashedRelationError(size: Long, got: Long): Throwable = { - new SparkException( - errorClass = "_LEGACY_ERROR_TEMP_2106", - messageParameters = Map("size" -> size.toString(), "got" -> got.toString()), - cause = null) - } - def cannotAcquireMemoryToBuildUnsafeHashedRelationError(): Throwable = { new SparkOutOfMemoryError( "_LEGACY_ERROR_TEMP_2107", diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index c67b55fd1d50c..242185e803577 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -22,7 +22,7 @@ import java.io._ import com.esotericsoftware.kryo.{Kryo, KryoSerializable} import com.esotericsoftware.kryo.io.{Input, Output} -import org.apache.spark.{SparkConf, SparkEnv, SparkException, SparkUnsupportedOperationException} +import org.apache.spark.{SparkConf, SparkEnv, SparkUnsupportedOperationException} import org.apache.spark.internal.config.{BUFFER_PAGESIZE, MEMORY_OFFHEAP_ENABLED} import org.apache.spark.memory._ import org.apache.spark.sql.catalyst.InternalRow @@ -31,7 +31,9 @@ import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.types.LongType import org.apache.spark.unsafe.Platform +import org.apache.spark.unsafe.array.LongArray import org.apache.spark.unsafe.map.BytesToBytesMap +import org.apache.spark.unsafe.memory.MemoryBlock import org.apache.spark.util.{KnownSizeEstimation, Utils} /** @@ -535,7 +537,7 @@ private[execution] final class LongToUnsafeRowMap( val mm: TaskMemoryManager, capacity: Int, ignoresDuplicatedKey: Boolean = false) - extends MemoryConsumer(mm, MemoryMode.ON_HEAP) with Externalizable with KryoSerializable { + extends MemoryConsumer(mm, mm.getTungstenMemoryMode) with Externalizable with KryoSerializable { // Whether the keys are stored in dense mode or not. private var isDense = false @@ -550,15 +552,15 @@ private[execution] final class LongToUnsafeRowMap( // // Sparse mode: [key1] [offset1 | size1] [key2] [offset | size2] ... // Dense mode: [offset1 | size1] [offset2 | size2] - private var array: Array[Long] = null + private var array: LongArray = null private var mask: Int = 0 // The page to store all bytes of UnsafeRow and the pointer to next rows. // [row1][pointer1] [row2][pointer2] - private var page: Array[Long] = null + private var page: MemoryBlock = null // Current write cursor in the page. - private var cursor: Long = Platform.LONG_ARRAY_OFFSET + private var cursor: Long = -1 // The number of bits for size in address private val SIZE_BITS = 28 @@ -583,24 +585,16 @@ private[execution] final class LongToUnsafeRowMap( 0) } - private def ensureAcquireMemory(size: Long): Unit = { - // do not support spilling - val got = acquireMemory(size) - if (got < size) { - freeMemory(got) - throw QueryExecutionErrors.cannotAcquireMemoryToBuildLongHashedRelationError(size, got) - } - } - private def init(): Unit = { if (mm != null) { require(capacity < 512000000, "Cannot broadcast 512 million or more rows") var n = 1 while (n < capacity) n *= 2 - ensureAcquireMemory(n * 2L * 8 + (1 << 20)) - array = new Array[Long](n * 2) + array = allocateArray(n * 2) + array.zeroOut() mask = n * 2 - 2 - page = new Array[Long](1 << 17) // 1M bytes + page = allocatePage(1 << 20) // 1M bytes + cursor = page.getBaseOffset } } @@ -616,7 +610,7 @@ private[execution] final class LongToUnsafeRowMap( /** * Returns total memory consumption. */ - def getTotalMemoryConsumption: Long = array.length * 8L + page.length * 8L + def getTotalMemoryConsumption: Long = array.size() * 8L + page.size() /** * Returns the first slot of array that store the keys (sparse mode). @@ -632,11 +626,11 @@ private[execution] final class LongToUnsafeRowMap( private def nextSlot(pos: Int): Int = (pos + 2) & mask private[this] def toAddress(offset: Long, size: Int): Long = { - ((offset - Platform.LONG_ARRAY_OFFSET) << SIZE_BITS) | size + (offset << SIZE_BITS) | size } private[this] def toOffset(address: Long): Long = { - (address >>> SIZE_BITS) + Platform.LONG_ARRAY_OFFSET + (address >>> SIZE_BITS) } private[this] def toSize(address: Long): Int = { @@ -644,7 +638,7 @@ private[execution] final class LongToUnsafeRowMap( } private def getRow(address: Long, resultRow: UnsafeRow): UnsafeRow = { - resultRow.pointTo(page, toOffset(address), toSize(address)) + resultRow.pointTo(page.getBaseObject, page.getBaseOffset + toOffset(address), toSize(address)) resultRow } @@ -654,16 +648,16 @@ private[execution] final class LongToUnsafeRowMap( def getValue(key: Long, resultRow: UnsafeRow): UnsafeRow = { if (isDense) { if (key >= minKey && key <= maxKey) { - val value = array((key - minKey).toInt) + val value = array.get((key - minKey).toInt) if (value > 0) { return getRow(value, resultRow) } } } else { var pos = firstSlot(key) - while (array(pos + 1) != 0) { - if (array(pos) == key) { - return getRow(array(pos + 1), resultRow) + while (array.get(pos + 1) != 0) { + if (array.get(pos) == key) { + return getRow(array.get(pos + 1), resultRow) } pos = nextSlot(pos) } @@ -681,8 +675,8 @@ private[execution] final class LongToUnsafeRowMap( override def next(): UnsafeRow = { val offset = toOffset(addr) val size = toSize(addr) - resultRow.pointTo(page, offset, size) - addr = Platform.getLong(page, offset + size) + resultRow.pointTo(page.getBaseObject, page.getBaseOffset + offset, size) + addr = Platform.getLong(page.getBaseObject, page.getBaseOffset + offset + size) resultRow } } @@ -694,16 +688,16 @@ private[execution] final class LongToUnsafeRowMap( def get(key: Long, resultRow: UnsafeRow): Iterator[UnsafeRow] = { if (isDense) { if (key >= minKey && key <= maxKey) { - val value = array((key - minKey).toInt) + val value = array.get((key - minKey).toInt) if (value > 0) { return valueIter(value, resultRow) } } } else { var pos = firstSlot(key) - while (array(pos + 1) != 0) { - if (array(pos) == key) { - return valueIter(array(pos + 1), resultRow) + while (array.get(pos + 1) != 0) { + if (array.get(pos) == key) { + return valueIter(array.get(pos + 1), resultRow) } pos = nextSlot(pos) } @@ -728,8 +722,8 @@ private[execution] final class LongToUnsafeRowMap( override def hasNext: Boolean = { // go to the next key if the current key slot is empty - while (pos + step < array.length) { - if (array(pos + step) > 0) { + while (pos + step < array.size()) { + if (array.get(pos + step) > 0) { return true } pos += step + 1 @@ -742,7 +736,7 @@ private[execution] final class LongToUnsafeRowMap( throw QueryExecutionErrors.endOfIteratorError() } else { // the key is retrieved based on the map mode - val ret = if (isDense) minKey + pos else array(pos) + val ret = if (isDense) minKey + pos else array.get(pos) // advance the cursor to the next index pos += step + 1 row.setLong(0, ret) @@ -762,7 +756,7 @@ private[execution] final class LongToUnsafeRowMap( } val pos = findKeyPosition(key) - if (ignoresDuplicatedKey && array(pos + 1) != 0) { + if (ignoresDuplicatedKey && array.get(pos + 1) != 0) { return } @@ -777,18 +771,19 @@ private[execution] final class LongToUnsafeRowMap( // copy the bytes of UnsafeRow val offset = cursor - Platform.copyMemory(row.getBaseObject, row.getBaseOffset, page, cursor, row.getSizeInBytes) + Platform.copyMemory(row.getBaseObject, row.getBaseOffset, page.getBaseObject, cursor, + row.getSizeInBytes) cursor += row.getSizeInBytes - Platform.putLong(page, cursor, 0) + Platform.putLong(page.getBaseObject, cursor, 0) cursor += 8 numValues += 1 - updateIndex(key, pos, toAddress(offset, row.getSizeInBytes)) + updateIndex(key, pos, toAddress(offset - page.getBaseOffset, row.getSizeInBytes)) } private def findKeyPosition(key: Long): Int = { var pos = firstSlot(key) - assert(numKeys < array.length / 2) - while (array(pos) != key && array(pos + 1) != 0) { + assert(numKeys < array.size() / 2) + while (array.get(pos) != key && array.get(pos + 1) != 0) { pos = nextSlot(pos) } pos @@ -798,17 +793,17 @@ private[execution] final class LongToUnsafeRowMap( * Update the address in array for given key. */ private def updateIndex(key: Long, pos: Int, address: Long): Unit = { - if (array(pos + 1) == 0) { + if (array.get(pos + 1) == 0) { // this is the first value for this key, put the address in array. - array(pos) = key - array(pos + 1) = address + array.set(pos, key) + array.set(pos + 1, address) numKeys += 1 - if (numKeys * 4 > array.length) { + if (numKeys * 4 > array.size()) { // reach half of the capacity - if (array.length < (1 << 30)) { + if (array.size() < (1 << 30)) { // Cannot allocate an array with 2G elements growArray() - } else if (numKeys > array.length / 2 * 0.75) { + } else if (numKeys > array.size() / 2 * 0.75) { // The fill ratio should be less than 0.75 throw QueryExecutionErrors.cannotBuildHashedRelationWithUniqueKeysExceededError() } @@ -816,46 +811,46 @@ private[execution] final class LongToUnsafeRowMap( } else { // there are some values for this key, put the address in the front of them. val pointer = toOffset(address) + toSize(address) - Platform.putLong(page, pointer, array(pos + 1)) - array(pos + 1) = address + Platform.putLong(page.getBaseObject, page.getBaseOffset + pointer, array.get(pos + 1)) + array.set(pos + 1, address) } } private def grow(inputRowSize: Int): Unit = { // There is 8 bytes for the pointer to next value - val neededNumWords = (cursor - Platform.LONG_ARRAY_OFFSET + 8 + inputRowSize + 7) / 8 - if (neededNumWords > page.length) { + val usedBytes = cursor - page.getBaseOffset + val neededNumWords = (usedBytes + 8 + inputRowSize + 7) / 8 + if (neededNumWords > page.size() / 8) { if (neededNumWords > (1 << 30)) { throw QueryExecutionErrors.cannotBuildHashedRelationLargerThan8GError() } - val newNumWords = math.max(neededNumWords, math.min(page.length * 2, 1 << 30)) - ensureAcquireMemory(newNumWords * 8L) - val newPage = new Array[Long](newNumWords.toInt) - Platform.copyMemory(page, Platform.LONG_ARRAY_OFFSET, newPage, Platform.LONG_ARRAY_OFFSET, - cursor - Platform.LONG_ARRAY_OFFSET) - val used = page.length + val newNumWords = math.max(neededNumWords, math.min(page.size() / 8 * 2, 1 << 30)) + val newPage = allocatePage(newNumWords.toInt * 8) + Platform.copyMemory(page.getBaseObject, page.getBaseOffset, newPage.getBaseObject, + newPage.getBaseOffset, usedBytes) + freePage(page) page = newPage - freeMemory(used * 8L) + cursor = page.getBaseOffset + usedBytes } } private def growArray(): Unit = { var old_array = array - val n = array.length + val n = Math.toIntExact(array.size()) numKeys = 0 - ensureAcquireMemory(n * 2 * 8L) - array = new Array[Long](n * 2) + array = allocateArray(n * 2) + array.zeroOut() mask = n * 2 - 2 var i = 0 - while (i < old_array.length) { - if (old_array(i + 1) > 0) { - val key = old_array(i) - updateIndex(key, findKeyPosition(key), old_array(i + 1)) + while (i < old_array.size()) { + if (old_array.get(i + 1) > 0) { + val key = old_array.get(i) + updateIndex(key, findKeyPosition(key), old_array.get(i + 1)) } i += 2 } + freeArray(old_array) old_array = null // release the reference to old array - freeMemory(n * 8L) } /** @@ -865,27 +860,20 @@ private[execution] final class LongToUnsafeRowMap( val range = maxKey - minKey // Convert to dense mode if it does not require more memory or could fit within L1 cache // SPARK-16740: Make sure range doesn't overflow if minKey has a large negative value - if (range >= 0 && (range < array.length || range < 1024)) { - try { - ensureAcquireMemory((range + 1) * 8L) - } catch { - case e: SparkException => - // there is no enough memory to convert - return - } - val denseArray = new Array[Long]((range + 1).toInt) + if (range >= 0 && (range < array.size() || range < 1024)) { + val denseArray = allocateArray((range + 1).toInt) + denseArray.zeroOut() var i = 0 - while (i < array.length) { - if (array(i + 1) > 0) { - val idx = (array(i) - minKey).toInt - denseArray(idx) = array(i + 1) + while (i < array.size()) { + if (array.get(i + 1) > 0) { + val idx = (array.get(i) - minKey).toInt + denseArray.set(idx, array.get(i + 1)) } i += 2 } - val old_length = array.length + freeArray(array) array = denseArray isDense = true - freeMemory(old_length * 8L) } } @@ -894,25 +882,26 @@ private[execution] final class LongToUnsafeRowMap( */ def free(): Unit = { if (page != null) { - freeMemory(page.length * 8L) + freePage(page) page = null } if (array != null) { - freeMemory(array.length * 8L) + freeArray(array) array = null } } - private def writeLongArray( + private def writeBytes( writeBuffer: (Array[Byte], Int, Int) => Unit, - arr: Array[Long], + baseObject: Object, + baseOffset: Long, len: Int): Unit = { val buffer = new Array[Byte](4 << 10) - var offset: Long = Platform.LONG_ARRAY_OFFSET - val end = len * 8L + Platform.LONG_ARRAY_OFFSET + var offset: Long = baseOffset + val end = len * 8L + offset while (offset < end) { val size = Math.min(buffer.length, end - offset) - Platform.copyMemory(arr, offset, buffer, Platform.BYTE_ARRAY_OFFSET, size) + Platform.copyMemory(baseObject, offset, buffer, Platform.BYTE_ARRAY_OFFSET, size) writeBuffer(buffer, 0, size.toInt) offset += size } @@ -928,11 +917,12 @@ private[execution] final class LongToUnsafeRowMap( writeLong(numKeys) writeLong(numValues) - writeLong(array.length) - writeLongArray(writeBuffer, array, array.length) - val used = ((cursor - Platform.LONG_ARRAY_OFFSET) / 8).toInt + writeLong(array.size()) + writeBytes(writeBuffer, array.memoryBlock.getBaseObject, array.memoryBlock.getBaseOffset, + Math.toIntExact(array.size())) + val used = ((cursor - page.getBaseOffset) / 8).toInt writeLong(used) - writeLongArray(writeBuffer, page, used) + writeBytes(writeBuffer, page.getBaseObject, page.getBaseOffset, used) } override def writeExternal(output: ObjectOutput): Unit = { @@ -943,20 +933,20 @@ private[execution] final class LongToUnsafeRowMap( write(out.writeBoolean, out.writeLong, out.write) } - private def readLongArray( + private def readData( readBuffer: (Array[Byte], Int, Int) => Unit, - length: Int): Array[Long] = { - val array = new Array[Long](length) + baseObject: Object, + baseOffset: Long, + length: Int): Unit = { val buffer = new Array[Byte](4 << 10) - var offset: Long = Platform.LONG_ARRAY_OFFSET - val end = length * 8L + Platform.LONG_ARRAY_OFFSET + var offset: Long = baseOffset + val end = length * 8L + baseOffset while (offset < end) { val size = Math.min(buffer.length, end - offset) readBuffer(buffer, 0, size.toInt) - Platform.copyMemory(buffer, Platform.BYTE_ARRAY_OFFSET, array, offset, size) + Platform.copyMemory(buffer, Platform.BYTE_ARRAY_OFFSET, baseObject, offset, size) offset += size } - array } private def read( @@ -971,11 +961,15 @@ private[execution] final class LongToUnsafeRowMap( val length = readLong().toInt mask = length - 2 - array = readLongArray(readBuffer, length) + freeArray(array) + array = allocateArray(length) + readData(readBuffer, array.memoryBlock.getBaseObject, array.memoryBlock.getBaseOffset, length) val pageLength = readLong().toInt - page = readLongArray(readBuffer, pageLength) + freePage(page) + page = allocatePage(pageLength * 8) + readData(readBuffer, page.getBaseObject, page.getBaseOffset, pageLength) // Restore cursor variable to make this map able to be serialized again on executors. - cursor = pageLength * 8 + Platform.LONG_ARRAY_OFFSET + cursor = pageLength * 8 + page.getBaseOffset } override def readExternal(in: ObjectInput): Unit = { @@ -990,7 +984,6 @@ private[execution] final class LongToUnsafeRowMap( class LongHashedRelation( private var nFields: Int, private var map: LongToUnsafeRowMap) extends HashedRelation with Externalizable { - private var resultRow: UnsafeRow = new UnsafeRow(nFields) // Needed for serialization (it is public to make Java serialization work) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index e46761f5cd048..f8fa2f5fe35f4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -25,9 +25,10 @@ import scala.util.Random import org.apache.spark.SparkConf import org.apache.spark.SparkException -import org.apache.spark.internal.config._ +import org.apache.spark.internal.config.{MEMORY_OFFHEAP_ENABLED, MEMORY_OFFHEAP_SIZE} import org.apache.spark.internal.config.Kryo._ import org.apache.spark.memory.{TaskMemoryManager, UnifiedMemoryManager} +import org.apache.spark.network.util.ByteUnit import org.apache.spark.serializer.KryoSerializer import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ @@ -39,9 +40,13 @@ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.collection.CompactBuffer -class HashedRelationSuite extends SharedSparkSession { +abstract class HashedRelationSuite extends SharedSparkSession { + protected def useOffHeapMemoryMode: Boolean + val umm = new UnifiedMemoryManager( - new SparkConf().set(MEMORY_OFFHEAP_ENABLED.key, "false"), + new SparkConf() + .set(MEMORY_OFFHEAP_ENABLED, useOffHeapMemoryMode) + .set(MEMORY_OFFHEAP_SIZE, ByteUnit.GiB.toBytes(1L)), Runtime.getRuntime.maxMemory, Runtime.getRuntime.maxMemory / 2, 1) @@ -754,3 +759,11 @@ class HashedRelationSuite extends SharedSparkSession { } } } + +class HashedRelationOnHeapSuite extends HashedRelationSuite { + override protected def useOffHeapMemoryMode: Boolean = false +} + +class HashedRelationOffHeapSuite extends HashedRelationSuite { + override protected def useOffHeapMemoryMode: Boolean = true +}