diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java index 20d840baeaf6..0b05857c0482 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java @@ -115,7 +115,9 @@ public void setClientId(String id) { } /** - * Requests a single chunk from the remote side, from the pre-negotiated streamId. + * Requests a chunk from the remote side, from the pre-negotiated streamId. The chunk will be + * fetched with a single response, or a stream if `streamCallback` is not null and the server + * supports fetching chunk as stream. * * Chunk indices go from 0 onwards. It is valid to request the same chunk multiple times, though * some streams may not support this. @@ -128,11 +130,15 @@ public void setClientId(String id) { * be agreed upon by client and server beforehand. * @param chunkIndex 0-based index of the chunk to fetch * @param callback Callback invoked upon successful receipt of chunk, or upon any failure. + * @param streamCallback If it's not null, we will send a `ChunkFetchRequest` with + * `fetchAsStream=true`, and this callback will be used to handle the stream + * response. */ public void fetchChunk( long streamId, int chunkIndex, - ChunkReceivedCallback callback) { + ChunkReceivedCallback callback, + StreamCallback streamCallback) { if (logger.isDebugEnabled()) { logger.debug("Sending fetch chunk request {} to {}", chunkIndex, getRemoteAddress(channel)); } @@ -142,12 +148,27 @@ public void fetchChunk( @Override void handleFailure(String errorMsg, Throwable cause) { handler.removeFetchRequest(streamChunkId); + handler.removeFetchAsStreamRequest(streamChunkId); callback.onFailure(chunkIndex, new IOException(errorMsg, cause)); } }; + + boolean fetchAsStream = streamCallback != null; handler.addFetchRequest(streamChunkId, callback); + if (fetchAsStream) { + handler.addFetchAsStreamRequest(streamChunkId, streamCallback); + } + + ChunkFetchRequest request = new ChunkFetchRequest(streamChunkId, fetchAsStream); + channel.writeAndFlush(request).addListener(listener); + } - channel.writeAndFlush(new ChunkFetchRequest(streamChunkId)).addListener(listener); + // This is only used in tests. + public void fetchChunk( + long streamId, + int chunkIndex, + ChunkReceivedCallback callback) { + fetchChunk(streamId, chunkIndex, callback, null); } /** diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java index 596b0ea5dba9..68e07198c6ab 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java @@ -31,14 +31,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.apache.spark.network.protocol.ChunkFetchFailure; -import org.apache.spark.network.protocol.ChunkFetchSuccess; -import org.apache.spark.network.protocol.ResponseMessage; -import org.apache.spark.network.protocol.RpcFailure; -import org.apache.spark.network.protocol.RpcResponse; -import org.apache.spark.network.protocol.StreamChunkId; -import org.apache.spark.network.protocol.StreamFailure; -import org.apache.spark.network.protocol.StreamResponse; +import org.apache.spark.network.protocol.*; import org.apache.spark.network.server.MessageHandler; import static org.apache.spark.network.util.NettyUtils.getRemoteAddress; import org.apache.spark.network.util.TransportFrameDecoder; @@ -56,6 +49,8 @@ public class TransportResponseHandler extends MessageHandler { private final Map outstandingFetches; + private final Map outstandingFetchAsStreams; + private final Map outstandingRpcs; private final Queue> streamCallbacks; @@ -67,6 +62,7 @@ public class TransportResponseHandler extends MessageHandler { public TransportResponseHandler(Channel channel) { this.channel = channel; this.outstandingFetches = new ConcurrentHashMap<>(); + this.outstandingFetchAsStreams = new ConcurrentHashMap<>(); this.outstandingRpcs = new ConcurrentHashMap<>(); this.streamCallbacks = new ConcurrentLinkedQueue<>(); this.timeOfLastRequestNs = new AtomicLong(0); @@ -81,6 +77,17 @@ public void removeFetchRequest(StreamChunkId streamChunkId) { outstandingFetches.remove(streamChunkId); } + public void addFetchAsStreamRequest( + StreamChunkId streamChunkId, + StreamCallback callback) { + updateTimeOfLastRequest(); + outstandingFetchAsStreams.put(streamChunkId, callback); + } + + public void removeFetchAsStreamRequest(StreamChunkId streamChunkId) { + outstandingFetchAsStreams.remove(streamChunkId); + } + public void addRpcRequest(long requestId, RpcResponseCallback callback) { updateTimeOfLastRequest(); outstandingRpcs.put(requestId, callback); @@ -112,6 +119,13 @@ private void failOutstandingRequests(Throwable cause) { logger.warn("ChunkReceivedCallback.onFailure throws exception", e); } } + for (Map.Entry entry : outstandingFetchAsStreams.entrySet()) { + try { + entry.getValue().onFailure(entry.getKey().toString(), cause); + } catch (Exception e) { + logger.warn("ChunkFetchRequest's StreamCallback.onFailure throws exception", e); + } + } for (Map.Entry entry : outstandingRpcs.entrySet()) { try { entry.getValue().onFailure(cause); @@ -129,6 +143,7 @@ private void failOutstandingRequests(Throwable cause) { // It's OK if new fetches appear, as they will fail immediately. outstandingFetches.clear(); + outstandingFetchAsStreams.clear(); outstandingRpcs.clear(); streamCallbacks.clear(); } @@ -171,6 +186,22 @@ public void handle(ResponseMessage message) throws Exception { listener.onSuccess(resp.streamChunkId.chunkIndex, resp.body()); resp.body().release(); } + // The response is `ChunkFetchSuccess`. It's either because the request was a normal chunk + // fetch request, or the server side is an old version that doesn't support fetch chunk as + // stream. So the next line is either a no-op, or remove the callback that will never be + // called later. + outstandingFetchAsStreams.remove(resp.streamChunkId); + } else if (message instanceof ChunkFetchStreamResponse) { + ChunkFetchStreamResponse resp = (ChunkFetchStreamResponse) message; + StreamCallback callback = outstandingFetchAsStreams.get(resp.streamChunkId); + if (callback == null) { + logger.warn("Ignoring stream response for block {} from {} since it is not outstanding", + resp.streamChunkId, getRemoteAddress(channel)); + resp.body().release(); + } else { + outstandingFetchAsStreams.remove(resp.streamChunkId); + readStream(resp.streamChunkId.toString(), resp.byteCount, callback); + } } else if (message instanceof ChunkFetchFailure) { ChunkFetchFailure resp = (ChunkFetchFailure) message; ChunkReceivedCallback listener = outstandingFetches.get(resp.streamChunkId); @@ -211,25 +242,7 @@ public void handle(ResponseMessage message) throws Exception { Pair entry = streamCallbacks.poll(); if (entry != null) { StreamCallback callback = entry.getValue(); - if (resp.byteCount > 0) { - StreamInterceptor interceptor = new StreamInterceptor<>( - this, resp.streamId, resp.byteCount, callback); - try { - TransportFrameDecoder frameDecoder = (TransportFrameDecoder) - channel.pipeline().get(TransportFrameDecoder.HANDLER_NAME); - frameDecoder.setInterceptor(interceptor); - streamActive = true; - } catch (Exception e) { - logger.error("Error installing stream handler.", e); - deactivateStream(); - } - } else { - try { - callback.onComplete(resp.streamId); - } catch (Exception e) { - logger.warn("Error in stream handler onComplete().", e); - } - } + readStream(resp.streamId, resp.byteCount, callback); } else { logger.error("Could not find callback for StreamResponse."); } @@ -251,10 +264,32 @@ public void handle(ResponseMessage message) throws Exception { } } + private void readStream(String streamId, long byteCount, StreamCallback callback) { + if (byteCount > 0) { + StreamInterceptor interceptor = new StreamInterceptor<>( + this, streamId, byteCount, callback); + try { + TransportFrameDecoder frameDecoder = (TransportFrameDecoder) + channel.pipeline().get(TransportFrameDecoder.HANDLER_NAME); + frameDecoder.setInterceptor(interceptor); + streamActive = true; + } catch (Exception e) { + logger.error("Error installing stream handler.", e); + deactivateStream(); + } + } else { + try { + callback.onComplete(streamId); + } catch (Exception e) { + logger.warn("Error in stream handler onComplete().", e); + } + } + } + /** Returns total number of outstanding requests (fetch requests + rpcs) */ public int numOutstandingRequests() { - return outstandingFetches.size() + outstandingRpcs.size() + streamCallbacks.size() + - (streamActive ? 1 : 0); + return outstandingFetches.size() + outstandingFetchAsStreams.size() + outstandingRpcs.size() + + streamCallbacks.size() + (streamActive ? 1 : 0); } /** Returns the time in nanoseconds of when the last request was sent out. */ diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java index fe54fcc50dc8..5ff46348652f 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java @@ -27,8 +27,19 @@ public final class ChunkFetchRequest extends AbstractMessage implements RequestMessage { public final StreamChunkId streamChunkId; + // Indicates if the client wants to fetch this chunk as a stream, to reduce memory consumption. + // This field is newly added in Spark 3.0, and will be encoded in the message only when it's true. + public final boolean fetchAsStream; + + public ChunkFetchRequest(StreamChunkId streamChunkId, boolean fetchAsStream) { + this.streamChunkId = streamChunkId; + this.fetchAsStream = fetchAsStream; + } + + // This is only used in tests. public ChunkFetchRequest(StreamChunkId streamChunkId) { this.streamChunkId = streamChunkId; + this.fetchAsStream = false; } @Override @@ -36,28 +47,40 @@ public ChunkFetchRequest(StreamChunkId streamChunkId) { @Override public int encodedLength() { - return streamChunkId.encodedLength(); + return streamChunkId.encodedLength() + (fetchAsStream ? 1 : 0); } @Override public void encode(ByteBuf buf) { streamChunkId.encode(buf); + if (fetchAsStream) { + buf.writeBoolean(true); + } } public static ChunkFetchRequest decode(ByteBuf buf) { - return new ChunkFetchRequest(StreamChunkId.decode(buf)); + StreamChunkId streamChunkId = StreamChunkId.decode(buf); + boolean fetchAsStream; + if (buf.readableBytes() >= 1) { + // A sanity check. In `encode` we write true, so here we should read true. + assert buf.readBoolean(); + fetchAsStream = true; + } else { + fetchAsStream = false; + } + return new ChunkFetchRequest(streamChunkId, fetchAsStream); } @Override public int hashCode() { - return streamChunkId.hashCode(); + return java.util.Objects.hash(streamChunkId, fetchAsStream); } @Override public boolean equals(Object other) { if (other instanceof ChunkFetchRequest) { ChunkFetchRequest o = (ChunkFetchRequest) other; - return streamChunkId.equals(o.streamChunkId); + return streamChunkId.equals(o.streamChunkId) && fetchAsStream == o.fetchAsStream; } return false; } @@ -66,6 +89,7 @@ public boolean equals(Object other) { public String toString() { return Objects.toStringHelper(this) .add("streamChunkId", streamChunkId) + .add("fetchAsStream", fetchAsStream) .toString(); } } diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchStreamResponse.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchStreamResponse.java new file mode 100644 index 000000000000..06dcbaae62df --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/ChunkFetchStreamResponse.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +import org.apache.spark.network.buffer.ManagedBuffer; + +/** + * Response to {@link ChunkFetchRequest} when its `fetchAsStream` flag is true and the stream has + * been successfully opened. + *

