Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,9 @@ public void setClientId(String id) {
}

/**
* Requests a single chunk from the remote side, from the pre-negotiated streamId.
* Requests a chunk from the remote side, from the pre-negotiated streamId. The chunk will be
* fetched with a single response, or a stream if `streamCallback` is not null and the server
* supports fetching chunk as stream.
*
* Chunk indices go from 0 onwards. It is valid to request the same chunk multiple times, though
* some streams may not support this.
Expand All @@ -128,11 +130,15 @@ public void setClientId(String id) {
* be agreed upon by client and server beforehand.
* @param chunkIndex 0-based index of the chunk to fetch
* @param callback Callback invoked upon successful receipt of chunk, or upon any failure.
* @param streamCallback If it's not null, we will send a `ChunkFetchRequest` with
* `fetchAsStream=true`, and this callback will be used to handle the stream
* response.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should update this code comment of fetchChunk. Now it can request stream instead of just a single chunk.

*/
public void fetchChunk(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not big deal but maybe rename to fetchChunkOrStream?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we still fetch a chunk, but the chunk may be returned as a stream.

long streamId,
int chunkIndex,
ChunkReceivedCallback callback) {
ChunkReceivedCallback callback,
StreamCallback streamCallback) {
if (logger.isDebugEnabled()) {
logger.debug("Sending fetch chunk request {} to {}", chunkIndex, getRemoteAddress(channel));
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we have something in the log to show this is also a stream request in case of streamCallback != null?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It depends on how you interprete it. We can say that this is a special chunk fetch request, the server side can return a stream reponse for it.

Expand All @@ -142,12 +148,27 @@ public void fetchChunk(
@Override
void handleFailure(String errorMsg, Throwable cause) {
handler.removeFetchRequest(streamChunkId);
handler.removeFetchAsStreamRequest(streamChunkId);
callback.onFailure(chunkIndex, new IOException(errorMsg, cause));
}
};

boolean fetchAsStream = streamCallback != null;
handler.addFetchRequest(streamChunkId, callback);
if (fetchAsStream) {
handler.addFetchAsStreamRequest(streamChunkId, streamCallback);
}

ChunkFetchRequest request = new ChunkFetchRequest(streamChunkId, fetchAsStream);
channel.writeAndFlush(request).addListener(listener);
}

channel.writeAndFlush(new ChunkFetchRequest(streamChunkId)).addListener(listener);
// This is only used in tests.
public void fetchChunk(
long streamId,
int chunkIndex,
ChunkReceivedCallback callback) {
fetchChunk(streamId, chunkIndex, callback, null);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,7 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.apache.spark.network.protocol.ChunkFetchFailure;
import org.apache.spark.network.protocol.ChunkFetchSuccess;
import org.apache.spark.network.protocol.ResponseMessage;
import org.apache.spark.network.protocol.RpcFailure;
import org.apache.spark.network.protocol.RpcResponse;
import org.apache.spark.network.protocol.StreamChunkId;
import org.apache.spark.network.protocol.StreamFailure;
import org.apache.spark.network.protocol.StreamResponse;
import org.apache.spark.network.protocol.*;
import org.apache.spark.network.server.MessageHandler;
import static org.apache.spark.network.util.NettyUtils.getRemoteAddress;
import org.apache.spark.network.util.TransportFrameDecoder;
Expand All @@ -56,6 +49,8 @@ public class TransportResponseHandler extends MessageHandler<ResponseMessage> {

private final Map<StreamChunkId, ChunkReceivedCallback> outstandingFetches;

private final Map<StreamChunkId, StreamCallback> outstandingFetchAsStreams;

private final Map<Long, RpcResponseCallback> outstandingRpcs;

private final Queue<Pair<String, StreamCallback>> streamCallbacks;
Expand All @@ -67,6 +62,7 @@ public class TransportResponseHandler extends MessageHandler<ResponseMessage> {
public TransportResponseHandler(Channel channel) {
this.channel = channel;
this.outstandingFetches = new ConcurrentHashMap<>();
this.outstandingFetchAsStreams = new ConcurrentHashMap<>();
this.outstandingRpcs = new ConcurrentHashMap<>();
this.streamCallbacks = new ConcurrentLinkedQueue<>();
this.timeOfLastRequestNs = new AtomicLong(0);
Expand All @@ -81,6 +77,17 @@ public void removeFetchRequest(StreamChunkId streamChunkId) {
outstandingFetches.remove(streamChunkId);
}

public void addFetchAsStreamRequest(
StreamChunkId streamChunkId,
StreamCallback callback) {
updateTimeOfLastRequest();
outstandingFetchAsStreams.put(streamChunkId, callback);
}

public void removeFetchAsStreamRequest(StreamChunkId streamChunkId) {
outstandingFetchAsStreams.remove(streamChunkId);
}

public void addRpcRequest(long requestId, RpcResponseCallback callback) {
updateTimeOfLastRequest();
outstandingRpcs.put(requestId, callback);
Expand Down Expand Up @@ -112,6 +119,13 @@ private void failOutstandingRequests(Throwable cause) {
logger.warn("ChunkReceivedCallback.onFailure throws exception", e);
}
}
for (Map.Entry<StreamChunkId, StreamCallback> entry : outstandingFetchAsStreams.entrySet()) {
try {
entry.getValue().onFailure(entry.getKey().toString(), cause);
} catch (Exception e) {
logger.warn("ChunkFetchRequest's StreamCallback.onFailure throws exception", e);
}
}
for (Map.Entry<Long, RpcResponseCallback> entry : outstandingRpcs.entrySet()) {
try {
entry.getValue().onFailure(cause);
Expand All @@ -129,6 +143,7 @@ private void failOutstandingRequests(Throwable cause) {

// It's OK if new fetches appear, as they will fail immediately.
outstandingFetches.clear();
outstandingFetchAsStreams.clear();
outstandingRpcs.clear();
streamCallbacks.clear();
}
Expand Down Expand Up @@ -171,6 +186,22 @@ public void handle(ResponseMessage message) throws Exception {
listener.onSuccess(resp.streamChunkId.chunkIndex, resp.body());
resp.body().release();
}
// The response is `ChunkFetchSuccess`. It's either because the request was a normal chunk
// fetch request, or the server side is an old version that doesn't support fetch chunk as
// stream. So the next line is either a no-op, or remove the callback that will never be
// called later.
outstandingFetchAsStreams.remove(resp.streamChunkId);
} else if (message instanceof ChunkFetchStreamResponse) {
ChunkFetchStreamResponse resp = (ChunkFetchStreamResponse) message;
StreamCallback callback = outstandingFetchAsStreams.get(resp.streamChunkId);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also remove this callback from outstandingFetchAsStreams ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch!

if (callback == null) {
logger.warn("Ignoring stream response for block {} from {} since it is not outstanding",
resp.streamChunkId, getRemoteAddress(channel));
resp.body().release();
} else {
outstandingFetchAsStreams.remove(resp.streamChunkId);
readStream(resp.streamChunkId.toString(), resp.byteCount, callback);
}
} else if (message instanceof ChunkFetchFailure) {
ChunkFetchFailure resp = (ChunkFetchFailure) message;
ChunkReceivedCallback listener = outstandingFetches.get(resp.streamChunkId);
Expand Down Expand Up @@ -211,25 +242,7 @@ public void handle(ResponseMessage message) throws Exception {
Pair<String, StreamCallback> entry = streamCallbacks.poll();
if (entry != null) {
StreamCallback callback = entry.getValue();
if (resp.byteCount > 0) {
StreamInterceptor<ResponseMessage> interceptor = new StreamInterceptor<>(
this, resp.streamId, resp.byteCount, callback);
try {
TransportFrameDecoder frameDecoder = (TransportFrameDecoder)
channel.pipeline().get(TransportFrameDecoder.HANDLER_NAME);
frameDecoder.setInterceptor(interceptor);
streamActive = true;
} catch (Exception e) {
logger.error("Error installing stream handler.", e);
deactivateStream();
}
} else {
try {
callback.onComplete(resp.streamId);
} catch (Exception e) {
logger.warn("Error in stream handler onComplete().", e);
}
}
readStream(resp.streamId, resp.byteCount, callback);
} else {
logger.error("Could not find callback for StreamResponse.");
}
Expand All @@ -251,10 +264,32 @@ public void handle(ResponseMessage message) throws Exception {
}
}

private void readStream(String streamId, long byteCount, StreamCallback callback) {
if (byteCount > 0) {
StreamInterceptor<ResponseMessage> interceptor = new StreamInterceptor<>(
this, streamId, byteCount, callback);
try {
TransportFrameDecoder frameDecoder = (TransportFrameDecoder)
channel.pipeline().get(TransportFrameDecoder.HANDLER_NAME);
frameDecoder.setInterceptor(interceptor);
streamActive = true;
} catch (Exception e) {
logger.error("Error installing stream handler.", e);
deactivateStream();
}
} else {
try {
callback.onComplete(streamId);
} catch (Exception e) {
logger.warn("Error in stream handler onComplete().", e);
}
}
}

/** Returns total number of outstanding requests (fetch requests + rpcs) */
public int numOutstandingRequests() {
return outstandingFetches.size() + outstandingRpcs.size() + streamCallbacks.size() +
(streamActive ? 1 : 0);
return outstandingFetches.size() + outstandingFetchAsStreams.size() + outstandingRpcs.size() +
streamCallbacks.size() + (streamActive ? 1 : 0);
}

/** Returns the time in nanoseconds of when the last request was sent out. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,37 +27,60 @@
public final class ChunkFetchRequest extends AbstractMessage implements RequestMessage {
public final StreamChunkId streamChunkId;

// Indicates if the client wants to fetch this chunk as a stream, to reduce memory consumption.
// This field is newly added in Spark 3.0, and will be encoded in the message only when it's true.
public final boolean fetchAsStream;

public ChunkFetchRequest(StreamChunkId streamChunkId, boolean fetchAsStream) {
this.streamChunkId = streamChunkId;
this.fetchAsStream = fetchAsStream;
}

// This is only used in tests.
public ChunkFetchRequest(StreamChunkId streamChunkId) {
this.streamChunkId = streamChunkId;
this.fetchAsStream = false;
}

@Override
public Message.Type type() { return Type.ChunkFetchRequest; }

@Override
public int encodedLength() {
return streamChunkId.encodedLength();
return streamChunkId.encodedLength() + (fetchAsStream ? 1 : 0);
}

@Override
public void encode(ByteBuf buf) {
streamChunkId.encode(buf);
if (fetchAsStream) {
buf.writeBoolean(true);
}
}

public static ChunkFetchRequest decode(ByteBuf buf) {
return new ChunkFetchRequest(StreamChunkId.decode(buf));
StreamChunkId streamChunkId = StreamChunkId.decode(buf);
boolean fetchAsStream;
if (buf.readableBytes() >= 1) {
// A sanity check. In `encode` we write true, so here we should read true.
assert buf.readBoolean();
fetchAsStream = true;
} else {
fetchAsStream = false;
}
return new ChunkFetchRequest(streamChunkId, fetchAsStream);
}

@Override
public int hashCode() {
return streamChunkId.hashCode();
return java.util.Objects.hash(streamChunkId, fetchAsStream);
}

@Override
public boolean equals(Object other) {
if (other instanceof ChunkFetchRequest) {
ChunkFetchRequest o = (ChunkFetchRequest) other;
return streamChunkId.equals(o.streamChunkId);
return streamChunkId.equals(o.streamChunkId) && fetchAsStream == o.fetchAsStream;
}
return false;
}
Expand All @@ -66,6 +89,7 @@ public boolean equals(Object other) {
public String toString() {
return Objects.toStringHelper(this)
.add("streamChunkId", streamChunkId)
.add("fetchAsStream", fetchAsStream)
.toString();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.network.protocol;

import com.google.common.base.Objects;
import io.netty.buffer.ByteBuf;

import org.apache.spark.network.buffer.ManagedBuffer;

/**
* Response to {@link ChunkFetchRequest} when its `fetchAsStream` flag is true and the stream has
* been successfully opened.
* <p>
* Note the message itself does not contain the stream data. That is written separately by the
* sender. The receiver is expected to set a temporary channel handler that will consume the
* number of bytes this message says the stream has.
*/
public final class ChunkFetchStreamResponse extends AbstractResponseMessage {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is very similar to StreamResponse, except that here we use StreamChunkId streamChunkId instead of String streamId.

public final StreamChunkId streamChunkId;
public final long byteCount;

public ChunkFetchStreamResponse(
StreamChunkId streamChunkId,
long byteCount,
ManagedBuffer buffer) {
super(buffer, false);
this.streamChunkId = streamChunkId;
this.byteCount = byteCount;
}

@Override
public Message.Type type() { return Type.ChunkFetchStreamResponse; }

@Override
public int encodedLength() {
return streamChunkId.encodedLength() + 8;
}

/** Encoding does NOT include 'buffer' itself. See {@link MessageEncoder}. */
@Override
public void encode(ByteBuf buf) {
streamChunkId.encode(buf);
buf.writeLong(byteCount);
}

@Override
public ResponseMessage createFailureResponse(String error) {
return new ChunkFetchFailure(streamChunkId, error);
}

public static ChunkFetchStreamResponse decode(ByteBuf buf) {
StreamChunkId streamChunkId = StreamChunkId.decode(buf);
long byteCount = buf.readLong();
return new ChunkFetchStreamResponse(streamChunkId, byteCount, null);
}

@Override
public int hashCode() {
return Objects.hashCode(streamChunkId, byteCount);
}

@Override
public boolean equals(Object other) {
if (other instanceof ChunkFetchStreamResponse) {
ChunkFetchStreamResponse o = (ChunkFetchStreamResponse) other;
return streamChunkId.equals(o.streamChunkId) && byteCount == o.byteCount;
}
return false;
}

@Override
public String toString() {
return Objects.toStringHelper(this)
.add("streamChunkId", streamChunkId)
.add("byteCount", byteCount)
.add("body", body())
.toString();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ enum Type implements Encodable {
ChunkFetchRequest(0), ChunkFetchSuccess(1), ChunkFetchFailure(2),
RpcRequest(3), RpcResponse(4), RpcFailure(5),
StreamRequest(6), StreamResponse(7), StreamFailure(8),
OneWayMessage(9), UploadStream(10), User(-1);
OneWayMessage(9), UploadStream(10), ChunkFetchStreamResponse(11), User(-1);

private final byte id;

Expand Down Expand Up @@ -66,6 +66,7 @@ public static Type decode(ByteBuf buf) {
case 8: return StreamFailure;
case 9: return OneWayMessage;
case 10: return UploadStream;
case 11: return ChunkFetchStreamResponse;
case -1: throw new IllegalArgumentException("User type messages cannot be decoded.");
default: throw new IllegalArgumentException("Unknown message type: " + id);
}
Expand Down
Loading