diff --git a/client-mr/src/main/java/org/apache/hadoop/mapred/RssMapOutputCollector.java b/client-mr/src/main/java/org/apache/hadoop/mapred/RssMapOutputCollector.java index c9cb553f2e..7d7b18065f 100644 --- a/client-mr/src/main/java/org/apache/hadoop/mapred/RssMapOutputCollector.java +++ b/client-mr/src/main/java/org/apache/hadoop/mapred/RssMapOutputCollector.java @@ -36,6 +36,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.apache.uniffle.client.api.ShuffleWriteClient; import org.apache.uniffle.common.ShuffleServerInfo; import org.apache.uniffle.common.exception.RssException; import org.apache.uniffle.common.util.ByteUnit; @@ -52,6 +53,7 @@ public class RssMapOutputCollector private Set failedBlockIds = Sets.newConcurrentHashSet(); private int partitions; private SortWriteBufferManager bufferManager; + private ShuffleWriteClient shuffleClient; @Override public void init(Context context) throws IOException, ClassNotFoundException { @@ -107,6 +109,7 @@ public void init(Context context) throws IOException, ClassNotFoundException { RssMRConfig.RSS_CLIENT_DEFAULT_SEND_THREAD_NUM); long maxBufferSize = RssMRUtils.getLong(rssJobConf, mrJobConf, RssMRConfig.RSS_WRITER_BUFFER_SIZE, RssMRConfig.RSS_WRITER_BUFFER_SIZE_DEFAULT_VALUE); + shuffleClient = RssMRUtils.createShuffleClient(mrJobConf); bufferManager = new SortWriteBufferManager( (long)(ByteUnit.MiB.toBytes(sortmb) * sortThreshold), taskAttemptId, @@ -116,7 +119,7 @@ public void init(Context context) throws IOException, ClassNotFoundException { comparator, memoryThreshold, appId, - RssMRUtils.createShuffleClient(mrJobConf), + shuffleClient, sendCheckInterval, sendCheckTimeout, partitionToServers, @@ -188,6 +191,7 @@ private void checkRssException() { public void close() throws IOException, InterruptedException { reporter.progress(); bufferManager.freeAllResources(); + shuffleClient.close(); } @Override