+ * Note the message itself does not contain the stream data. That is written separately by the + * sender. The receiver is expected to set a temporary channel handler that will consume the + * number of bytes this message says the stream has. + */ +public final class ChunkFetchStreamResponse extends AbstractResponseMessage { + public final StreamChunkId streamChunkId; + public final long byteCount; + + public ChunkFetchStreamResponse( + StreamChunkId streamChunkId, + long byteCount, + ManagedBuffer buffer) { + super(buffer, false); + this.streamChunkId = streamChunkId; + this.byteCount = byteCount; + } + + @Override + public Message.Type type() { return Type.ChunkFetchStreamResponse; } + + @Override + public int encodedLength() { + return streamChunkId.encodedLength() + 8; + } + + /** Encoding does NOT include 'buffer' itself. See {@link MessageEncoder}. */ + @Override + public void encode(ByteBuf buf) { + streamChunkId.encode(buf); + buf.writeLong(byteCount); + } + + @Override + public ResponseMessage createFailureResponse(String error) { + return new ChunkFetchFailure(streamChunkId, error); + } + + public static ChunkFetchStreamResponse decode(ByteBuf buf) { + StreamChunkId streamChunkId = StreamChunkId.decode(buf); + long byteCount = buf.readLong(); + return new ChunkFetchStreamResponse(streamChunkId, byteCount, null); + } + + @Override + public int hashCode() { + return Objects.hashCode(streamChunkId, byteCount); + } + + @Override + public boolean equals(Object other) { + if (other instanceof ChunkFetchStreamResponse) { + ChunkFetchStreamResponse o = (ChunkFetchStreamResponse) other; + return streamChunkId.equals(o.streamChunkId) && byteCount == o.byteCount; + } + return false; + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("streamChunkId", streamChunkId) + .add("byteCount", byteCount) + .add("body", body()) + .toString(); + } +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/Message.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/Message.java index 0ccd70c03aba..eb09cb2c4045 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/Message.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/Message.java @@ -37,7 +37,7 @@ enum Type implements Encodable { ChunkFetchRequest(0), ChunkFetchSuccess(1), ChunkFetchFailure(2), RpcRequest(3), RpcResponse(4), RpcFailure(5), StreamRequest(6), StreamResponse(7), StreamFailure(8), - OneWayMessage(9), UploadStream(10), User(-1); + OneWayMessage(9), UploadStream(10), ChunkFetchStreamResponse(11), User(-1); private final byte id; @@ -66,6 +66,7 @@ public static Type decode(ByteBuf buf) { case 8: return StreamFailure; case 9: return OneWayMessage; case 10: return UploadStream; + case 11: return ChunkFetchStreamResponse; case -1: throw new IllegalArgumentException("User type messages cannot be decoded."); default: throw new IllegalArgumentException("Unknown message type: " + id); } diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java index bf80aed0afe1..6f2d3a4bbfea 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java @@ -83,6 +83,9 @@ private Message decode(Message.Type msgType, ByteBuf in) { case UploadStream: return UploadStream.decode(in); + case ChunkFetchStreamResponse: + return ChunkFetchStreamResponse.decode(in); + default: throw new IllegalArgumentException("Unexpected message type: " + msgType); } diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/ChunkFetchRequestHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/ChunkFetchRequestHandler.java index f08d8b0f984c..12a5b0dd1cb0 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/ChunkFetchRequestHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/ChunkFetchRequestHandler.java @@ -30,10 +30,7 @@ import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.client.TransportClient; -import org.apache.spark.network.protocol.ChunkFetchFailure; -import org.apache.spark.network.protocol.ChunkFetchRequest; -import org.apache.spark.network.protocol.ChunkFetchSuccess; -import org.apache.spark.network.protocol.Encodable; +import org.apache.spark.network.protocol.*; import static org.apache.spark.network.util.NettyUtils.*; @@ -101,7 +98,13 @@ protected void channelRead0( } streamManager.chunkBeingSent(msg.streamChunkId.streamId); - respond(channel, new ChunkFetchSuccess(msg.streamChunkId, buf)).addListener( + AbstractResponseMessage response; + if (msg.fetchAsStream) { + response = new ChunkFetchStreamResponse(msg.streamChunkId, buf.size(), buf); + } else { + response = new ChunkFetchSuccess(msg.streamChunkId, buf); + } + respond(channel, response).addListener( (ChannelFutureListener) future -> streamManager.chunkSent(msg.streamChunkId.streamId)); } diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java b/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java index 0f6a8824d95e..14038e15e670 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java @@ -99,16 +99,14 @@ public ManagedBuffer getChunk(long streamId, int chunkIndex) { return nextChunk; } + // This is needed for clients of Spark 2.2, 2.3, 2.4, which will send stream request to fetch + // chunks. @Override public ManagedBuffer openStream(String streamChunkId) { Pair streamChunkIdPair = parseStreamChunkId(streamChunkId); return getChunk(streamChunkIdPair.getLeft(), streamChunkIdPair.getRight()); } - public static String genStreamChunkId(long streamId, int chunkId) { - return String.format("%d_%d", streamId, chunkId); - } - // Parse streamChunkId to be stream id and chunk id. This is used when fetch remote chunk as a // stream. public static Pair parseStreamChunkId(String streamChunkId) { diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java index 30587023877c..4edbf51aecca 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java @@ -29,7 +29,6 @@ import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.StreamCallback; import org.apache.spark.network.client.TransportClient; -import org.apache.spark.network.server.OneForOneStreamManager; import org.apache.spark.network.shuffle.protocol.BlockTransferMessage; import org.apache.spark.network.shuffle.protocol.OpenBlocks; import org.apache.spark.network.shuffle.protocol.StreamHandle; @@ -58,12 +57,12 @@ public class OneForOneBlockFetcher { private StreamHandle streamHandle = null; public OneForOneBlockFetcher( - TransportClient client, - String appId, - String execId, - String[] blockIds, - BlockFetchingListener listener, - TransportConf transportConf) { + TransportClient client, + String appId, + String execId, + String[] blockIds, + BlockFetchingListener listener, + TransportConf transportConf) { this(client, appId, execId, blockIds, listener, transportConf, null); } @@ -96,7 +95,7 @@ public void onSuccess(int chunkIndex, ManagedBuffer buffer) { public void onFailure(int chunkIndex, Throwable e) { // On receipt of a failure, fail every block from chunkIndex onwards. String[] remainingBlockIds = Arrays.copyOfRange(blockIds, chunkIndex, blockIds.length); - failRemainingBlocks(remainingBlockIds, e); + failBlocks(remainingBlockIds, e); } } @@ -120,29 +119,30 @@ public void onSuccess(ByteBuffer response) { // Immediately request all chunks -- we expect that the total size of the request is // reasonable due to higher level chunking in [[ShuffleBlockFetcherIterator]]. for (int i = 0; i < streamHandle.numChunks; i++) { + StreamCallback streamCallback; if (downloadFileManager != null) { - client.stream(OneForOneStreamManager.genStreamChunkId(streamHandle.streamId, i), - new DownloadCallback(i)); + streamCallback = new DownloadCallback(i); } else { - client.fetchChunk(streamHandle.streamId, i, chunkCallback); + streamCallback = null; } + client.fetchChunk(streamHandle.streamId, i, chunkCallback, streamCallback); } } catch (Exception e) { logger.error("Failed while starting block fetches after success", e); - failRemainingBlocks(blockIds, e); + failBlocks(blockIds, e); } } @Override public void onFailure(Throwable e) { logger.error("Failed while starting block fetches", e); - failRemainingBlocks(blockIds, e); + failBlocks(blockIds, e); } }); } /** Invokes the "onBlockFetchFailure" callback for every listed block id. */ - private void failRemainingBlocks(String[] failedBlockIds, Throwable e) { + private void failBlocks(String[] failedBlockIds, Throwable e) { for (String blockId : failedBlockIds) { try { listener.onBlockFetchFailure(blockId, e); @@ -184,7 +184,7 @@ public void onFailure(String streamId, Throwable cause) throws IOException { channel.close(); // On receipt of a failure, fail every block from chunkIndex onwards. String[] remainingBlockIds = Arrays.copyOfRange(blockIds, chunkIndex, blockIds.length); - failRemainingBlocks(remainingBlockIds, cause); + failBlocks(remainingBlockIds, cause); targetFile.delete(); } } diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java index 95460637db89..9dab3774a91a 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java @@ -165,7 +165,7 @@ private static BlockFetchingListener fetchBlocks(LinkedHashMap