Skip to content

Commit

Permalink
add row based config and fix style
Browse files Browse the repository at this point in the history
  • Loading branch information
summaryzb committed Nov 20, 2023
1 parent f3f4f76 commit 871a15c
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 117 deletions.
28 changes: 14 additions & 14 deletions cpp/core/jni/JniWrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -859,20 +859,20 @@ JNIEXPORT jlong JNICALL Java_io_glutenproject_vectorized_ShuffleWriterJniWrapper
} else if (partitionWriterType == "uniffle") {
shuffleWriterOptions.partition_writer_type = PartitionWriterType::kUniffle;
jclass unifflePartitionPusherClass =
createGlobalClassReferenceOrError(env, "Lorg/apache/spark/shuffle/writer/PartitionPusher;");
jmethodID unifflePushPartitionDataMethod =
getMethodIdOrError(env, unifflePartitionPusherClass, "pushPartitionData", "(I[B)I");
if (pushBufferMaxSize > 0) {
shuffleWriterOptions.push_buffer_max_size = pushBufferMaxSize;
}
JavaVM* vm;
if (env->GetJavaVM(&vm) != JNI_OK) {
gluten::jniThrow("Unable to get JavaVM instance");
}
// rename CelebornClient RssClient
std::shared_ptr<CelebornClient> celebornClient =
std::make_shared<CelebornClient>(vm, partitionPusher, unifflePushPartitionDataMethod);
partitionWriterCreator = std::make_shared<CelebornPartitionWriterCreator>(std::move(celebornClient));
createGlobalClassReferenceOrError(env, "Lorg/apache/spark/shuffle/writer/PartitionPusher;");
jmethodID unifflePushPartitionDataMethod =
getMethodIdOrError(env, unifflePartitionPusherClass, "pushPartitionData", "(I[B)I");
if (pushBufferMaxSize > 0) {
shuffleWriterOptions.push_buffer_max_size = pushBufferMaxSize;
}
JavaVM* vm;
if (env->GetJavaVM(&vm) != JNI_OK) {
gluten::jniThrow("Unable to get JavaVM instance");
}
// rename CelebornClient RssClient
std::shared_ptr<CelebornClient> celebornClient =
std::make_shared<CelebornClient>(vm, partitionPusher, unifflePushPartitionDataMethod);
partitionWriterCreator = std::make_shared<CelebornPartitionWriterCreator>(std::move(celebornClient));
} else {
throw gluten::GlutenException("Unrecognizable partition writer type: " + partitionWriterType);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
*/
package org.apache.spark.shuffle.gluten.uniffle;

import java.lang.reflect.Constructor;
import java.util.concurrent.ConcurrentHashMap;
import org.apache.spark.ShuffleDependency;
import org.apache.spark.SparkConf;
import org.apache.spark.SparkEnv;
Expand All @@ -26,18 +24,18 @@
import org.apache.spark.shuffle.ColumnarShuffleDependency;
import org.apache.spark.shuffle.RssShuffleHandle;
import org.apache.spark.shuffle.RssShuffleManager;
import org.apache.spark.shuffle.RssSparkConfig;
import org.apache.spark.shuffle.ShuffleHandle;
import org.apache.spark.shuffle.ShuffleReadMetricsReporter;
import org.apache.spark.shuffle.ShuffleReader;
import org.apache.spark.shuffle.ShuffleWriteMetricsReporter;
import org.apache.spark.shuffle.ShuffleWriter;
import org.apache.spark.shuffle.sort.ColumnarShuffleManager;
import org.apache.spark.shuffle.writer.RssShuffleWriter;
import org.apache.spark.shuffle.writer.VeloxUniffleColumnarShuffleWriter;
import org.apache.uniffle.common.exception.RssException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.lang.reflect.Constructor;

public class GlutenRssShuffleManager extends RssShuffleManager {
private static final Logger LOG = LoggerFactory.getLogger(GlutenRssShuffleManager.class);
private static final String GLUTEN_SHUFFLE_MANAGER_NAME =
Expand All @@ -53,8 +51,7 @@ private ColumnarShuffleManager columnarShuffleManager() {
synchronized (this) {
if (_columnarShuffleManager == null) {
_columnarShuffleManager =
initShuffleManager(
GLUTEN_SHUFFLE_MANAGER_NAME, sparkConf, isDriver());
initShuffleManager(GLUTEN_SHUFFLE_MANAGER_NAME, sparkConf, isDriver());
}
}
}
Expand All @@ -65,8 +62,7 @@ private RssShuffleManager vanillaUniffleShuffleManager() {
if (_vanillaUniffleShuffleManager == null) {
synchronized (this) {
if (_vanillaUniffleShuffleManager == null) {
initShuffleManager(
VANILLA_UNIFFLE_SHUFFLE_MANAGER_NAME, sparkConf, isDriver());
initShuffleManager(VANILLA_UNIFFLE_SHUFFLE_MANAGER_NAME, sparkConf, isDriver());
}
}
}
Expand All @@ -77,18 +73,17 @@ private boolean isDriver() {
return "driver".equals(SparkEnv.get().executorId());
}

private ColumnarShuffleManager initShuffleManager(
String name, SparkConf conf, boolean isDriver) {
private ColumnarShuffleManager initShuffleManager(String name, SparkConf conf, boolean isDriver) {
Constructor constructor;
ColumnarShuffleManager instance;
try {
Class klass = Class.forName(name);
try {
constructor = klass.getConstructor(conf.getClass(), Boolean.TYPE);
instance = (ColumnarShuffleManager)constructor.newInstance(conf, isDriver);
instance = (ColumnarShuffleManager) constructor.newInstance(conf, isDriver);
} catch (NoSuchMethodException var7) {
constructor = klass.getConstructor(conf.getClass());
instance = (ColumnarShuffleManager)constructor.newInstance(conf);
instance = (ColumnarShuffleManager) constructor.newInstance(conf);
}
} catch (Exception e) {
throw new RuntimeException("initColumnManager fail");
Expand All @@ -100,20 +95,21 @@ public GlutenRssShuffleManager(SparkConf conf, boolean isDriver) {
super(conf, isDriver);
// TODO conf set some config
}

@Override
public <K, V, C> ShuffleHandle registerShuffle(
int shuffleId, ShuffleDependency<K, V, C> dependency) {
return super.registerShuffle(shuffleId, dependency);
}



@Override
public <K, V> ShuffleWriter<K, V> getWriter(
ShuffleHandle handle, long mapId, TaskContext context, ShuffleWriteMetricsReporter metrics) {
if (!(handle instanceof RssShuffleHandle)) {
throw new RssException("Unexpected ShuffleHandle:" + handle.getClass().getName());
}
sparkConf.setIfMissing(
RssSparkConfig.SPARK_RSS_CONFIG_PREFIX + RssSparkConfig.RSS_ROW_BASED, "false");
RssShuffleHandle<K, V, V> rssHandle = (RssShuffleHandle<K, V, V>) handle;
if (rssHandle.getDependency() instanceof ColumnarShuffleDependency) {
setPusherAppId(rssHandle);
Expand All @@ -124,16 +120,18 @@ public <K, V> ShuffleWriter<K, V> getWriter(
} else {
writeMetrics = context.taskMetrics().shuffleWriteMetrics();
}
return new VeloxUniffleColumnarShuffleWriter<>(
rssHandle.getAppId(), rssHandle.getShuffleId(), taskId,
context.taskAttemptId(),
writeMetrics,
this,
sparkConf,
shuffleWriteClient,
rssHandle,
this::markFailedTask,
context);
return new VeloxUniffleColumnarShuffleWriter<>(
rssHandle.getAppId(),
rssHandle.getShuffleId(),
taskId,
context.taskAttemptId(),
writeMetrics,
this,
sparkConf,
shuffleWriteClient,
rssHandle,
this::markFailedTask,
context);
} else {
return super.getWriter(handle, mapId, context, metrics);
}
Expand Down
Loading

0 comments on commit 871a15c

Please sign in to comment.