From 1fbd6435b8f3541eb7fb2f4dc5a12ca542e62e50 Mon Sep 17 00:00:00 2001 From: turboFei Date: Fri, 12 Apr 2019 14:22:14 +0800 Subject: [PATCH 1/2] [SPARK-27562][Shuffle] Complete the verification mechanism for shuffle transmitted data --- .../DigestFileSegmentManagedBuffer.java | 50 ++++++++ .../buffer/FileSegmentManagedBuffer.java | 2 +- .../network/client/ChunkReceivedCallback.java | 5 + .../spark/network/client/StreamCallback.java | 5 + .../network/client/StreamInterceptor.java | 17 ++- .../client/TransportResponseHandler.java | 41 +++++++ .../protocol/DigestChunkFetchSuccess.java | 93 +++++++++++++++ .../protocol/DigestStreamResponse.java | 95 +++++++++++++++ .../spark/network/protocol/Message.java | 5 +- .../network/protocol/MessageDecoder.java | 6 + .../server/ChunkFetchRequestHandler.java | 18 ++- .../server/TransportRequestHandler.java | 14 ++- .../spark/network/util/DigestUtils.java | 64 ++++++++++ .../shuffle/BlockFetchingListener.java | 9 ++ .../shuffle/ExternalShuffleBlockResolver.java | 29 +++-- .../shuffle/OneForOneBlockFetcher.java | 14 +++ .../network/shuffle/RetryingBlockFetcher.java | 18 +++ .../shuffle/ShuffleIndexInformation.java | 47 +++++++- .../network/shuffle/ShuffleIndexRecord.java | 10 +- .../spark/internal/config/package.scala | 7 ++ .../shuffle/BlockStoreShuffleReader.scala | 3 +- .../shuffle/IndexShuffleBlockResolver.scala | 112 +++++++++++++++--- .../storage/ShuffleBlockFetcherIterator.scala | 3 +- .../scala/org/apache/spark/ShuffleSuite.scala | 16 +++ .../sort/IndexShuffleBlockResolverSuite.scala | 27 ++++- docs/configuration.md | 31 +++++ 26 files changed, 698 insertions(+), 43 deletions(-) create mode 100644 common/network-common/src/main/java/org/apache/spark/network/buffer/DigestFileSegmentManagedBuffer.java create mode 100644 common/network-common/src/main/java/org/apache/spark/network/protocol/DigestChunkFetchSuccess.java create mode 100644 common/network-common/src/main/java/org/apache/spark/network/protocol/DigestStreamResponse.java create mode 100644 common/network-common/src/main/java/org/apache/spark/network/util/DigestUtils.java diff --git a/common/network-common/src/main/java/org/apache/spark/network/buffer/DigestFileSegmentManagedBuffer.java b/common/network-common/src/main/java/org/apache/spark/network/buffer/DigestFileSegmentManagedBuffer.java new file mode 100644 index 000000000000..d58e3ce2b2c8 --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/buffer/DigestFileSegmentManagedBuffer.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.buffer; + +import java.io.File; + +import com.google.common.base.Objects; + +import org.apache.spark.network.util.TransportConf; + +/** + * A {@link ManagedBuffer} backed by a segment in a file with digest. + */ +public final class DigestFileSegmentManagedBuffer extends FileSegmentManagedBuffer { + + private final long digest; + + public DigestFileSegmentManagedBuffer(TransportConf conf, File file, long offset, long length, + long digest) { + super(conf, file, offset, length); + this.digest = digest; + } + + public long getDigest() { return digest; } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("file", getFile()) + .add("offset", getOffset()) + .add("length", getLength()) + .add("digest", digest) + .toString(); + } +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java b/common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java index 66566b67870f..ed64867ad6c1 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java +++ b/common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java @@ -38,7 +38,7 @@ /** * A {@link ManagedBuffer} backed by a segment in a file. */ -public final class FileSegmentManagedBuffer extends ManagedBuffer { +public class FileSegmentManagedBuffer extends ManagedBuffer { private final TransportConf conf; private final File file; private final long offset; diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/ChunkReceivedCallback.java b/common/network-common/src/main/java/org/apache/spark/network/client/ChunkReceivedCallback.java index 519e6cb470d0..db1682c086f5 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/ChunkReceivedCallback.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/ChunkReceivedCallback.java @@ -36,6 +36,11 @@ public interface ChunkReceivedCallback { */ void onSuccess(int chunkIndex, ManagedBuffer buffer); + /** Called with a extra digest parameter upon receipt of a particular chunk.*/ + default void onSuccess(int chunkIndex, ManagedBuffer buffer, long digest) { + onSuccess(chunkIndex, buffer); + } + /** * Called upon failure to fetch a particular chunk. Note that this may actually be called due * to failure to fetch a prior chunk in this stream. diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/StreamCallback.java b/common/network-common/src/main/java/org/apache/spark/network/client/StreamCallback.java index d322aec28793..2b9b6cce2562 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/StreamCallback.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/StreamCallback.java @@ -35,6 +35,11 @@ public interface StreamCallback { /** Called when all data from the stream has been received. */ void onComplete(String streamId) throws IOException; + /** Called with a extra digest when all data from the stream has been received. */ + default void onComplete(String streamId, long digest) throws IOException { + onComplete(streamId); + } + /** Called if there's an error reading data from the stream. */ void onFailure(String streamId, Throwable cause) throws IOException; } diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java b/common/network-common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java index f3eb744ff734..f9d4ca1addcd 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/StreamInterceptor.java @@ -37,6 +37,7 @@ public class StreamInterceptor implements TransportFrameDecod private final long byteCount; private final StreamCallback callback; private long bytesRead; + private long digest = -1L; public StreamInterceptor( MessageHandler handler, @@ -50,6 +51,16 @@ public StreamInterceptor( this.bytesRead = 0; } + public StreamInterceptor( + MessageHandler handler, + String streamId, + long byteCount, + StreamCallback callback, + long digest) { + this(handler, streamId, byteCount, callback); + this.digest = digest; + } + @Override public void exceptionCaught(Throwable cause) throws Exception { deactivateStream(); @@ -86,7 +97,11 @@ public boolean handle(ByteBuf buf) throws Exception { throw re; } else if (bytesRead == byteCount) { deactivateStream(); - callback.onComplete(streamId); + if (digest < 0) { + callback.onComplete(streamId); + } else { + callback.onComplete(streamId, digest); + } } return bytesRead != byteCount; diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java index 2f143f77fa4a..e83a352d0821 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java @@ -33,6 +33,8 @@ import org.apache.spark.network.protocol.ChunkFetchFailure; import org.apache.spark.network.protocol.ChunkFetchSuccess; +import org.apache.spark.network.protocol.DigestChunkFetchSuccess; +import org.apache.spark.network.protocol.DigestStreamResponse; import org.apache.spark.network.protocol.ResponseMessage; import org.apache.spark.network.protocol.RpcFailure; import org.apache.spark.network.protocol.RpcResponse; @@ -246,6 +248,45 @@ public void handle(ResponseMessage message) throws Exception { } else { logger.warn("Stream failure with unknown callback: {}", resp.error); } + } else if (message instanceof DigestChunkFetchSuccess) { + DigestChunkFetchSuccess resp = (DigestChunkFetchSuccess) message; + ChunkReceivedCallback listener = outstandingFetches.get(resp.streamChunkId); + if (listener == null) { + logger.warn("Ignoring response for block {} from {} since it is not outstanding", + resp.streamChunkId, getRemoteAddress(channel)); + resp.body().release(); + } else { + outstandingFetches.remove(resp.streamChunkId); + listener.onSuccess(resp.streamChunkId.chunkIndex, resp.body(), resp.digest); + resp.body().release(); + } + } else if (message instanceof DigestStreamResponse) { + DigestStreamResponse resp = (DigestStreamResponse) message; + Pair entry = streamCallbacks.poll(); + if (entry != null) { + StreamCallback callback = entry.getValue(); + if (resp.byteCount > 0) { + StreamInterceptor interceptor = new StreamInterceptor(this, resp.streamId, resp.byteCount, + callback, resp.digest); + try { + TransportFrameDecoder frameDecoder = (TransportFrameDecoder) + channel.pipeline().get(TransportFrameDecoder.HANDLER_NAME); + frameDecoder.setInterceptor(interceptor); + streamActive = true; + } catch (Exception e) { + logger.error("Error installing stream handler.", e); + deactivateStream(); + } + } else { + try { + callback.onComplete(resp.streamId, resp.digest); + } catch (Exception e) { + logger.warn("Error in stream handler onComplete().", e); + } + } + } else { + logger.error("Could not find callback for StreamResponse."); + } } else { throw new IllegalStateException("Unknown response type: " + message.type()); } diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/DigestChunkFetchSuccess.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/DigestChunkFetchSuccess.java new file mode 100644 index 000000000000..733231704fde --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/DigestChunkFetchSuccess.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NettyManagedBuffer; + +/** + * Response to {@link ChunkFetchRequest} when a chunk exists with a digest and has been + * successfully fetched. + * + * Note that the server-side encoding of this messages does NOT include the buffer itself, as this + * may be written by Netty in a more efficient manner (i.e., zero-copy write). + * Similarly, the client-side decoding will reuse the Netty ByteBuf as the buffer. + */ +public final class DigestChunkFetchSuccess extends AbstractResponseMessage { + public final StreamChunkId streamChunkId; + public final long digest; + public DigestChunkFetchSuccess(StreamChunkId streamChunkId, ManagedBuffer buffer, long digest) { + super(buffer, true); + this.streamChunkId = streamChunkId; + this.digest = digest; + } + + @Override + public Type type() { return Type.DigestChunkFetchSuccess; } + + @Override + public int encodedLength() { + return streamChunkId.encodedLength() + 8; + } + + /** Encoding does NOT include 'buffer' itself. See {@link MessageEncoder}. */ + @Override + public void encode(ByteBuf buf) { + streamChunkId.encode(buf); + buf.writeLong(digest); + } + + @Override + public ResponseMessage createFailureResponse(String error) { + return new ChunkFetchFailure(streamChunkId, error); + } + + /** Decoding uses the given ByteBuf as our data, and will retain() it. */ + public static DigestChunkFetchSuccess decode(ByteBuf buf) { + StreamChunkId streamChunkId = StreamChunkId.decode(buf); + long digest = buf.readLong(); + buf.retain(); + NettyManagedBuffer managedBuf = new NettyManagedBuffer(buf.duplicate()); + return new DigestChunkFetchSuccess(streamChunkId, managedBuf, digest); + } + + @Override + public int hashCode() { + return Objects.hashCode(streamChunkId, body(), digest); + } + + @Override + public boolean equals(Object other) { + if (other instanceof DigestChunkFetchSuccess) { + DigestChunkFetchSuccess o = (DigestChunkFetchSuccess) other; + return streamChunkId.equals(o.streamChunkId) && super.equals(o) && digest == o.digest; + } + return false; + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("streamChunkId", streamChunkId) + .add("digest", digest) + .add("buffer", body()) + .toString(); + } +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/DigestStreamResponse.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/DigestStreamResponse.java new file mode 100644 index 000000000000..7ef4a28c0ebc --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/DigestStreamResponse.java @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; +import org.apache.spark.network.buffer.ManagedBuffer; + +/** + * Response to {@link StreamRequest} with digest when the stream has been successfully opened. + *

+ * Note the message itself does not contain the stream data. That is written separately by the + * sender. The receiver is expected to set a temporary channel handler that will consume the + * number of bytes this message says the stream has. + */ +public final class DigestStreamResponse extends AbstractResponseMessage { + public final String streamId; + public final long byteCount; + public final long digest; + + public DigestStreamResponse(String streamId, long byteCount, ManagedBuffer buffer, long digest) { + super(buffer, false); + this.streamId = streamId; + this.byteCount = byteCount; + this.digest = digest; + } + + @Override + public Type type() { return Type.DigestStreamResponse; } + + @Override + public int encodedLength() { + return 8 + Encoders.Strings.encodedLength(streamId) + 8; + } + + /** Encoding does NOT include 'buffer' itself. See {@link MessageEncoder}. */ + @Override + public void encode(ByteBuf buf) { + Encoders.Strings.encode(buf, streamId); + buf.writeLong(byteCount); + buf.writeLong(digest); + } + + @Override + public ResponseMessage createFailureResponse(String error) { + return new StreamFailure(streamId, error); + } + + public static DigestStreamResponse decode(ByteBuf buf) { + String streamId = Encoders.Strings.decode(buf); + long byteCount = buf.readLong(); + long digest = buf.readLong(); + return new DigestStreamResponse(streamId, byteCount, null, digest); + } + + @Override + public int hashCode() { + return Objects.hashCode(byteCount, streamId, body(), digest); + } + + @Override + public boolean equals(Object other) { + if (other instanceof DigestStreamResponse) { + DigestStreamResponse o = (DigestStreamResponse) other; + return byteCount == o.byteCount && streamId.equals(o.streamId) && digest == o.digest; + } + return false; + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("streamId", streamId) + .add("byteCount", byteCount) + .add("digest", digest) + .add("body", body()) + .toString(); + } + +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/Message.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/Message.java index 0ccd70c03aba..cd3efdc59d38 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/Message.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/Message.java @@ -37,7 +37,8 @@ enum Type implements Encodable { ChunkFetchRequest(0), ChunkFetchSuccess(1), ChunkFetchFailure(2), RpcRequest(3), RpcResponse(4), RpcFailure(5), StreamRequest(6), StreamResponse(7), StreamFailure(8), - OneWayMessage(9), UploadStream(10), User(-1); + OneWayMessage(9), UploadStream(10), DigestChunkFetchSuccess(11), + DigestStreamResponse(12), User(-1); private final byte id; @@ -66,6 +67,8 @@ public static Type decode(ByteBuf buf) { case 8: return StreamFailure; case 9: return OneWayMessage; case 10: return UploadStream; + case 11: return DigestChunkFetchSuccess; + case 12: return DigestStreamResponse; case -1: throw new IllegalArgumentException("User type messages cannot be decoded."); default: throw new IllegalArgumentException("Unknown message type: " + id); } diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java index bf80aed0afe1..0d98f0161015 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java @@ -50,6 +50,12 @@ public void decode(ChannelHandlerContext ctx, ByteBuf in, List out) { private Message decode(Message.Type msgType, ByteBuf in) { switch (msgType) { + case DigestChunkFetchSuccess: + return DigestChunkFetchSuccess.decode(in); + + case DigestStreamResponse: + return DigestStreamResponse.decode(in); + case ChunkFetchRequest: return ChunkFetchRequest.decode(in); diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/ChunkFetchRequestHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/ChunkFetchRequestHandler.java index 82810dacdad8..cf06d88af158 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/ChunkFetchRequestHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/ChunkFetchRequestHandler.java @@ -25,15 +25,13 @@ import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.SimpleChannelInboundHandler; +import org.apache.spark.network.buffer.DigestFileSegmentManagedBuffer; +import org.apache.spark.network.protocol.*; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.client.TransportClient; -import org.apache.spark.network.protocol.ChunkFetchFailure; -import org.apache.spark.network.protocol.ChunkFetchRequest; -import org.apache.spark.network.protocol.ChunkFetchSuccess; -import org.apache.spark.network.protocol.Encodable; import static org.apache.spark.network.util.NettyUtils.*; @@ -111,8 +109,16 @@ public void processFetchRequest( } streamManager.chunkBeingSent(msg.streamChunkId.streamId); - respond(channel, new ChunkFetchSuccess(msg.streamChunkId, buf)).addListener( - (ChannelFutureListener) future -> streamManager.chunkSent(msg.streamChunkId.streamId)); + if (buf instanceof DigestFileSegmentManagedBuffer) { + respond(channel, new DigestChunkFetchSuccess(msg.streamChunkId, buf, + ((DigestFileSegmentManagedBuffer)buf).getDigest())) + .addListener((ChannelFutureListener) future -> + streamManager.chunkSent(msg.streamChunkId.streamId)); + } else { + respond(channel, new ChunkFetchSuccess(msg.streamChunkId, buf)).addListener( + (ChannelFutureListener) future -> streamManager.chunkSent(msg.streamChunkId.streamId)); + } + } /** diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java index f17892800690..226852d675bb 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java @@ -28,6 +28,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.apache.spark.network.buffer.DigestFileSegmentManagedBuffer; import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.buffer.NioManagedBuffer; import org.apache.spark.network.client.*; @@ -143,9 +144,16 @@ private void processStreamRequest(final StreamRequest req) { if (buf != null) { streamManager.streamBeingSent(req.streamId); - respond(new StreamResponse(req.streamId, buf.size(), buf)).addListener(future -> { - streamManager.streamSent(req.streamId); - }); + if (buf instanceof DigestFileSegmentManagedBuffer) { + respond(new DigestStreamResponse(req.streamId, buf.size(), buf, + ((DigestFileSegmentManagedBuffer) buf).getDigest())).addListener(future -> { + streamManager.streamSent(req.streamId); + }); + } else { + respond(new StreamResponse(req.streamId, buf.size(), buf)).addListener(future -> { + streamManager.streamSent(req.streamId); + }); + } } else { // org.apache.spark.repl.ExecutorClassLoader.STREAM_NOT_FOUND_REGEX should also be updated // when the following error message is changed. diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/DigestUtils.java b/common/network-common/src/main/java/org/apache/spark/network/util/DigestUtils.java new file mode 100644 index 000000000000..4c57c37be6c7 --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/util/DigestUtils.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.util; + +import java.io.File; +import java.io.FileInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.util.zip.CRC32; + +public class DigestUtils { + private static final int STREAM_BUFFER_LENGTH = 2048; + private static final int DIGEST_LENGTH = 8; + + public static int getDigestLength() { + return DIGEST_LENGTH; + } + + public DigestUtils() { + } + + public static long getDigest(InputStream data) throws IOException { + return updateCRC32(getCRC32(), data); + } + + public static long getDigest(File file, long offset, long length) { + try { + LimitedInputStream inputStream = new LimitedInputStream(new FileInputStream(file), + offset + length, true); + inputStream.skip(offset); + return getDigest(inputStream); + } catch (IOException e) { + return -1; + } + } + + public static CRC32 getCRC32() { + return new CRC32(); + } + + public static long updateCRC32(CRC32 crc32, InputStream data) throws IOException { + byte[] buffer = new byte[STREAM_BUFFER_LENGTH]; + int len = 0; + while ((len = data.read(buffer)) >= 0) { + crc32.update(buffer, 0, len); + } + return crc32.getValue(); + } +} diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockFetchingListener.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockFetchingListener.java index 138fd5389c20..b5f76aab11d3 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockFetchingListener.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockFetchingListener.java @@ -29,6 +29,15 @@ public interface BlockFetchingListener extends EventListener { */ void onBlockFetchSuccess(String blockId, ManagedBuffer data); + /** + * Called once per successfully fetch block during shuffle, which has a parameter present the + * checkSum of shuffle block. Here provide a default method body for that not every + * blockFetchingListener need to implement one onBlockFetchSuccess method. + */ + default void onBlockFetchSuccess(String blockId, ManagedBuffer data, long digest) { + onBlockFetchSuccess(blockId, data); + } + /** * Called at least once per block upon failures. */ diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java index ba1a17bf7e5e..2740373548d5 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java @@ -39,19 +39,17 @@ import com.google.common.cache.LoadingCache; import com.google.common.cache.Weigher; import com.google.common.collect.Maps; +import org.apache.spark.network.util.*; import org.iq80.leveldb.DB; import org.iq80.leveldb.DBIterator; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.apache.spark.network.buffer.DigestFileSegmentManagedBuffer; import org.apache.spark.network.buffer.FileSegmentManagedBuffer; import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; -import org.apache.spark.network.util.LevelDBProvider; import org.apache.spark.network.util.LevelDBProvider.StoreVersion; -import org.apache.spark.network.util.JavaUtils; -import org.apache.spark.network.util.NettyUtils; -import org.apache.spark.network.util.TransportConf; /** * Manages converting shuffle BlockIds into physical segments of local files, from a process outside @@ -320,12 +318,25 @@ private ManagedBuffer getSortBasedShuffleBlockData( ShuffleIndexInformation shuffleIndexInformation = shuffleIndexCache.get(indexFile); ShuffleIndexRecord shuffleIndexRecord = shuffleIndexInformation.getIndex( startReduceId, endReduceId); - return new FileSegmentManagedBuffer( - conf, - ExecutorDiskUtils.getFile(executor.localDirs, executor.subDirsPerLocalDir, + if (shuffleIndexInformation.isHasDigest()) { + File dataFile = ExecutorDiskUtils.getFile(executor.localDirs, executor.subDirsPerLocalDir, + "shuffle_" + shuffleId + "_" + mapId + "_0.data"); + return new DigestFileSegmentManagedBuffer( + conf, + dataFile, + shuffleIndexRecord.getOffset(), + shuffleIndexRecord.getLength(), + shuffleIndexRecord.getDigest().orElse(DigestUtils.getDigest( + dataFile, shuffleIndexRecord.getOffset(), shuffleIndexRecord.getLength()))); + + } else { + return new FileSegmentManagedBuffer( + conf, + ExecutorDiskUtils.getFile(executor.localDirs, executor.subDirsPerLocalDir, "shuffle_" + shuffleId + "_" + mapId + "_0.data"), - shuffleIndexRecord.getOffset(), - shuffleIndexRecord.getLength()); + shuffleIndexRecord.getOffset(), + shuffleIndexRecord.getLength()); + } } catch (ExecutionException e) { throw new RuntimeException("Failed to open file: " + indexFile, e); } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java index ec2e3dce661d..d7e1e096c423 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java @@ -165,6 +165,12 @@ public void onSuccess(int chunkIndex, ManagedBuffer buffer) { listener.onBlockFetchSuccess(blockIds[chunkIndex], buffer); } + @Override + public void onSuccess(int chunkIndex, ManagedBuffer buffer, long digest) { + // On receipt of a chunk, pass it upwards as a block. + listener.onBlockFetchSuccess(blockIds[chunkIndex], buffer, digest); + } + @Override public void onFailure(int chunkIndex, Throwable e) { // On receipt of a failure, fail every block from chunkIndex onwards. @@ -248,6 +254,14 @@ public void onComplete(String streamId) throws IOException { } } + @Override + public void onComplete(String streamId, long digest) throws IOException { + listener.onBlockFetchSuccess(blockIds[chunkIndex], channel.closeAndRead(), digest); + if (!downloadFileManager.registerTempFileToClean(targetFile)) { + targetFile.delete(); + } + } + @Override public void onFailure(String streamId, Throwable cause) throws IOException { channel.close(); diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java index 6bf3da94030d..d9e6f9251b6a 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java @@ -205,6 +205,24 @@ public void onBlockFetchSuccess(String blockId, ManagedBuffer data) { } } + @Override + public void onBlockFetchSuccess(String blockId, ManagedBuffer data, long digest) { + // We will only forward this success message to our parent listener if this block request is + // outstanding and we are still the active listener. + boolean shouldForwardSuccess = false; + synchronized (RetryingBlockFetcher.this) { + if (this == currentListener && outstandingBlocksIds.contains(blockId)) { + outstandingBlocksIds.remove(blockId); + shouldForwardSuccess = true; + } + } + + // Now actually invoke the parent listener, outside of the synchronized block. + if (shouldForwardSuccess) { + listener.onBlockFetchSuccess(blockId, data, digest); + } + } + @Override public void onBlockFetchFailure(String blockId, Throwable exception) { // We will only forward this failure to our parent listener if this block request is diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexInformation.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexInformation.java index b65aacfcc4b9..8f2852ea0278 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexInformation.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexInformation.java @@ -23,6 +23,9 @@ import java.nio.ByteBuffer; import java.nio.LongBuffer; import java.nio.file.Files; +import java.util.Optional; + +import org.apache.spark.network.util.DigestUtils; /** * Keeps the index information for a particular map output @@ -31,17 +34,44 @@ public class ShuffleIndexInformation { /** offsets as long buffer */ private final LongBuffer offsets; + private final boolean hasDigest; + /** digests as long buffer */ + private final LongBuffer digests; private int size; public ShuffleIndexInformation(File indexFile) throws IOException { + ByteBuffer offsetsBuffer, digestsBuffer; size = (int)indexFile.length(); - ByteBuffer buffer = ByteBuffer.allocate(size); - offsets = buffer.asLongBuffer(); + int offsetsSize, digestsSize; + if (size % 8 == 0) { + hasDigest = false; + offsetsSize = size; + digestsSize = 0; + } else { + hasDigest = true; + offsetsSize = ((size - 8 - 1) / (8 + DigestUtils.getDigestLength()) + 1) * 8; + digestsSize = size - offsetsSize -1; + } + offsetsBuffer = ByteBuffer.allocate(offsetsSize); + digestsBuffer = ByteBuffer.allocate(digestsSize); + offsets = offsetsBuffer.asLongBuffer(); + digests = digestsBuffer.asLongBuffer(); try (DataInputStream dis = new DataInputStream(Files.newInputStream(indexFile.toPath()))) { - dis.readFully(buffer.array()); + dis.readFully(offsetsBuffer.array()); + if (hasDigest) { + dis.readByte(); + } + dis.readFully(digestsBuffer.array()); } } + /** + * If this indexFile has digest + */ + public boolean isHasDigest() { + return hasDigest; + } + /** * Size of the index file * @return size @@ -63,6 +93,15 @@ public ShuffleIndexRecord getIndex(int reduceId) { public ShuffleIndexRecord getIndex(int startReduceId, int endReduceId) { long offset = offsets.get(startReduceId); long nextOffset = offsets.get(endReduceId); - return new ShuffleIndexRecord(offset, nextOffset - offset); + /** Default digest is -1L.*/ + Optional digest = Optional.of(-1L); + if (hasDigest) { + if (endReduceId - startReduceId == 1) { + digest = Optional.of(digests.get(startReduceId)); + } else { + digest = Optional.empty(); + } + } + return new ShuffleIndexRecord(offset, nextOffset - offset, digest); } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexRecord.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexRecord.java index 6a4fac150a6b..f64871c2bea3 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexRecord.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexRecord.java @@ -17,16 +17,20 @@ package org.apache.spark.network.shuffle; +import java.util.Optional; + /** * Contains offset and length of the shuffle block data. */ public class ShuffleIndexRecord { private final long offset; private final long length; + private final Optional digest; - public ShuffleIndexRecord(long offset, long length) { + public ShuffleIndexRecord(long offset, long length, Optional digest) { this.offset = offset; this.length = length; + this.digest = digest; } public long getOffset() { @@ -36,5 +40,9 @@ public long getOffset() { public long getLength() { return length; } + + public Optional getDigest() { + return digest; + } } diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 4cda4b180d97..61b52f797c2a 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -1836,4 +1836,11 @@ package object config { .version("3.1.0") .booleanConf .createWithDefault(false) + + private[spark] val SHUFFLE_DIGEST_ENABLED = + ConfigBuilder("spark.shuffle.digest.enabled") + .internal() + .doc("The parameter to control whether check the transmitted data during shuffle.") + .booleanConf + .createWithDefault(false) } diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index bc2a0fbc36d5..dd6bea1a1a7a 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -80,7 +80,8 @@ private[spark] class BlockStoreShuffleReader[K, C]( SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT), SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT_MEMORY), readMetrics, - fetchContinuousBlocksInBatch).toCompletionIterator + fetchContinuousBlocksInBatch, + SparkEnv.get.conf.get(config.SHUFFLE_DIGEST_ENABLED)).toCompletionIterator val serializerInstance = dep.serializer.newInstance() diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala index af2c82e77197..84e2f0bef867 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala @@ -22,11 +22,12 @@ import java.nio.channels.Channels import java.nio.file.Files import org.apache.spark.{SparkConf, SparkEnv} -import org.apache.spark.internal.Logging +import org.apache.spark.internal.{config, Logging} import org.apache.spark.io.NioBufferedFileInputStream -import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} +import org.apache.spark.network.buffer.{DigestFileSegmentManagedBuffer, FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.shuffle.ExecutorDiskUtils +import org.apache.spark.network.util.{DigestUtils, LimitedInputStream} import org.apache.spark.shuffle.IndexShuffleBlockResolver.NOOP_REDUCE_ID import org.apache.spark.storage._ import org.apache.spark.util.Utils @@ -52,6 +53,9 @@ private[spark] class IndexShuffleBlockResolver( private val transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle") + // The digest conf for shuffle block check + private final val digestEnable = conf.getBoolean(config.SHUFFLE_DIGEST_ENABLED.key, false); + private final val digestLength = DigestUtils.getDigestLength() def getDataFile(shuffleId: Int, mapId: Long): File = getDataFile(shuffleId, mapId, None) @@ -107,12 +111,16 @@ private[spark] class IndexShuffleBlockResolver( * Check whether the given index and data files match each other. * If so, return the partition lengths in the data file. Otherwise return null. */ - private def checkIndexAndDataFile(index: File, data: File, blocks: Int): Array[Long] = { - // the index file should have `block + 1` longs as offset. - if (index.length() != (blocks + 1) * 8L) { + private def checkIndexAndDataFile(index: File, data: File, blocks: Int, digests: Array[Long]): + (Array[Long], Array[Long]) = { + // Id digestEnable is false, the index file should have `blocks + 1` longs as offset. + // Otherwise, it should have a byte as flag, `blocks + 1` longs as offset and `blocks` digests + if ((!digestEnable && index.length() != (blocks + 1) * 8L) || + (digestEnable && index.length() != blocks * (8L + digestLength) + 8L + 1L)) { return null } val lengths = new Array[Long](blocks) + val digestArr = new Array[Long](blocks) // Read the lengths of blocks val in = try { new DataInputStream(new NioBufferedFileInputStream(index)) @@ -133,6 +141,18 @@ private[spark] class IndexShuffleBlockResolver( offset = off i += 1 } + if (digestEnable) { + val flag = in.readByte() + // the flag for digestEnable should be 1 + if (flag != 1) { + return null + } + i = 0 + while (i < blocks) { + digestArr(i) = in.readLong() + i += 1 + } + } } catch { case e: IOException => return null @@ -141,8 +161,8 @@ private[spark] class IndexShuffleBlockResolver( } // the size of data file should match with index file - if (data.length() == lengths.sum) { - lengths + if (data.length() == lengths.sum && !(0 until blocks).exists(i => digests(i) != digestArr(i))) { + (lengths, digestArr) } else { null } @@ -170,11 +190,38 @@ private[spark] class IndexShuffleBlockResolver( // There is only one IndexShuffleBlockResolver per executor, this synchronization make sure // the following check and rename are atomic. synchronized { - val existingLengths = checkIndexAndDataFile(indexFile, dataFile, lengths.length) - if (existingLengths != null) { + val digests = new Array[Long](lengths.length) + val dateIn = if (dataTmp != null && dataTmp.exists()) { + new FileInputStream(dataTmp) + } else { + null + } + Utils.tryWithSafeFinally { + if (digestEnable && dateIn != null) { + for (i <- (0 until lengths.length)) { + val length = lengths(i) + if (length == 0) { + digests(i) = -1L + } else { + digests(i) = DigestUtils.getDigest(new LimitedInputStream(dateIn, length)) + } + } + } + } { + if (dateIn != null) { + dateIn.close() + } + } + + val existingLengthsDigests = + checkIndexAndDataFile(indexFile, dataFile, lengths.length, digests) + if (existingLengthsDigests != null) { + val existingLengths = existingLengthsDigests._1 + val existingDigests = existingLengthsDigests._2 // Another attempt for the same task has already written our map outputs successfully, // so just use the existing partition lengths and delete our temporary map outputs. System.arraycopy(existingLengths, 0, lengths, 0, lengths.length) + System.arraycopy(existingDigests, 0, digests, 0, digests.length) if (dataTmp != null && dataTmp.exists()) { dataTmp.delete() } @@ -190,6 +237,13 @@ private[spark] class IndexShuffleBlockResolver( offset += length out.writeLong(offset) } + if (digestEnable) { + // we write a byte present digest enable + out.writeByte(1) + for (digest <- digests) { + out.writeLong(digest) + } + } } { out.close() } @@ -237,6 +291,10 @@ private[spark] class IndexShuffleBlockResolver( // class of issue from re-occurring in the future which is why they are left here even though // SPARK-22982 is fixed. val channel = Files.newByteChannel(indexFile.toPath) + var blocks = (indexFile.length() - 8) / 8 + if (digestEnable) { + blocks = (indexFile.length() - 8 - 1) / (8 + digestLength) + } channel.position(startReduceId * 8L) val in = new DataInputStream(Channels.newInputStream(channel)) try { @@ -249,11 +307,37 @@ private[spark] class IndexShuffleBlockResolver( throw new Exception(s"SPARK-22982: Incorrect channel position after index file reads: " + s"expected $expectedPosition but actual position was $actualPosition.") } - new FileSegmentManagedBuffer( - transportConf, - getDataFile(shuffleId, mapId, dirs), - startOffset, - endOffset - startOffset) + + if (digestEnable) { + val digestValue = if (endReduceId - startReduceId == 1) { + channel.position(1 + (blocks + 1) * 8L + startReduceId * digestLength) + val digest = in.readLong() + val actualDigestPosition = channel.position() + val expectedDigestLength = 1 + (blocks + 1) * 8L + (startReduceId + 1) * digestLength + if (actualDigestPosition != expectedDigestLength) { + throw new Exception(s"SPARK-22982: Incorrect channel position after index file " + + s"reads: expected $expectedDigestLength but actual position was " + + s" $actualDigestPosition.") + } + digest + } else { + DigestUtils.getDigest(getDataFile(shuffleId, mapId, dirs), startOffset, + endOffset - startOffset) + } + + new DigestFileSegmentManagedBuffer( + transportConf, + getDataFile(shuffleId, mapId, dirs), + startOffset, + endOffset - startOffset, + digestValue) + } else { + new FileSegmentManagedBuffer( + transportConf, + getDataFile(shuffleId, mapId, dirs), + startOffset, + endOffset - startOffset) + } } finally { in.close() } diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 5efbc0703f72..8a5ce70c1121 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -80,7 +80,8 @@ final class ShuffleBlockFetcherIterator( detectCorrupt: Boolean, detectCorruptUseExtraMemory: Boolean, shuffleMetrics: ShuffleReadMetricsReporter, - doBatchFetch: Boolean) + doBatchFetch: Boolean, + digest: Boolean) extends Iterator[(BlockId, InputStream)] with DownloadFileManager with Logging { import ShuffleBlockFetcherIterator._ diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index 9e39271bdf9e..37e169c5eef0 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -418,6 +418,22 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC manager.unregisterShuffle(0) } + + test("[SPARK-27562]: test shuffle with shuffle digest enabled is true") { + conf.set(config.SHUFFLE_DIGEST_ENABLED, true) + val sc = new SparkContext("local", "test", conf) + val numRecords = 10000 + + val wordCount = sc.parallelize(1 to numRecords, 4) + .map(key => (key, 1)) + .reduceByKey(_ + _) + .collect() + val count = wordCount.length + val sum = wordCount.map(value => value._1).sum + assert(count == numRecords) + assert(sum == (1 to numRecords).sum) + sc.stop() + } } /** diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala index 27bb06b4e063..a4ed7761b714 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.shuffle.sort -import java.io.{DataInputStream, File, FileInputStream, FileOutputStream} +import java.io.{ByteArrayInputStream, DataInputStream, File, FileInputStream, FileOutputStream} import org.mockito.{Mock, MockitoAnnotations} import org.mockito.Answers.RETURNS_SMART_NULLS @@ -27,6 +27,9 @@ import org.mockito.invocation.InvocationOnMock import org.scalatest.BeforeAndAfterEach import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.internal.config +import org.apache.spark.network.buffer.DigestFileSegmentManagedBuffer +import org.apache.spark.network.util.DigestUtils import org.apache.spark.shuffle.IndexShuffleBlockResolver import org.apache.spark.storage._ import org.apache.spark.util.Utils @@ -155,4 +158,26 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEa indexIn2.close() } } + + test("[SPARK-27562]: check the digest when shuffle digest enabled is true") { + val confClone = conf.clone + confClone.set(config.SHUFFLE_DIGEST_ENABLED, true) + val resolver = new IndexShuffleBlockResolver(confClone, blockManager) + val lengths = Array[Long](10, 0, 20) + val dataTmp = File.createTempFile("shuffle", null, tempDir) + val out = new FileOutputStream(dataTmp) + Utils.tryWithSafeFinally { + out.write(new Array[Byte](30)) + } { + out.close() + } + val digest = DigestUtils.getDigest(new ByteArrayInputStream(new Array[Byte](10))) + + resolver.writeIndexFileAndCommit(1, 2, lengths, dataTmp) + val managedBuffer = resolver.getBlockData(ShuffleBlockId(1, 2, 0)) + assert(managedBuffer.isInstanceOf[DigestFileSegmentManagedBuffer]) + assert(managedBuffer.asInstanceOf[DigestFileSegmentManagedBuffer].getDigest == digest) + + } + } diff --git a/docs/configuration.md b/docs/configuration.md index fce04b940594..ed0e637c60ba 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -998,6 +998,37 @@ Apart from these, the following properties are also available, and may be useful 2.3.0 + + spark.shuffle.digest.enabled + false + + The parameter to control whether check the transmitted data during shuffle. + + + + spark.io.encryption.enabled + false + + Enable IO encryption. Currently supported by all modes except Mesos. It's recommended that RPC encryption + be enabled when using this feature. + + + + spark.io.encryption.keySizeBits + 128 + + IO encryption key size in bits. Supported values are 128, 192 and 256. + + + + spark.io.encryption.keygen.algorithm + HmacSHA1 + + The algorithm to use when generating the IO encryption key. The supported algorithms are + described in the KeyGenerator section of the Java Cryptography Architecture Standard Algorithm + Name Documentation. + + ### Spark UI From 8aa9ab5f926dcba95bd4173077a412428b2bca4c Mon Sep 17 00:00:00 2001 From: hustfeiwang Date: Fri, 26 Apr 2019 11:49:37 +0800 Subject: [PATCH 2/2] fix code --- .../client/TransportResponseHandler.java | 8 +++---- .../protocol/DigestChunkFetchSuccess.java | 3 ++- .../protocol/DigestStreamResponse.java | 2 +- .../server/ChunkFetchRequestHandler.java | 7 +++--- .../server/TransportRequestHandler.java | 2 +- .../spark/network/util/DigestUtils.java | 7 ++---- .../network/shuffle/RetryingBlockFetcher.java | 21 +++++----------- .../sort/IndexShuffleBlockResolverSuite.scala | 3 --- docs/configuration.md | 24 ------------------- 9 files changed, 19 insertions(+), 58 deletions(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java index e83a352d0821..c7e648a71c31 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java @@ -253,7 +253,7 @@ public void handle(ResponseMessage message) throws Exception { ChunkReceivedCallback listener = outstandingFetches.get(resp.streamChunkId); if (listener == null) { logger.warn("Ignoring response for block {} from {} since it is not outstanding", - resp.streamChunkId, getRemoteAddress(channel)); + resp.streamChunkId, getRemoteAddress(channel)); resp.body().release(); } else { outstandingFetches.remove(resp.streamChunkId); @@ -266,11 +266,11 @@ public void handle(ResponseMessage message) throws Exception { if (entry != null) { StreamCallback callback = entry.getValue(); if (resp.byteCount > 0) { - StreamInterceptor interceptor = new StreamInterceptor(this, resp.streamId, resp.byteCount, - callback, resp.digest); + StreamInterceptor interceptor = new StreamInterceptor( + this, resp.streamId, resp.byteCount, callback, resp.digest); try { TransportFrameDecoder frameDecoder = (TransportFrameDecoder) - channel.pipeline().get(TransportFrameDecoder.HANDLER_NAME); + channel.pipeline().get(TransportFrameDecoder.HANDLER_NAME); frameDecoder.setInterceptor(interceptor); streamActive = true; } catch (Exception e) { diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/DigestChunkFetchSuccess.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/DigestChunkFetchSuccess.java index 733231704fde..b7e326d9457c 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/DigestChunkFetchSuccess.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/DigestChunkFetchSuccess.java @@ -33,6 +33,7 @@ public final class DigestChunkFetchSuccess extends AbstractResponseMessage { public final StreamChunkId streamChunkId; public final long digest; + public DigestChunkFetchSuccess(StreamChunkId streamChunkId, ManagedBuffer buffer, long digest) { super(buffer, true); this.streamChunkId = streamChunkId; @@ -40,7 +41,7 @@ public DigestChunkFetchSuccess(StreamChunkId streamChunkId, ManagedBuffer buffer } @Override - public Type type() { return Type.DigestChunkFetchSuccess; } + public Message.Type type() { return Type.DigestChunkFetchSuccess; } @Override public int encodedLength() { diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/DigestStreamResponse.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/DigestStreamResponse.java index 7ef4a28c0ebc..a184cba6654a 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/DigestStreamResponse.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/DigestStreamResponse.java @@ -41,7 +41,7 @@ public DigestStreamResponse(String streamId, long byteCount, ManagedBuffer buffe } @Override - public Type type() { return Type.DigestStreamResponse; } + public Message.Type type() { return Type.DigestStreamResponse; } @Override public int encodedLength() { diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/ChunkFetchRequestHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/ChunkFetchRequestHandler.java index cf06d88af158..f648a0fe6d0b 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/ChunkFetchRequestHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/ChunkFetchRequestHandler.java @@ -111,12 +111,11 @@ public void processFetchRequest( streamManager.chunkBeingSent(msg.streamChunkId.streamId); if (buf instanceof DigestFileSegmentManagedBuffer) { respond(channel, new DigestChunkFetchSuccess(msg.streamChunkId, buf, - ((DigestFileSegmentManagedBuffer)buf).getDigest())) - .addListener((ChannelFutureListener) future -> - streamManager.chunkSent(msg.streamChunkId.streamId)); + ((DigestFileSegmentManagedBuffer)buf).getDigest())).addListener( + (ChannelFutureListener) future -> streamManager.chunkSent(msg.streamChunkId.streamId)); } else { respond(channel, new ChunkFetchSuccess(msg.streamChunkId, buf)).addListener( - (ChannelFutureListener) future -> streamManager.chunkSent(msg.streamChunkId.streamId)); + (ChannelFutureListener) future -> streamManager.chunkSent(msg.streamChunkId.streamId)); } } diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java index 226852d675bb..db7a43875cdc 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java @@ -146,7 +146,7 @@ private void processStreamRequest(final StreamRequest req) { streamManager.streamBeingSent(req.streamId); if (buf instanceof DigestFileSegmentManagedBuffer) { respond(new DigestStreamResponse(req.streamId, buf.size(), buf, - ((DigestFileSegmentManagedBuffer) buf).getDigest())).addListener(future -> { + ((DigestFileSegmentManagedBuffer) buf).getDigest())).addListener(future -> { streamManager.streamSent(req.streamId); }); } else { diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/DigestUtils.java b/common/network-common/src/main/java/org/apache/spark/network/util/DigestUtils.java index 4c57c37be6c7..18cc6f8d3abe 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/DigestUtils.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/DigestUtils.java @@ -24,16 +24,13 @@ import java.util.zip.CRC32; public class DigestUtils { - private static final int STREAM_BUFFER_LENGTH = 2048; + private static final int STREAM_BUFFER_LENGTH = 8192; private static final int DIGEST_LENGTH = 8; public static int getDigestLength() { return DIGEST_LENGTH; } - public DigestUtils() { - } - public static long getDigest(InputStream data) throws IOException { return updateCRC32(getCRC32(), data); } @@ -55,7 +52,7 @@ public static CRC32 getCRC32() { public static long updateCRC32(CRC32 crc32, InputStream data) throws IOException { byte[] buffer = new byte[STREAM_BUFFER_LENGTH]; - int len = 0; + int len; while ((len = data.read(buffer)) >= 0) { crc32.update(buffer, 0, len); } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java index d9e6f9251b6a..dba15e0076dd 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java @@ -189,20 +189,7 @@ private synchronized boolean shouldRetry(Throwable e) { private class RetryingBlockFetchListener implements BlockFetchingListener { @Override public void onBlockFetchSuccess(String blockId, ManagedBuffer data) { - // We will only forward this success message to our parent listener if this block request is - // outstanding and we are still the active listener. - boolean shouldForwardSuccess = false; - synchronized (RetryingBlockFetcher.this) { - if (this == currentListener && outstandingBlocksIds.contains(blockId)) { - outstandingBlocksIds.remove(blockId); - shouldForwardSuccess = true; - } - } - - // Now actually invoke the parent listener, outside of the synchronized block. - if (shouldForwardSuccess) { - listener.onBlockFetchSuccess(blockId, data); - } + onBlockFetchSuccess(blockId, data, -1L); } @Override @@ -219,7 +206,11 @@ public void onBlockFetchSuccess(String blockId, ManagedBuffer data, long digest) // Now actually invoke the parent listener, outside of the synchronized block. if (shouldForwardSuccess) { - listener.onBlockFetchSuccess(blockId, data, digest); + if (digest < 0) { + listener.onBlockFetchSuccess(blockId, data); + } else { + listener.onBlockFetchSuccess(blockId, data, digest); + } } } diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala index a4ed7761b714..c5ba21834cc8 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala @@ -172,12 +172,9 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEa out.close() } val digest = DigestUtils.getDigest(new ByteArrayInputStream(new Array[Byte](10))) - resolver.writeIndexFileAndCommit(1, 2, lengths, dataTmp) val managedBuffer = resolver.getBlockData(ShuffleBlockId(1, 2, 0)) assert(managedBuffer.isInstanceOf[DigestFileSegmentManagedBuffer]) assert(managedBuffer.asInstanceOf[DigestFileSegmentManagedBuffer].getDigest == digest) - } - } diff --git a/docs/configuration.md b/docs/configuration.md index ed0e637c60ba..3a91d4f7531a 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1005,30 +1005,6 @@ Apart from these, the following properties are also available, and may be useful The parameter to control whether check the transmitted data during shuffle. - - spark.io.encryption.enabled - false - - Enable IO encryption. Currently supported by all modes except Mesos. It's recommended that RPC encryption - be enabled when using this feature. - - - - spark.io.encryption.keySizeBits - 128 - - IO encryption key size in bits. Supported values are 128, 192 and 256. - - - - spark.io.encryption.keygen.algorithm - HmacSHA1 - - The algorithm to use when generating the IO encryption key. The supported algorithms are - described in the KeyGenerator section of the Java Cryptography Architecture Standard Algorithm - Name Documentation. - - ### Spark UI