Skip to content
This repository has been archived by the owner on Sep 18, 2023. It is now read-only.

[NSE-68][Shuffle] Adaptive compression select in Shuffle. #69

Merged
merged 3 commits into from
Feb 14, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,17 @@

package com.intel.oap.vectorized;

import io.netty.buffer.ArrowBuf;
import org.apache.arrow.flatbuf.MessageHeader;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.vector.FieldVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.dictionary.Dictionary;
import org.apache.arrow.vector.ipc.ArrowStreamReader;
import org.apache.arrow.vector.ipc.message.ArrowDictionaryBatch;
import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
import org.apache.arrow.vector.ipc.message.MessageResult;
import org.apache.arrow.vector.ipc.message.MessageSerializer;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.Schema;
import org.apache.arrow.vector.util.DictionaryUtility;
Expand All @@ -36,11 +41,16 @@
* ArrowRecordBatches.
*/
public class ArrowCompressedStreamReader extends ArrowStreamReader {
private String compressType;

public ArrowCompressedStreamReader(InputStream in, BufferAllocator allocator) {
super(in, allocator);
}

public String GetCompressType() {
return compressType;
}

protected void initialize() throws IOException {
Schema originalSchema = readSchema();
List<Field> fields = new ArrayList<>();
Expand All @@ -60,6 +70,47 @@ protected void initialize() throws IOException {
this.dictionaries = Collections.unmodifiableMap(dictionaries);
}

/**
* Load the next ArrowRecordBatch to the vector schema root if available.
*
* @return true if a batch was read, false on EOS
* @throws IOException on error
*/
public boolean loadNextBatch() throws IOException {
prepareLoadNextBatch();
MessageResult result = messageReader.readNext();

// Reached EOS
if (result == null) {
return false;
}
// Get the compress type from customMetadata. Currently the customMetadata only have one entry.
compressType = result.getMessage().customMetadata(0).value();

if (result.getMessage().headerType() == MessageHeader.RecordBatch) {
ArrowBuf bodyBuffer = result.getBodyBuffer();

// For zero-length batches, need an empty buffer to deserialize the batch
if (bodyBuffer == null) {
bodyBuffer = allocator.getEmpty();
}

ArrowRecordBatch batch = MessageSerializer.deserializeRecordBatch(result.getMessage(), bodyBuffer);
loadRecordBatch(batch);
checkDictionaries();
return true;
} else if (result.getMessage().headerType() == MessageHeader.DictionaryBatch) {
// if it's dictionary message, read dictionary message out and continue to read unless get a batch or eos.
ArrowDictionaryBatch dictionaryBatch = readDictionary(result);
loadDictionary(dictionaryBatch);
loadedDictionaryCount++;
return loadNextBatch();
} else {
throw new IOException("Expected RecordBatch or DictionaryBatch but header was " +
result.getMessage().headerType());
}
}

@Override
protected void loadRecordBatch(ArrowRecordBatch batch) {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,19 @@ public native long nativeMake(
* @param numRows Rows per batch
* @param bufAddrs Addresses of buffers
* @param bufSizes Sizes of buffers
* @param firstRecordBatch whether this record batch is the first
* record batch in the first partition.
* @return If the firstRecorBatch is true, return the compressed size, otherwise -1.
*/
public native void split(long splitterId, int numRows, long[] bufAddrs, long[] bufSizes)
public native long split(
long splitterId, int numRows, long[] bufAddrs, long[] bufSizes, boolean firstRecordBatch)
throws IOException;

/**
* Update the compress type.
*/
public native void setCompressType(long splitterId, String compressType);

/**
* Write the data remained in the buffers hold by native splitter to each partition's temporary
* file. And stop processing splitting
Expand Down
6 changes: 4 additions & 2 deletions core/src/main/scala/com/intel/oap/ColumnarPluginConfig.scala
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,10 @@ class ColumnarPluginConfig(conf: SQLConf) {
// and the cached buffers will be spilled when reach maximum memory.
val columnarShufflePreferSpill: Boolean =
conf.getConfString("spark.oap.sql.columnar.shuffle.preferSpill", "true").toBoolean
val columnarShuffleUseCustomizedCompression: Boolean =
conf.getConfString("spark.oap.sql.columnar.shuffle.customizedCompression", "false").toBoolean

// The supported customized compression codec is lz4 and fastpfor.
val columnarShuffleUseCustomizedCompressionCodec: String =
conf.getConfString("spark.oap.sql.columnar.shuffle.customizedCompression.codec", "lz4")
val isTesting: Boolean =
conf.getConfString("spark.oap.sql.columnar.testing", "false").toBoolean
val numaBindingInfo: ColumnarNumaBindingInfo = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,12 @@ object ConverterUtils extends Logging {
out.toByteArray
}

@throws[IOException]
def getSchemaFromBytesBuf(schema: Array[Byte]): Schema = {
val in: ByteArrayInputStream = new ByteArrayInputStream(schema)
MessageSerializer.deserializeSchema(new ReadChannel(Channels.newChannel(in)))
}

@throws[GandivaException]
def getExprListBytesBuf(exprs: List[ExpressionTree]): Array[Byte] = {
val builder: ExpressionList.Builder = GandivaTypes.ExpressionList.newBuilder
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,7 @@ private class ArrowColumnarBatchSerializerInstance(

private val compressionEnabled =
SparkEnv.get.conf.getBoolean("spark.shuffle.compress", true)
private val compressionCodec =
if (ColumnarPluginConfig
.getConf
.columnarShuffleUseCustomizedCompression) {
"fastpfor"
} else {
SparkEnv.get.conf.get("spark.io.compression.codec", "lz4")
}

private val allocator: BufferAllocator = SparkMemoryUtils.contextAllocator()
.newChildAllocator("ArrowColumnarBatch deserialize", 0, Long.MaxValue)

Expand Down Expand Up @@ -232,7 +225,7 @@ private class ArrowColumnarBatchSerializerInstance(

val builder = jniWrapper.decompress(
schemaHolderId,
compressionCodec,
reader.asInstanceOf[ArrowCompressedStreamReader].GetCompressType(),
root.getRowCount,
bufAddrs.toArray,
bufSizes.toArray,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@ import java.io.IOException

import com.google.common.annotations.VisibleForTesting
import com.intel.oap.ColumnarPluginConfig
import com.intel.oap.expression.ConverterUtils
import com.intel.oap.spark.sql.execution.datasources.v2.arrow.Spiller
import com.intel.oap.vectorized.{ArrowWritableColumnVector, ShuffleSplitterJniWrapper, SplitResult}
import org.apache.arrow.vector.types.pojo.ArrowType.ArrowTypeID
import org.apache.spark._
import org.apache.spark.internal.Logging
import org.apache.spark.memory.MemoryConsumer
Expand Down Expand Up @@ -57,12 +59,11 @@ class ColumnarShuffleWriter[K, V](
private val localDirs = blockManager.diskBlockManager.localDirs.mkString(",")
private val nativeBufferSize =
conf.getInt("spark.sql.execution.arrow.maxRecordsPerBatch", 4096)
private val compressionCodec = if (conf.getBoolean("spark.shuffle.compress", true)) {
if (ColumnarPluginConfig.getConf.columnarShuffleUseCustomizedCompression) {
"fastpfor"
} else {
conf.get("spark.io.compression.codec", "lz4")
}

private val customizedCompressCodec =
ColumnarPluginConfig.getConf.columnarShuffleUseCustomizedCompressionCodec
private val defaultCompressionCodec = if (conf.getBoolean("spark.shuffle.compress", true)) {
conf.get("spark.io.compression.codec", "lz4")
} else {
"uncompressed"
}
Expand All @@ -76,6 +77,8 @@ class ColumnarShuffleWriter[K, V](

private var partitionLengths: Array[Long] = _

private var firstRecordBatch: Boolean = true

@throws[IOException]
override def write(records: Iterator[Product2[K, V]]): Unit = {
if (!records.hasNext) {
Expand All @@ -90,7 +93,7 @@ class ColumnarShuffleWriter[K, V](
nativeSplitter = jniWrapper.make(
dep.nativePartitioning,
nativeBufferSize,
compressionCodec,
defaultCompressionCodec,
dataTmp.getAbsolutePath,
blockManager.subDirsPerLocalDir,
localDirs,
Expand Down Expand Up @@ -128,7 +131,39 @@ class ColumnarShuffleWriter[K, V](
dep.dataSize.add(bufSizes.sum)

val startTime = System.nanoTime()
jniWrapper.split(nativeSplitter, cb.numRows, bufAddrs.toArray, bufSizes.toArray)

val existingIntType: Boolean = if (firstRecordBatch) {
// Check whether the recordbatch contain the Int data type.
val arrowSchema = ConverterUtils.getSchemaFromBytesBuf(dep.nativePartitioning.getSchema)
import scala.collection.JavaConverters._
arrowSchema.getFields.asScala.find(_.getType.getTypeID == ArrowTypeID.Int).nonEmpty
} else false

// Choose the compress type based on the compress size of the first record batch.
if (firstRecordBatch && conf.getBoolean("spark.shuffle.compress", true) &&
customizedCompressCodec != defaultCompressionCodec && existingIntType) {
// Compute the default compress size
jniWrapper.setCompressType(nativeSplitter, defaultCompressionCodec)
val defaultCompressedSize = jniWrapper.split(
nativeSplitter, cb.numRows, bufAddrs.toArray, bufSizes.toArray, firstRecordBatch)

// Compute the custom compress size.
jniWrapper.setCompressType(nativeSplitter, customizedCompressCodec)
val customizedCompressedSize = jniWrapper.split(
nativeSplitter, cb.numRows, bufAddrs.toArray, bufSizes.toArray, firstRecordBatch)

// Choose the compress algorithm based on the compress size.
if (customizedCompressedSize != -1 && defaultCompressedSize != -1) {
if (customizedCompressedSize > defaultCompressedSize) {
jniWrapper.setCompressType(nativeSplitter, defaultCompressionCodec)
}
} else {
logError("Failed to compute the compress size in the first record batch")
}
}
firstRecordBatch = false

jniWrapper.split(nativeSplitter, cb.numRows, bufAddrs.toArray, bufSizes.toArray, firstRecordBatch)
dep.splitTime.add(System.nanoTime() - startTime)
dep.numInputRows.add(cb.numRows)
writeMetrics.incRecordsWritten(1)
Expand Down
46 changes: 34 additions & 12 deletions cpp/src/jni/jni_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1395,32 +1395,50 @@ Java_com_intel_oap_vectorized_ShuffleSplitterJniWrapper_nativeMake(
return shuffle_splitter_holder_.Insert(std::shared_ptr<Splitter>(splitter));
}

JNIEXPORT void JNICALL Java_com_intel_oap_vectorized_ShuffleSplitterJniWrapper_split(
JNIEnv* env, jobject, jlong splitter_id, jint num_rows, jlongArray buf_addrs,
jlongArray buf_sizes) {
JNIEXPORT void JNICALL Java_com_intel_oap_vectorized_ShuffleSplitterJniWrapper_setCompressType(
JNIEnv* env, jobject, jlong splitter_id, jstring compression_type_jstr) {
auto splitter = shuffle_splitter_holder_.Lookup(splitter_id);
if (!splitter) {
std::string error_message = "Invalid splitter id " + std::to_string(splitter_id);
env->ThrowNew(illegal_argument_exception_class, error_message.c_str());
return;
}

if (compression_type_jstr != NULL) {
auto compression_type_result = GetCompressionType(env, compression_type_jstr);
if (compression_type_result.status().ok()) {
splitter->SetCompressType(compression_type_result.MoveValueUnsafe());
}
}
return;
}

JNIEXPORT jlong JNICALL Java_com_intel_oap_vectorized_ShuffleSplitterJniWrapper_split(
JNIEnv* env, jobject, jlong splitter_id, jint num_rows, jlongArray buf_addrs,
jlongArray buf_sizes, jboolean first_record_batch) {
auto splitter = shuffle_splitter_holder_.Lookup(splitter_id);
if (!splitter) {
std::string error_message = "Invalid splitter id " + std::to_string(splitter_id);
env->ThrowNew(illegal_argument_exception_class, error_message.c_str());
return -1;
}
if (buf_addrs == NULL) {
env->ThrowNew(illegal_argument_exception_class,
std::string("Native split: buf_addrs can't be null").c_str());
return;
return -1;
}
if (buf_sizes == NULL) {
env->ThrowNew(illegal_argument_exception_class,
std::string("Native split: buf_sizes can't be null").c_str());
return;
return -1;
}

int in_bufs_len = env->GetArrayLength(buf_addrs);
if (in_bufs_len != env->GetArrayLength(buf_sizes)) {
env->ThrowNew(
illegal_argument_exception_class,
std::string("Native split: length of buf_addrs and buf_sizes mismatch").c_str());
return;
return -1;
}

jlong* in_buf_addrs = env->GetLongArrayElements(buf_addrs, JNI_FALSE);
Expand All @@ -1440,17 +1458,21 @@ JNIEXPORT void JNICALL Java_com_intel_oap_vectorized_ShuffleSplitterJniWrapper_s
std::string("Native split: make record batch failed, error message is " +
status.message())
.c_str());
return;
return -1;
}

status = splitter->Split(*in);

if (!status.ok()) {
// Throw IOException
env->ThrowNew(io_exception_class,
if (first_record_batch) {
return splitter->CompressedSize(*in);
} else {
status = splitter->Split(*in);
if (!status.ok()) {
// Throw IOException
env->ThrowNew(io_exception_class,
std::string("Native split: splitter split failed, error message is " +
status.message())
.c_str());
}
return -1;
}
}

Expand Down
16 changes: 16 additions & 0 deletions cpp/src/shuffle/splitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,22 @@ arrow::Status Splitter::Init() {
return arrow::Status::OK();
}

int64_t Splitter::CompressedSize(const arrow::RecordBatch& rb) {
auto payload = std::make_shared<arrow::ipc::internal::IpcPayload>();
auto result = arrow::ipc::internal::GetRecordBatchPayload(
rb, options_.ipc_write_options, payload.get());
if (result.ok()) {
return payload.get()->body_length;
} else {
result.UnknownError("Failed to get the compressed size.");
return -1;
}
}

void Splitter::SetCompressType(arrow::Compression::type compressed_type) {
options_.ipc_write_options.compression = compressed_type;
}

arrow::Status Splitter::Split(const arrow::RecordBatch& rb) {
EVAL_START("split", options_.thread_id)
RETURN_NOT_OK(ComputeAndCountPartitionId(rb));
Expand Down
7 changes: 7 additions & 0 deletions cpp/src/shuffle/splitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@ class Splitter {
* id. The largest partition buffer will be spilled if memory allocation failure occurs.
*/
virtual arrow::Status Split(const arrow::RecordBatch&);

/**
* Compute the compresse size of record batch.
*/
virtual int64_t CompressedSize(const arrow::RecordBatch&);

/**
* For each partition, merge spilled file into shuffle data file and write any cached
Expand All @@ -64,6 +69,8 @@ class Splitter {
*/
arrow::Status SpillPartition(int32_t partition_id);

void SetCompressType(arrow::Compression::type compressed_type);

/**
* Spill for fixed size of partition data
*/
Expand Down