From 871a15cc5f60a1350e08d5804a82ae3fe3831b50 Mon Sep 17 00:00:00 2001 From: summaryzb Date: Mon, 20 Nov 2023 12:40:24 +0800 Subject: [PATCH] add row based config and fix style --- cpp/core/jni/JniWrapper.cc | 28 +-- .../uniffle/GlutenRssShuffleManager.java | 48 +++--- .../VeloxUniffleColumnarShuffleWriter.java | 162 ++++++++++-------- .../shuffle/writer/PartitionPusher.scala | 4 +- 4 files changed, 125 insertions(+), 117 deletions(-) diff --git a/cpp/core/jni/JniWrapper.cc b/cpp/core/jni/JniWrapper.cc index ad3e23e2e2e6..7d69817fb2bc 100644 --- a/cpp/core/jni/JniWrapper.cc +++ b/cpp/core/jni/JniWrapper.cc @@ -859,20 +859,20 @@ JNIEXPORT jlong JNICALL Java_io_glutenproject_vectorized_ShuffleWriterJniWrapper } else if (partitionWriterType == "uniffle") { shuffleWriterOptions.partition_writer_type = PartitionWriterType::kUniffle; jclass unifflePartitionPusherClass = - createGlobalClassReferenceOrError(env, "Lorg/apache/spark/shuffle/writer/PartitionPusher;"); - jmethodID unifflePushPartitionDataMethod = - getMethodIdOrError(env, unifflePartitionPusherClass, "pushPartitionData", "(I[B)I"); - if (pushBufferMaxSize > 0) { - shuffleWriterOptions.push_buffer_max_size = pushBufferMaxSize; - } - JavaVM* vm; - if (env->GetJavaVM(&vm) != JNI_OK) { - gluten::jniThrow("Unable to get JavaVM instance"); - } - // rename CelebornClient RssClient - std::shared_ptr celebornClient = - std::make_shared(vm, partitionPusher, unifflePushPartitionDataMethod); - partitionWriterCreator = std::make_shared(std::move(celebornClient)); + createGlobalClassReferenceOrError(env, "Lorg/apache/spark/shuffle/writer/PartitionPusher;"); + jmethodID unifflePushPartitionDataMethod = + getMethodIdOrError(env, unifflePartitionPusherClass, "pushPartitionData", "(I[B)I"); + if (pushBufferMaxSize > 0) { + shuffleWriterOptions.push_buffer_max_size = pushBufferMaxSize; + } + JavaVM* vm; + if (env->GetJavaVM(&vm) != JNI_OK) { + gluten::jniThrow("Unable to get JavaVM instance"); + } + // rename CelebornClient RssClient + std::shared_ptr celebornClient = + std::make_shared(vm, partitionPusher, unifflePushPartitionDataMethod); + partitionWriterCreator = std::make_shared(std::move(celebornClient)); } else { throw gluten::GlutenException("Unrecognizable partition writer type: " + partitionWriterType); } diff --git a/gluten-uniffle/velox/src/main/java/org/apache/spark/shuffle/gluten/uniffle/GlutenRssShuffleManager.java b/gluten-uniffle/velox/src/main/java/org/apache/spark/shuffle/gluten/uniffle/GlutenRssShuffleManager.java index 07c3cccc66a2..a5590ea8e488 100644 --- a/gluten-uniffle/velox/src/main/java/org/apache/spark/shuffle/gluten/uniffle/GlutenRssShuffleManager.java +++ b/gluten-uniffle/velox/src/main/java/org/apache/spark/shuffle/gluten/uniffle/GlutenRssShuffleManager.java @@ -16,8 +16,6 @@ */ package org.apache.spark.shuffle.gluten.uniffle; -import java.lang.reflect.Constructor; -import java.util.concurrent.ConcurrentHashMap; import org.apache.spark.ShuffleDependency; import org.apache.spark.SparkConf; import org.apache.spark.SparkEnv; @@ -26,18 +24,18 @@ import org.apache.spark.shuffle.ColumnarShuffleDependency; import org.apache.spark.shuffle.RssShuffleHandle; import org.apache.spark.shuffle.RssShuffleManager; +import org.apache.spark.shuffle.RssSparkConfig; import org.apache.spark.shuffle.ShuffleHandle; -import org.apache.spark.shuffle.ShuffleReadMetricsReporter; -import org.apache.spark.shuffle.ShuffleReader; import org.apache.spark.shuffle.ShuffleWriteMetricsReporter; import org.apache.spark.shuffle.ShuffleWriter; import org.apache.spark.shuffle.sort.ColumnarShuffleManager; -import org.apache.spark.shuffle.writer.RssShuffleWriter; import org.apache.spark.shuffle.writer.VeloxUniffleColumnarShuffleWriter; import org.apache.uniffle.common.exception.RssException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.lang.reflect.Constructor; + public class GlutenRssShuffleManager extends RssShuffleManager { private static final Logger LOG = LoggerFactory.getLogger(GlutenRssShuffleManager.class); private static final String GLUTEN_SHUFFLE_MANAGER_NAME = @@ -53,8 +51,7 @@ private ColumnarShuffleManager columnarShuffleManager() { synchronized (this) { if (_columnarShuffleManager == null) { _columnarShuffleManager = - initShuffleManager( - GLUTEN_SHUFFLE_MANAGER_NAME, sparkConf, isDriver()); + initShuffleManager(GLUTEN_SHUFFLE_MANAGER_NAME, sparkConf, isDriver()); } } } @@ -65,8 +62,7 @@ private RssShuffleManager vanillaUniffleShuffleManager() { if (_vanillaUniffleShuffleManager == null) { synchronized (this) { if (_vanillaUniffleShuffleManager == null) { - initShuffleManager( - VANILLA_UNIFFLE_SHUFFLE_MANAGER_NAME, sparkConf, isDriver()); + initShuffleManager(VANILLA_UNIFFLE_SHUFFLE_MANAGER_NAME, sparkConf, isDriver()); } } } @@ -77,18 +73,17 @@ private boolean isDriver() { return "driver".equals(SparkEnv.get().executorId()); } - private ColumnarShuffleManager initShuffleManager( - String name, SparkConf conf, boolean isDriver) { + private ColumnarShuffleManager initShuffleManager(String name, SparkConf conf, boolean isDriver) { Constructor constructor; ColumnarShuffleManager instance; try { Class klass = Class.forName(name); try { constructor = klass.getConstructor(conf.getClass(), Boolean.TYPE); - instance = (ColumnarShuffleManager)constructor.newInstance(conf, isDriver); + instance = (ColumnarShuffleManager) constructor.newInstance(conf, isDriver); } catch (NoSuchMethodException var7) { constructor = klass.getConstructor(conf.getClass()); - instance = (ColumnarShuffleManager)constructor.newInstance(conf); + instance = (ColumnarShuffleManager) constructor.newInstance(conf); } } catch (Exception e) { throw new RuntimeException("initColumnManager fail"); @@ -100,20 +95,21 @@ public GlutenRssShuffleManager(SparkConf conf, boolean isDriver) { super(conf, isDriver); // TODO conf set some config } + @Override public ShuffleHandle registerShuffle( int shuffleId, ShuffleDependency dependency) { return super.registerShuffle(shuffleId, dependency); } - - @Override public ShuffleWriter getWriter( ShuffleHandle handle, long mapId, TaskContext context, ShuffleWriteMetricsReporter metrics) { if (!(handle instanceof RssShuffleHandle)) { throw new RssException("Unexpected ShuffleHandle:" + handle.getClass().getName()); } + sparkConf.setIfMissing( + RssSparkConfig.SPARK_RSS_CONFIG_PREFIX + RssSparkConfig.RSS_ROW_BASED, "false"); RssShuffleHandle rssHandle = (RssShuffleHandle) handle; if (rssHandle.getDependency() instanceof ColumnarShuffleDependency) { setPusherAppId(rssHandle); @@ -124,16 +120,18 @@ public ShuffleWriter getWriter( } else { writeMetrics = context.taskMetrics().shuffleWriteMetrics(); } - return new VeloxUniffleColumnarShuffleWriter<>( - rssHandle.getAppId(), rssHandle.getShuffleId(), taskId, - context.taskAttemptId(), - writeMetrics, - this, - sparkConf, - shuffleWriteClient, - rssHandle, - this::markFailedTask, - context); + return new VeloxUniffleColumnarShuffleWriter<>( + rssHandle.getAppId(), + rssHandle.getShuffleId(), + taskId, + context.taskAttemptId(), + writeMetrics, + this, + sparkConf, + shuffleWriteClient, + rssHandle, + this::markFailedTask, + context); } else { return super.getWriter(handle, mapId, context, metrics); } diff --git a/gluten-uniffle/velox/src/main/java/org/apache/spark/shuffle/writer/VeloxUniffleColumnarShuffleWriter.java b/gluten-uniffle/velox/src/main/java/org/apache/spark/shuffle/writer/VeloxUniffleColumnarShuffleWriter.java index e4b5c735ef67..a0425e2fd64d 100644 --- a/gluten-uniffle/velox/src/main/java/org/apache/spark/shuffle/writer/VeloxUniffleColumnarShuffleWriter.java +++ b/gluten-uniffle/velox/src/main/java/org/apache/spark/shuffle/writer/VeloxUniffleColumnarShuffleWriter.java @@ -14,49 +14,50 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package org.apache.spark.shuffle.writer; import io.glutenproject.GlutenConfig; import io.glutenproject.columnarbatch.ColumnarBatches; -import io.glutenproject.memory.memtarget.Spiller; import io.glutenproject.memory.memtarget.MemoryTarget; +import io.glutenproject.memory.memtarget.Spiller; import io.glutenproject.memory.nmm.NativeMemoryManagers; import io.glutenproject.vectorized.ShuffleWriterJniWrapper; import io.glutenproject.vectorized.SplitResult; + +import org.apache.spark.SparkConf; import org.apache.spark.TaskContext; +import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.memory.SparkMemoryUtil; import org.apache.spark.shuffle.ColumnarShuffleDependency; import org.apache.spark.shuffle.GlutenShuffleUtils; +import org.apache.spark.shuffle.RssShuffleHandle; import org.apache.spark.shuffle.RssShuffleManager; import org.apache.spark.shuffle.RssSparkConfig; import org.apache.spark.sql.vectorized.ColumnarBatch; -import java.io.IOException; -import java.util.List; -import java.util.concurrent.TimeUnit; -import java.util.function.Function; -import org.apache.spark.SparkConf; -import org.apache.spark.executor.ShuffleWriteMetrics; -import org.apache.spark.shuffle.RssShuffleHandle; import org.apache.spark.util.SparkResourceUtil; import org.apache.uniffle.client.api.ShuffleWriteClient; import org.apache.uniffle.common.ShuffleBlockInfo; import org.slf4j.Logger; import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.List; +import java.util.concurrent.TimeUnit; +import java.util.function.Function; + import scala.Product2; import scala.collection.Iterator; public class VeloxUniffleColumnarShuffleWriter extends RssShuffleWriter { - private static final Logger LOG = LoggerFactory.getLogger( - VeloxUniffleColumnarShuffleWriter.class); + private static final Logger LOG = + LoggerFactory.getLogger(VeloxUniffleColumnarShuffleWriter.class); private long nativeShuffleWriter = 0L; - private String compreSsionCodec; + private String compreSsionCodec; private int compressThreshold = GlutenConfig.getConf().columnarShuffleCompressionThreshold(); private double reallocThreshold = GlutenConfig.getConf().columnarShuffleReallocThreshold(); - private ShuffleWriterJniWrapper jniWrapper = ShuffleWriterJniWrapper.create(); private SplitResult splitResult; private int nativeBufferSize = GlutenConfig.getConf().maxBatchSize(); @@ -67,8 +68,8 @@ public class VeloxUniffleColumnarShuffleWriter extends RssShuffleWriter rssHandle, Function taskFailureCallback, TaskContext context) { - super(appId, - shuffleId, - taskId, - taskAttemptId, - shuffleWriteMetrics, - shuffleManager, - sparkConf, - shuffleWriteClient, - rssHandle, - taskFailureCallback, context); + super( + appId, + shuffleId, + taskId, + taskAttemptId, + shuffleWriteMetrics, + shuffleManager, + sparkConf, + shuffleWriteClient, + rssHandle, + taskFailureCallback, + context); columnarDep = (ColumnarShuffleDependency) rssHandle.getDependency(); this.sparkConf = sparkConf; - bufferSize = (int)sparkConf.getSizeAsBytes( - RssSparkConfig.RSS_WRITER_BUFFER_SIZE.key(), - RssSparkConfig.RSS_WRITER_BUFFER_SIZE.defaultValue().get()); + bufferSize = + (int) + sparkConf.getSizeAsBytes( + RssSparkConfig.RSS_WRITER_BUFFER_SIZE.key(), + RssSparkConfig.RSS_WRITER_BUFFER_SIZE.defaultValue().get()); compreSsionCodec = GlutenShuffleUtils.getCompressionCodec(sparkConf); } - - @Override protected void writeImpl(Iterator> records) throws IOException { if (!records.hasNext() && !isMemoryShuffleEnabled) { @@ -118,45 +121,48 @@ protected void writeImpl(Iterator> records) throws IOException { } else { long handle = ColumnarBatches.getNativeHandle(cb); if (nativeShuffleWriter == 0) { - nativeShuffleWriter = jniWrapper.makeForRSS( - columnarDep.nativePartitioning(), - nativeBufferSize, - // use field do this - compreSsionCodec, - compressThreshold, - GlutenConfig.getConf().columnarShuffleCompressionMode(), - bufferSize, - partitionPusher, - NativeMemoryManagers - .create( - "UniffleShuffleWriter", - new Spiller() { - @Override - public long spill(MemoryTarget self, long size) { - if (nativeShuffleWriter == -1) { - throw new IllegalStateException( - "Fatal: spill() called before a shuffle shuffle writer " + - "evaluator is created. This behavior should be" + - "optimized by moving memory " + - "allocations from make() to split()"); - } - LOG.info("Gluten shuffle writer: Trying to push {} bytes of data", size); - long pushed = jniWrapper.nativeEvict(nativeShuffleWriter, size, false); - LOG.info("Gluten shuffle writer: Pushed {} / {} bytes of data", pushed, - size); - return pushed; - } - }) - .getNativeInstanceHandle(), - handle, - taskAttemptId, - "uniffle", - reallocThreshold - ); + nativeShuffleWriter = + jniWrapper.makeForRSS( + columnarDep.nativePartitioning(), + nativeBufferSize, + // use field do this + compreSsionCodec, + compressThreshold, + GlutenConfig.getConf().columnarShuffleCompressionMode(), + bufferSize, + partitionPusher, + NativeMemoryManagers.create( + "UniffleShuffleWriter", + new Spiller() { + @Override + public long spill(MemoryTarget self, long size) { + if (nativeShuffleWriter == -1) { + throw new IllegalStateException( + "Fatal: spill() called before a shuffle shuffle writer " + + "evaluator is created. This behavior should be" + + "optimized by moving memory " + + "allocations from make() to split()"); + } + LOG.info( + "Gluten shuffle writer: Trying to push {} bytes of data", size); + long pushed = + jniWrapper.nativeEvict(nativeShuffleWriter, size, false); + LOG.info( + "Gluten shuffle writer: Pushed {} / {} bytes of data", + pushed, + size); + return pushed; + } + }) + .getNativeInstanceHandle(), + handle, + taskAttemptId, + "uniffle", + reallocThreshold); } long startTime = System.nanoTime(); - long bytes = jniWrapper.split(nativeShuffleWriter, cb.numRows(), handle, - availableOffHeapPerTask()); + long bytes = + jniWrapper.split(nativeShuffleWriter, cb.numRows(), handle, availableOffHeapPerTask()); LOG.debug("jniWrapper.split rows {}, split bytes {}", cb.numRows(), bytes); columnarDep.metrics().get("dataSize").get().add(bytes); // this metric replace part of uniffle shuffle write time @@ -173,11 +179,16 @@ public long spill(MemoryTarget self, long size) { } LOG.info("nativeShuffleWriter value {}", nativeShuffleWriter); splitResult = jniWrapper.stop(nativeShuffleWriter); - columnarDep.metrics().get("splitTime").get().add( - System.nanoTime() - startTime - - splitResult.getTotalPushTime() - - splitResult.getTotalWriteTime() - - splitResult.getTotalCompressTime()); + columnarDep + .metrics() + .get("splitTime") + .get() + .add( + System.nanoTime() + - startTime + - splitResult.getTotalPushTime() + - splitResult.getTotalWriteTime() + - splitResult.getTotalCompressTime()); shuffleWriteMetrics.incBytesWritten(splitResult.getTotalBytesWritten()); shuffleWriteMetrics.incWriteTime( @@ -192,7 +203,8 @@ public long spill(MemoryTarget self, long size) { } long writeDurationMs = System.nanoTime() - pushMergedDataTime; shuffleWriteMetrics.incWriteTime(writeDurationMs); - LOG.info("Finish write shuffle with rest write {} ms", + LOG.info( + "Finish write shuffle with rest write {} ms", TimeUnit.MILLISECONDS.toNanos(writeDurationMs)); } @@ -204,8 +216,8 @@ private void sendRestBlockAndWait() { } public int doAddByte(int partitionId, byte[] data, int length) { - List shuffleBlockInfos = super.getBufferManager() - .addPartitionData(partitionId, data, length); + List shuffleBlockInfos = + super.getBufferManager().addPartitionData(partitionId, data, length); super.processShuffleBlockInfos(shuffleBlockInfos); return length; } diff --git a/gluten-uniffle/velox/src/main/scala/org/apache/spark/shuffle/writer/PartitionPusher.scala b/gluten-uniffle/velox/src/main/scala/org/apache/spark/shuffle/writer/PartitionPusher.scala index a55d0911daee..eb99fd23bb8e 100644 --- a/gluten-uniffle/velox/src/main/scala/org/apache/spark/shuffle/writer/PartitionPusher.scala +++ b/gluten-uniffle/velox/src/main/scala/org/apache/spark/shuffle/writer/PartitionPusher.scala @@ -14,10 +14,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package org.apache.spark.shuffle.writer - import java.io.IOException class PartitionPusher(val uniffleWriter: VeloxUniffleColumnarShuffleWriter[_, _]) { @@ -26,4 +24,4 @@ class PartitionPusher(val uniffleWriter: VeloxUniffleColumnarShuffleWriter[_, _] def pushPartitionData(partitionId: Int, buffer: Array[Byte], length: Int): Int = { uniffleWriter.doAddByte(partitionId, buffer, length) } -} \ No newline at end of file +}