diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java index 5c10ac6763..fcce059730 100644 --- a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java +++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java @@ -35,6 +35,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import scala.reflect.ClassTag$; +import scala.reflect.ManifestFactory$; import org.apache.uniffle.client.util.ClientUtils; import org.apache.uniffle.common.ShuffleBlockInfo; @@ -112,7 +113,11 @@ public List addRecord(int partitionId, Object key, Object valu final long start = System.currentTimeMillis(); arrayOutputStream.reset(); serializeStream.writeKey(key, ClassTag$.MODULE$.apply(key.getClass())); - serializeStream.writeValue(value, ClassTag$.MODULE$.apply(value.getClass())); + if (value != null) { + serializeStream.writeValue(value, ClassTag$.MODULE$.apply(value.getClass())); + } else { + serializeStream.writeValue(null, ManifestFactory$.MODULE$.Null()); + } serializeStream.flush(); serializeTime += System.currentTimeMillis() - start; byte[] serializedData = arrayOutputStream.getBuf(); diff --git a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java index 3fc2039877..4c2be50a4e 100644 --- a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java +++ b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java @@ -139,6 +139,20 @@ public void addHugeRecordTest() { assertEquals(1, wbm.getBuffers().size()); } + @Test + public void addNullValueRecordTest() { + SparkConf conf = getConf(); + WriteBufferManager wbm = createManager(conf); + String testKey = "key"; + String testValue = null; + List result = wbm.addRecord(0, testKey, testValue); + assertEquals(0, result.size()); + assertEquals(512, wbm.getAllocatedBytes()); + assertEquals(32, wbm.getUsedBytes()); + assertEquals(0, wbm.getInSendListBytes()); + assertEquals(1, wbm.getBuffers().size()); + } + @Test public void createBlockIdTest() { SparkConf conf = getConf();