diff --git a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala index 210a581db466e..dcbda5a8515dd 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala @@ -73,6 +73,7 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo def uploadBlock( hostname: String, port: Int, + execId: String, blockId: BlockId, blockData: ManagedBuffer, level: StorageLevel): Future[Unit] @@ -110,9 +111,10 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo def uploadBlockSync( hostname: String, port: Int, + execId: String, blockId: BlockId, blockData: ManagedBuffer, level: StorageLevel): Unit = { - Await.result(uploadBlock(hostname, port, blockId, blockData, level), Duration.Inf) + Await.result(uploadBlock(hostname, port, execId, blockId, blockData, level), Duration.Inf) } } diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala index 1950e7bd634ee..b089da8596e2b 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala @@ -26,18 +26,10 @@ import org.apache.spark.network.BlockDataManager import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} import org.apache.spark.network.client.{RpcResponseCallback, TransportClient} import org.apache.spark.network.server.{OneForOneStreamManager, RpcHandler, StreamManager} -import org.apache.spark.network.shuffle.ShuffleStreamHandle +import org.apache.spark.network.shuffle.protocol.{BlockTransferMessage, OpenBlocks, StreamHandle, UploadBlock} import org.apache.spark.serializer.Serializer import org.apache.spark.storage.{BlockId, StorageLevel} -object NettyMessages { - /** Request to read a set of blocks. Returns [[ShuffleStreamHandle]] to identify the stream. */ - case class OpenBlocks(blockIds: Seq[BlockId]) - - /** Request to upload a block with a certain StorageLevel. Returns nothing (empty byte array). */ - case class UploadBlock(blockId: BlockId, blockData: Array[Byte], level: StorageLevel) -} - /** * Serves requests to open blocks by simply registering one chunk per block requested. * Handles opening and uploading arbitrary BlockManager blocks. @@ -50,28 +42,29 @@ class NettyBlockRpcServer( blockManager: BlockDataManager) extends RpcHandler with Logging { - import NettyMessages._ - private val streamManager = new OneForOneStreamManager() override def receive( client: TransportClient, messageBytes: Array[Byte], responseContext: RpcResponseCallback): Unit = { - val ser = serializer.newInstance() - val message = ser.deserialize[AnyRef](ByteBuffer.wrap(messageBytes)) + val message = BlockTransferMessage.Decoder.fromByteArray(messageBytes) logTrace(s"Received request: $message") message match { - case OpenBlocks(blockIds) => - val blocks: Seq[ManagedBuffer] = blockIds.map(blockManager.getBlockData) + case openBlocks: OpenBlocks => + val blocks: Seq[ManagedBuffer] = + openBlocks.blockIds.map(BlockId.apply).map(blockManager.getBlockData) val streamId = streamManager.registerStream(blocks.iterator) logTrace(s"Registered streamId $streamId with ${blocks.size} buffers") - responseContext.onSuccess( - ser.serialize(new ShuffleStreamHandle(streamId, blocks.size)).array()) + responseContext.onSuccess(new StreamHandle(streamId, blocks.size).toByteArray) - case UploadBlock(blockId, blockData, level) => - blockManager.putBlockData(blockId, new NioManagedBuffer(ByteBuffer.wrap(blockData)), level) + case uploadBlock: UploadBlock => + // StorageLevel is serialized as bytes using our JavaSerializer. + val level: StorageLevel = + serializer.newInstance().deserialize(ByteBuffer.wrap(uploadBlock.metadata)) + val data = new NioManagedBuffer(ByteBuffer.wrap(uploadBlock.blockData)) + blockManager.putBlockData(BlockId(uploadBlock.blockId), data, level) responseContext.onSuccess(new Array[Byte](0)) } } diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index b937ea825f49e..f8a7f640689a2 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -24,10 +24,10 @@ import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.network._ import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.network.client.{TransportClientBootstrap, RpcResponseCallback, TransportClientFactory} -import org.apache.spark.network.netty.NettyMessages.{OpenBlocks, UploadBlock} import org.apache.spark.network.sasl.{SaslRpcHandler, SaslClientBootstrap} import org.apache.spark.network.server._ import org.apache.spark.network.shuffle.{RetryingBlockFetcher, BlockFetchingListener, OneForOneBlockFetcher} +import org.apache.spark.network.shuffle.protocol.UploadBlock import org.apache.spark.serializer.JavaSerializer import org.apache.spark.storage.{BlockId, StorageLevel} import org.apache.spark.util.Utils @@ -46,6 +46,7 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage private[this] var transportContext: TransportContext = _ private[this] var server: TransportServer = _ private[this] var clientFactory: TransportClientFactory = _ + private[this] var appId: String = _ override def init(blockDataManager: BlockDataManager): Unit = { val (rpcHandler: RpcHandler, bootstrap: Option[TransportClientBootstrap]) = { @@ -60,6 +61,7 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage transportContext = new TransportContext(transportConf, rpcHandler) clientFactory = transportContext.createClientFactory(bootstrap.toList) server = transportContext.createServer() + appId = conf.getAppId logInfo("Server created on " + server.getPort) } @@ -74,8 +76,7 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage val blockFetchStarter = new RetryingBlockFetcher.BlockFetchStarter { override def createAndStart(blockIds: Array[String], listener: BlockFetchingListener) { val client = clientFactory.createClient(host, port) - new OneForOneBlockFetcher(client, blockIds.toArray, listener) - .start(OpenBlocks(blockIds.map(BlockId.apply))) + new OneForOneBlockFetcher(client, appId, execId, blockIds.toArray, listener).start() } } @@ -101,12 +102,17 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage override def uploadBlock( hostname: String, port: Int, + execId: String, blockId: BlockId, blockData: ManagedBuffer, level: StorageLevel): Future[Unit] = { val result = Promise[Unit]() val client = clientFactory.createClient(hostname, port) + // StorageLevel is serialized as bytes using our JavaSerializer. Everything else is encoded + // using our binary protocol. + val levelBytes = serializer.newInstance().serialize(level).array() + // Convert or copy nio buffer into array in order to serialize it. val nioBuffer = blockData.nioByteBuffer() val array = if (nioBuffer.hasArray) { @@ -117,8 +123,7 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage data } - val ser = serializer.newInstance() - client.sendRpc(ser.serialize(new UploadBlock(blockId, array, level)).array(), + client.sendRpc(new UploadBlock(appId, execId, blockId.toString, levelBytes, array).toByteArray, new RpcResponseCallback { override def onSuccess(response: Array[Byte]): Unit = { logTrace(s"Successfully uploaded block $blockId") diff --git a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala index f56d165daba55..b2aec160635c7 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala @@ -137,6 +137,7 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa override def uploadBlock( hostname: String, port: Int, + execId: String, blockId: BlockId, blockData: ManagedBuffer, level: StorageLevel) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index e48d7772d6ee9..39434f473a9d8 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -35,7 +35,8 @@ import org.apache.spark.io.CompressionCodec import org.apache.spark.network._ import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} import org.apache.spark.network.netty.{SparkTransportConf, NettyBlockTransferService} -import org.apache.spark.network.shuffle.{ExecutorShuffleInfo, ExternalShuffleClient} +import org.apache.spark.network.shuffle.ExternalShuffleClient +import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo import org.apache.spark.network.util.{ConfigProvider, TransportConf} import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.ShuffleManager @@ -939,7 +940,7 @@ private[spark] class BlockManager( data.rewind() logTrace(s"Trying to replicate $blockId of ${data.limit()} bytes to $peer") blockTransferService.uploadBlockSync( - peer.host, peer.port, blockId, new NioManagedBuffer(data), tLevel) + peer.host, peer.port, peer.executorId, blockId, new NioManagedBuffer(data), tLevel) logTrace(s"Replicated $blockId of ${data.limit()} bytes to $peer in %s ms" .format(System.currentTimeMillis - onePeerStartTime)) peersReplicatedTo += peer diff --git a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala index 9162ec9801663..530f5d6db5a29 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala @@ -36,7 +36,9 @@ import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite, ShouldMat class NettyBlockTransferSecuritySuite extends FunSuite with MockitoSugar with ShouldMatchers { test("security default off") { - testConnection(new SparkConf, new SparkConf) match { + val conf = new SparkConf() + .set("spark.app.id", "app-id") + testConnection(conf, conf) match { case Success(_) => // expected case Failure(t) => fail(t) } diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java index 152af98ced7ce..986957c1509fd 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java @@ -38,23 +38,19 @@ public ChunkFetchFailure(StreamChunkId streamChunkId, String errorString) { @Override public int encodedLength() { - return streamChunkId.encodedLength() + 4 + errorString.getBytes(Charsets.UTF_8).length; + return streamChunkId.encodedLength() + Encoders.Strings.encodedLength(errorString); } @Override public void encode(ByteBuf buf) { streamChunkId.encode(buf); - byte[] errorBytes = errorString.getBytes(Charsets.UTF_8); - buf.writeInt(errorBytes.length); - buf.writeBytes(errorBytes); + Encoders.Strings.encode(buf, errorString); } public static ChunkFetchFailure decode(ByteBuf buf) { StreamChunkId streamChunkId = StreamChunkId.decode(buf); - int numErrorStringBytes = buf.readInt(); - byte[] errorBytes = new byte[numErrorStringBytes]; - buf.readBytes(errorBytes); - return new ChunkFetchFailure(streamChunkId, new String(errorBytes, Charsets.UTF_8)); + String errorString = Encoders.Strings.decode(buf); + return new ChunkFetchFailure(streamChunkId, errorString); } @Override diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/Encoders.java b/network/common/src/main/java/org/apache/spark/network/protocol/Encoders.java new file mode 100644 index 0000000000000..873c694250942 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/Encoders.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + + +import com.google.common.base.Charsets; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; + +/** Provides a canonical set of Encoders for simple types. */ +public class Encoders { + + /** Strings are encoded with their length followed by UTF-8 bytes. */ + public static class Strings { + public static int encodedLength(String s) { + return 4 + s.getBytes(Charsets.UTF_8).length; + } + + public static void encode(ByteBuf buf, String s) { + byte[] bytes = s.getBytes(Charsets.UTF_8); + buf.writeInt(bytes.length); + buf.writeBytes(bytes); + } + + public static String decode(ByteBuf buf) { + int length = buf.readInt(); + byte[] bytes = new byte[length]; + buf.readBytes(bytes); + return new String(bytes, Charsets.UTF_8); + } + } + + /** Byte arrays are encoded with their length followed by bytes. */ + public static class ByteArrays { + public static int encodedLength(byte[] arr) { + return 4 + arr.length; + } + + public static void encode(ByteBuf buf, byte[] arr) { + buf.writeInt(arr.length); + buf.writeBytes(arr); + } + + public static byte[] decode(ByteBuf buf) { + int length = buf.readInt(); + byte[] bytes = new byte[length]; + buf.readBytes(bytes); + return bytes; + } + } + + /** String arrays are encoded with the number of strings followed by per-String encoding. */ + public static class StringArrays { + public static int encodedLength(String[] strings) { + int totalLength = 4; + for (String s : strings) { + totalLength += Strings.encodedLength(s); + } + return totalLength; + } + + public static void encode(ByteBuf buf, String[] strings) { + buf.writeInt(strings.length); + for (String s : strings) { + Strings.encode(buf, s); + } + } + + public static String[] decode(ByteBuf buf) { + int numStrings = buf.readInt(); + String[] strings = new String[numStrings]; + for (int i = 0; i < strings.length; i ++) { + strings[i] = Strings.decode(buf); + } + return strings; + } + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java b/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java index e239d4ffbd29c..ebd764eb5eb5f 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java @@ -36,23 +36,19 @@ public RpcFailure(long requestId, String errorString) { @Override public int encodedLength() { - return 8 + 4 + errorString.getBytes(Charsets.UTF_8).length; + return 8 + Encoders.Strings.encodedLength(errorString); } @Override public void encode(ByteBuf buf) { buf.writeLong(requestId); - byte[] errorBytes = errorString.getBytes(Charsets.UTF_8); - buf.writeInt(errorBytes.length); - buf.writeBytes(errorBytes); + Encoders.Strings.encode(buf, errorString); } public static RpcFailure decode(ByteBuf buf) { long requestId = buf.readLong(); - int numErrorStringBytes = buf.readInt(); - byte[] errorBytes = new byte[numErrorStringBytes]; - buf.readBytes(errorBytes); - return new RpcFailure(requestId, new String(errorBytes, Charsets.UTF_8)); + String errorString = Encoders.Strings.decode(buf); + return new RpcFailure(requestId, errorString); } @Override diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java b/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java index 099e934ae018c..cdee0b0e0316b 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java @@ -44,21 +44,18 @@ public RpcRequest(long requestId, byte[] message) { @Override public int encodedLength() { - return 8 + 4 + message.length; + return 8 + Encoders.ByteArrays.encodedLength(message); } @Override public void encode(ByteBuf buf) { buf.writeLong(requestId); - buf.writeInt(message.length); - buf.writeBytes(message); + Encoders.ByteArrays.encode(buf, message); } public static RpcRequest decode(ByteBuf buf) { long requestId = buf.readLong(); - int messageLen = buf.readInt(); - byte[] message = new byte[messageLen]; - buf.readBytes(message); + byte[] message = Encoders.ByteArrays.decode(buf); return new RpcRequest(requestId, message); } diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java b/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java index ed479478325b6..0a62e09a8115c 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java @@ -36,20 +36,17 @@ public RpcResponse(long requestId, byte[] response) { public Type type() { return Type.RpcResponse; } @Override - public int encodedLength() { return 8 + 4 + response.length; } + public int encodedLength() { return 8 + Encoders.ByteArrays.encodedLength(response); } @Override public void encode(ByteBuf buf) { buf.writeLong(requestId); - buf.writeInt(response.length); - buf.writeBytes(response); + Encoders.ByteArrays.encode(buf, response); } public static RpcResponse decode(ByteBuf buf) { long requestId = buf.readLong(); - int responseLen = buf.readInt(); - byte[] response = new byte[responseLen]; - buf.readBytes(response); + byte[] response = Encoders.ByteArrays.decode(buf); return new RpcResponse(requestId, response); } diff --git a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java index 2856d1c8c9337..2ebdccb80549b 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java +++ b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java @@ -44,33 +44,6 @@ public static void closeQuietly(Closeable closeable) { } } - // TODO: Make this configurable, do not use Java serialization! - public static T deserialize(byte[] bytes) { - try { - ObjectInputStream is = new ObjectInputStream(new ByteArrayInputStream(bytes)); - Object out = is.readObject(); - is.close(); - return (T) out; - } catch (ClassNotFoundException e) { - throw new RuntimeException("Could not deserialize object", e); - } catch (IOException e) { - throw new RuntimeException("Could not deserialize object", e); - } - } - - // TODO: Make this configurable, do not use Java serialization! - public static byte[] serialize(Object object) { - try { - ByteArrayOutputStream baos = new ByteArrayOutputStream(); - ObjectOutputStream os = new ObjectOutputStream(baos); - os.writeObject(object); - os.close(); - return baos.toByteArray(); - } catch (IOException e) { - throw new RuntimeException("Could not serialize object", e); - } - } - /** Returns a hash consistent with Spark's Utils.nonNegativeHash(). */ public static int nonNegativeHash(Object obj) { if (obj == null) { return 0; } diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslMessage.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslMessage.java index 599cc6428c90e..cad76ab7aa54e 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslMessage.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslMessage.java @@ -17,10 +17,10 @@ package org.apache.spark.network.sasl; -import com.google.common.base.Charsets; import io.netty.buffer.ByteBuf; import org.apache.spark.network.protocol.Encodable; +import org.apache.spark.network.protocol.Encoders; /** * Encodes a Sasl-related message which is attempting to authenticate using some credentials tagged @@ -42,18 +42,14 @@ public SaslMessage(String appId, byte[] payload) { @Override public int encodedLength() { - // tag + appIdLength + appId + payloadLength + payload - return 1 + 4 + appId.getBytes(Charsets.UTF_8).length + 4 + payload.length; + return 1 + Encoders.Strings.encodedLength(appId) + Encoders.ByteArrays.encodedLength(payload); } @Override public void encode(ByteBuf buf) { buf.writeByte(TAG_BYTE); - byte[] idBytes = appId.getBytes(Charsets.UTF_8); - buf.writeInt(idBytes.length); - buf.writeBytes(idBytes); - buf.writeInt(payload.length); - buf.writeBytes(payload); + Encoders.Strings.encode(buf, appId); + Encoders.ByteArrays.encode(buf, payload); } public static SaslMessage decode(ByteBuf buf) { @@ -62,14 +58,8 @@ public static SaslMessage decode(ByteBuf buf) { + " (maybe your client does not have SASL enabled?)"); } - int idLength = buf.readInt(); - byte[] idBytes = new byte[idLength]; - buf.readBytes(idBytes); - - int payloadLength = buf.readInt(); - byte[] payload = new byte[payloadLength]; - buf.readBytes(payload); - - return new SaslMessage(new String(idBytes, Charsets.UTF_8), payload); + String appId = Encoders.Strings.decode(buf); + byte[] payload = Encoders.ByteArrays.decode(buf); + return new SaslMessage(appId, payload); } } diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java index cd3fea85b19a4..11619f8663cbf 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java @@ -24,15 +24,16 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import static org.apache.spark.network.shuffle.ExternalShuffleMessages.*; - import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.server.OneForOneStreamManager; import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.StreamManager; -import org.apache.spark.network.util.JavaUtils; +import org.apache.spark.network.shuffle.protocol.BlockTransferMessage; +import org.apache.spark.network.shuffle.protocol.OpenBlocks; +import org.apache.spark.network.shuffle.protocol.RegisterExecutor; +import org.apache.spark.network.shuffle.protocol.StreamHandle; /** * RPC Handler for a server which can serve shuffle blocks from outside of an Executor process. @@ -62,12 +63,10 @@ public ExternalShuffleBlockHandler() { @Override public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { - Object msgObj = JavaUtils.deserialize(message); - - logger.trace("Received message: " + msgObj); + BlockTransferMessage msgObj = BlockTransferMessage.Decoder.fromByteArray(message); - if (msgObj instanceof OpenShuffleBlocks) { - OpenShuffleBlocks msg = (OpenShuffleBlocks) msgObj; + if (msgObj instanceof OpenBlocks) { + OpenBlocks msg = (OpenBlocks) msgObj; List blocks = Lists.newArrayList(); for (String blockId : msg.blockIds) { @@ -75,8 +74,7 @@ public void receive(TransportClient client, byte[] message, RpcResponseCallback } long streamId = streamManager.registerStream(blocks.iterator()); logger.trace("Registered streamId {} with {} buffers", streamId, msg.blockIds.length); - callback.onSuccess(JavaUtils.serialize( - new ShuffleStreamHandle(streamId, msg.blockIds.length))); + callback.onSuccess(new StreamHandle(streamId, msg.blockIds.length).toByteArray()); } else if (msgObj instanceof RegisterExecutor) { RegisterExecutor msg = (RegisterExecutor) msgObj; @@ -84,8 +82,7 @@ public void receive(TransportClient client, byte[] message, RpcResponseCallback callback.onSuccess(new byte[0]); } else { - throw new UnsupportedOperationException(String.format( - "Unexpected message: %s (class = %s)", msgObj, msgObj.getClass())); + throw new UnsupportedOperationException("Unexpected message: " + msgObj); } } diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManager.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManager.java index 6589889fe1be7..f7125273a97bb 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManager.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManager.java @@ -29,6 +29,7 @@ import org.apache.spark.network.buffer.FileSegmentManagedBuffer; import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; import org.apache.spark.network.util.JavaUtils; /** diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java index 27884b82c8cb9..6e8018b723dc6 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java @@ -31,8 +31,8 @@ import org.apache.spark.network.sasl.SaslClientBootstrap; import org.apache.spark.network.sasl.SecretKeyHolder; import org.apache.spark.network.server.NoOpRpcHandler; -import org.apache.spark.network.shuffle.ExternalShuffleMessages.RegisterExecutor; -import org.apache.spark.network.util.JavaUtils; +import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; +import org.apache.spark.network.shuffle.protocol.RegisterExecutor; import org.apache.spark.network.util.TransportConf; /** @@ -91,8 +91,7 @@ public void fetchBlocks( public void createAndStart(String[] blockIds, BlockFetchingListener listener) throws IOException { TransportClient client = clientFactory.createClient(host, port); - new OneForOneBlockFetcher(client, blockIds, listener) - .start(new ExternalShuffleMessages.OpenShuffleBlocks(appId, execId, blockIds)); + new OneForOneBlockFetcher(client, appId, execId, blockIds, listener).start(); } }; @@ -128,9 +127,8 @@ public void registerWithShuffleServer( ExecutorShuffleInfo executorInfo) throws IOException { assert appId != null : "Called before init()"; TransportClient client = clientFactory.createClient(host, port); - byte[] registerExecutorMessage = - JavaUtils.serialize(new RegisterExecutor(appId, execId, executorInfo)); - client.sendRpcSync(registerExecutorMessage, 5000 /* timeoutMs */); + byte[] registerMessage = new RegisterExecutor(appId, execId, executorInfo).toByteArray(); + client.sendRpcSync(registerMessage, 5000 /* timeoutMs */); } @Override diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleMessages.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleMessages.java deleted file mode 100644 index e79420ed8254f..0000000000000 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleMessages.java +++ /dev/null @@ -1,106 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.shuffle; - -import java.io.Serializable; -import java.util.Arrays; - -import com.google.common.base.Objects; - -/** Messages handled by the {@link ExternalShuffleBlockHandler}. */ -public class ExternalShuffleMessages { - - /** Request to read a set of shuffle blocks. Returns [[ShuffleStreamHandle]]. */ - public static class OpenShuffleBlocks implements Serializable { - public final String appId; - public final String execId; - public final String[] blockIds; - - public OpenShuffleBlocks(String appId, String execId, String[] blockIds) { - this.appId = appId; - this.execId = execId; - this.blockIds = blockIds; - } - - @Override - public int hashCode() { - return Objects.hashCode(appId, execId) * 41 + Arrays.hashCode(blockIds); - } - - @Override - public String toString() { - return Objects.toStringHelper(this) - .add("appId", appId) - .add("execId", execId) - .add("blockIds", Arrays.toString(blockIds)) - .toString(); - } - - @Override - public boolean equals(Object other) { - if (other != null && other instanceof OpenShuffleBlocks) { - OpenShuffleBlocks o = (OpenShuffleBlocks) other; - return Objects.equal(appId, o.appId) - && Objects.equal(execId, o.execId) - && Arrays.equals(blockIds, o.blockIds); - } - return false; - } - } - - /** Initial registration message between an executor and its local shuffle server. */ - public static class RegisterExecutor implements Serializable { - public final String appId; - public final String execId; - public final ExecutorShuffleInfo executorInfo; - - public RegisterExecutor( - String appId, - String execId, - ExecutorShuffleInfo executorInfo) { - this.appId = appId; - this.execId = execId; - this.executorInfo = executorInfo; - } - - @Override - public int hashCode() { - return Objects.hashCode(appId, execId, executorInfo); - } - - @Override - public String toString() { - return Objects.toStringHelper(this) - .add("appId", appId) - .add("execId", execId) - .add("executorInfo", executorInfo) - .toString(); - } - - @Override - public boolean equals(Object other) { - if (other != null && other instanceof RegisterExecutor) { - RegisterExecutor o = (RegisterExecutor) other; - return Objects.equal(appId, o.appId) - && Objects.equal(execId, o.execId) - && Objects.equal(executorInfo, o.executorInfo); - } - return false; - } - } -} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java index 9e77a1f68c4b0..8ed2e0b39ad23 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java @@ -26,6 +26,9 @@ import org.apache.spark.network.client.ChunkReceivedCallback; import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.shuffle.protocol.BlockTransferMessage; +import org.apache.spark.network.shuffle.protocol.OpenBlocks; +import org.apache.spark.network.shuffle.protocol.StreamHandle; import org.apache.spark.network.util.JavaUtils; /** @@ -41,17 +44,21 @@ public class OneForOneBlockFetcher { private final Logger logger = LoggerFactory.getLogger(OneForOneBlockFetcher.class); private final TransportClient client; + private final OpenBlocks openMessage; private final String[] blockIds; private final BlockFetchingListener listener; private final ChunkReceivedCallback chunkCallback; - private ShuffleStreamHandle streamHandle = null; + private StreamHandle streamHandle = null; public OneForOneBlockFetcher( TransportClient client, + String appId, + String execId, String[] blockIds, BlockFetchingListener listener) { this.client = client; + this.openMessage = new OpenBlocks(appId, execId, blockIds); this.blockIds = blockIds; this.listener = listener; this.chunkCallback = new ChunkCallback(); @@ -76,18 +83,18 @@ public void onFailure(int chunkIndex, Throwable e) { /** * Begins the fetching process, calling the listener with every block fetched. * The given message will be serialized with the Java serializer, and the RPC must return a - * {@link ShuffleStreamHandle}. We will send all fetch requests immediately, without throttling. + * {@link StreamHandle}. We will send all fetch requests immediately, without throttling. */ - public void start(Object openBlocksMessage) { + public void start() { if (blockIds.length == 0) { throw new IllegalArgumentException("Zero-sized blockIds array"); } - client.sendRpc(JavaUtils.serialize(openBlocksMessage), new RpcResponseCallback() { + client.sendRpc(openMessage.toByteArray(), new RpcResponseCallback() { @Override public void onSuccess(byte[] response) { try { - streamHandle = JavaUtils.deserialize(response); + streamHandle = (StreamHandle) BlockTransferMessage.Decoder.fromByteArray(response); logger.trace("Successfully opened blocks {}, preparing to fetch chunks.", streamHandle); // Immediately request all chunks -- we expect that the total size of the request is diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java new file mode 100644 index 0000000000000..b4b13b8a6ef5d --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle.protocol; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; + +import org.apache.spark.network.protocol.Encodable; + +/** + * Messages handled by the {@link org.apache.spark.network.shuffle.ExternalShuffleBlockHandler}, or + * by Spark's NettyBlockTransferService. + * + * At a high level: + * - OpenBlock is handled by both services, but only services shuffle files for the external + * shuffle service. It returns a StreamHandle. + * - UploadBlock is only handled by the NettyBlockTransferService. + * - RegisterExecutor is only handled by the external shuffle service. + */ +public abstract class BlockTransferMessage implements Encodable { + protected abstract Type type(); + + /** Preceding every serialized message is its type, which allows us to deserialize it. */ + public static enum Type { + OPEN_BLOCKS(0), UPLOAD_BLOCK(1), REGISTER_EXECUTOR(2), STREAM_HANDLE(3); + + private final byte id; + + private Type(int id) { + assert id < 128 : "Cannot have more than 128 message types"; + this.id = (byte) id; + } + + public byte id() { return id; } + } + + // NB: Java does not support static methods in interfaces, so we must put this in a static class. + public static class Decoder { + /** Deserializes the 'type' byte followed by the message itself. */ + public static BlockTransferMessage fromByteArray(byte[] msg) { + ByteBuf buf = Unpooled.wrappedBuffer(msg); + byte type = buf.readByte(); + switch (type) { + case 0: return OpenBlocks.decode(buf); + case 1: return UploadBlock.decode(buf); + case 2: return RegisterExecutor.decode(buf); + case 3: return StreamHandle.decode(buf); + default: throw new IllegalArgumentException("Unknown message type: " + type); + } + } + } + + /** Serializes the 'type' byte followed by the message itself. */ + public byte[] toByteArray() { + ByteBuf buf = Unpooled.buffer(encodedLength()); + buf.writeByte(type().id); + encode(buf); + assert buf.writableBytes() == 0 : "Writable bytes remain: " + buf.writableBytes(); + return buf.array(); + } +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExecutorShuffleInfo.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExecutorShuffleInfo.java similarity index 68% rename from network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExecutorShuffleInfo.java rename to network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExecutorShuffleInfo.java index d45e64656a0e3..cadc8e8369c6a 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExecutorShuffleInfo.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExecutorShuffleInfo.java @@ -15,21 +15,24 @@ * limitations under the License. */ -package org.apache.spark.network.shuffle; +package org.apache.spark.network.shuffle.protocol; -import java.io.Serializable; import java.util.Arrays; import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +import org.apache.spark.network.protocol.Encodable; +import org.apache.spark.network.protocol.Encoders; /** Contains all configuration necessary for locating the shuffle files of an executor. */ -public class ExecutorShuffleInfo implements Serializable { +public class ExecutorShuffleInfo implements Encodable { /** The base set of local directories that the executor stores its shuffle files in. */ - final String[] localDirs; + public final String[] localDirs; /** Number of subdirectories created within each localDir. */ - final int subDirsPerLocalDir; + public final int subDirsPerLocalDir; /** Shuffle manager (SortShuffleManager or HashShuffleManager) that the executor is using. */ - final String shuffleManager; + public final String shuffleManager; public ExecutorShuffleInfo(String[] localDirs, int subDirsPerLocalDir, String shuffleManager) { this.localDirs = localDirs; @@ -61,4 +64,25 @@ public boolean equals(Object other) { } return false; } + + @Override + public int encodedLength() { + return Encoders.StringArrays.encodedLength(localDirs) + + 4 // int + + Encoders.Strings.encodedLength(shuffleManager); + } + + @Override + public void encode(ByteBuf buf) { + Encoders.StringArrays.encode(buf, localDirs); + buf.writeInt(subDirsPerLocalDir); + Encoders.Strings.encode(buf, shuffleManager); + } + + public static ExecutorShuffleInfo decode(ByteBuf buf) { + String[] localDirs = Encoders.StringArrays.decode(buf); + int subDirsPerLocalDir = buf.readInt(); + String shuffleManager = Encoders.Strings.decode(buf); + return new ExecutorShuffleInfo(localDirs, subDirsPerLocalDir, shuffleManager); + } } diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenBlocks.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenBlocks.java new file mode 100644 index 0000000000000..60485bace643c --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenBlocks.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle.protocol; + +import java.util.Arrays; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +import org.apache.spark.network.protocol.Encoders; + +/** Request to read a set of blocks. Returns {@link StreamHandle}. */ +public class OpenBlocks extends BlockTransferMessage { + public final String appId; + public final String execId; + public final String[] blockIds; + + public OpenBlocks(String appId, String execId, String[] blockIds) { + this.appId = appId; + this.execId = execId; + this.blockIds = blockIds; + } + + @Override + protected Type type() { return Type.OPEN_BLOCKS; } + + @Override + public int hashCode() { + return Objects.hashCode(appId, execId) * 41 + Arrays.hashCode(blockIds); + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("appId", appId) + .add("execId", execId) + .add("blockIds", Arrays.toString(blockIds)) + .toString(); + } + + @Override + public boolean equals(Object other) { + if (other != null && other instanceof OpenBlocks) { + OpenBlocks o = (OpenBlocks) other; + return Objects.equal(appId, o.appId) + && Objects.equal(execId, o.execId) + && Arrays.equals(blockIds, o.blockIds); + } + return false; + } + + @Override + public int encodedLength() { + return Encoders.Strings.encodedLength(appId) + + Encoders.Strings.encodedLength(execId) + + Encoders.StringArrays.encodedLength(blockIds); + } + + @Override + public void encode(ByteBuf buf) { + Encoders.Strings.encode(buf, appId); + Encoders.Strings.encode(buf, execId); + Encoders.StringArrays.encode(buf, blockIds); + } + + public static OpenBlocks decode(ByteBuf buf) { + String appId = Encoders.Strings.decode(buf); + String execId = Encoders.Strings.decode(buf); + String[] blockIds = Encoders.StringArrays.decode(buf); + return new OpenBlocks(appId, execId, blockIds); + } +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java new file mode 100644 index 0000000000000..38acae3b31d64 --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle.protocol; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +import org.apache.spark.network.protocol.Encoders; + +/** + * Initial registration message between an executor and its local shuffle server. + * Returns nothing (empty bye array). + */ +public class RegisterExecutor extends BlockTransferMessage { + public final String appId; + public final String execId; + public final ExecutorShuffleInfo executorInfo; + + public RegisterExecutor( + String appId, + String execId, + ExecutorShuffleInfo executorInfo) { + this.appId = appId; + this.execId = execId; + this.executorInfo = executorInfo; + } + + @Override + protected Type type() { return Type.REGISTER_EXECUTOR; } + + @Override + public int hashCode() { + return Objects.hashCode(appId, execId, executorInfo); + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("appId", appId) + .add("execId", execId) + .add("executorInfo", executorInfo) + .toString(); + } + + @Override + public boolean equals(Object other) { + if (other != null && other instanceof RegisterExecutor) { + RegisterExecutor o = (RegisterExecutor) other; + return Objects.equal(appId, o.appId) + && Objects.equal(execId, o.execId) + && Objects.equal(executorInfo, o.executorInfo); + } + return false; + } + + @Override + public int encodedLength() { + return Encoders.Strings.encodedLength(appId) + + Encoders.Strings.encodedLength(execId) + + executorInfo.encodedLength(); + } + + @Override + public void encode(ByteBuf buf) { + Encoders.Strings.encode(buf, appId); + Encoders.Strings.encode(buf, execId); + executorInfo.encode(buf); + } + + public static RegisterExecutor decode(ByteBuf buf) { + String appId = Encoders.Strings.decode(buf); + String execId = Encoders.Strings.decode(buf); + ExecutorShuffleInfo executorShuffleInfo = ExecutorShuffleInfo.decode(buf); + return new RegisterExecutor(appId, execId, executorShuffleInfo); + } +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleStreamHandle.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/StreamHandle.java similarity index 65% rename from network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleStreamHandle.java rename to network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/StreamHandle.java index 9c94691224328..21369c8cfb0d6 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleStreamHandle.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/StreamHandle.java @@ -15,26 +15,29 @@ * limitations under the License. */ -package org.apache.spark.network.shuffle; +package org.apache.spark.network.shuffle.protocol; import java.io.Serializable; -import java.util.Arrays; import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; /** * Identifier for a fixed number of chunks to read from a stream created by an "open blocks" - * message. This is used by {@link OneForOneBlockFetcher}. + * message. This is used by {@link org.apache.spark.network.shuffle.OneForOneBlockFetcher}. */ -public class ShuffleStreamHandle implements Serializable { +public class StreamHandle extends BlockTransferMessage { public final long streamId; public final int numChunks; - public ShuffleStreamHandle(long streamId, int numChunks) { + public StreamHandle(long streamId, int numChunks) { this.streamId = streamId; this.numChunks = numChunks; } + @Override + protected Type type() { return Type.STREAM_HANDLE; } + @Override public int hashCode() { return Objects.hashCode(streamId, numChunks); @@ -50,11 +53,28 @@ public String toString() { @Override public boolean equals(Object other) { - if (other != null && other instanceof ShuffleStreamHandle) { - ShuffleStreamHandle o = (ShuffleStreamHandle) other; + if (other != null && other instanceof StreamHandle) { + StreamHandle o = (StreamHandle) other; return Objects.equal(streamId, o.streamId) && Objects.equal(numChunks, o.numChunks); } return false; } + + @Override + public int encodedLength() { + return 8 + 4; + } + + @Override + public void encode(ByteBuf buf) { + buf.writeLong(streamId); + buf.writeInt(numChunks); + } + + public static StreamHandle decode(ByteBuf buf) { + long streamId = buf.readLong(); + int numChunks = buf.readInt(); + return new StreamHandle(streamId, numChunks); + } } diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadBlock.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadBlock.java new file mode 100644 index 0000000000000..38abe29cc585f --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadBlock.java @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle.protocol; + +import java.util.Arrays; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +import org.apache.spark.network.protocol.Encoders; + +/** Request to upload a block with a certain StorageLevel. Returns nothing (empty byte array). */ +public class UploadBlock extends BlockTransferMessage { + public final String appId; + public final String execId; + public final String blockId; + // TODO: StorageLevel is serialized separately in here because StorageLevel is not available in + // this package. We should avoid this hack. + public final byte[] metadata; + public final byte[] blockData; + + /** + * @param metadata Meta-information about block, typically StorageLevel. + * @param blockData The actual block's bytes. + */ + public UploadBlock( + String appId, + String execId, + String blockId, + byte[] metadata, + byte[] blockData) { + this.appId = appId; + this.execId = execId; + this.blockId = blockId; + this.metadata = metadata; + this.blockData = blockData; + } + + @Override + protected Type type() { return Type.UPLOAD_BLOCK; } + + @Override + public int hashCode() { + int objectsHashCode = Objects.hashCode(appId, execId, blockId); + return (objectsHashCode * 41 + Arrays.hashCode(metadata)) * 41 + Arrays.hashCode(blockData); + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("appId", appId) + .add("execId", execId) + .add("blockId", blockId) + .add("metadata size", metadata.length) + .add("block size", blockData.length) + .toString(); + } + + @Override + public boolean equals(Object other) { + if (other != null && other instanceof UploadBlock) { + UploadBlock o = (UploadBlock) other; + return Objects.equal(appId, o.appId) + && Objects.equal(execId, o.execId) + && Objects.equal(blockId, o.blockId) + && Arrays.equals(metadata, o.metadata) + && Arrays.equals(blockData, o.blockData); + } + return false; + } + + @Override + public int encodedLength() { + return Encoders.Strings.encodedLength(appId) + + Encoders.Strings.encodedLength(execId) + + Encoders.Strings.encodedLength(blockId) + + Encoders.ByteArrays.encodedLength(metadata) + + Encoders.ByteArrays.encodedLength(blockData); + } + + @Override + public void encode(ByteBuf buf) { + Encoders.Strings.encode(buf, appId); + Encoders.Strings.encode(buf, execId); + Encoders.Strings.encode(buf, blockId); + Encoders.ByteArrays.encode(buf, metadata); + Encoders.ByteArrays.encode(buf, blockData); + } + + public static UploadBlock decode(ByteBuf buf) { + String appId = Encoders.Strings.decode(buf); + String execId = Encoders.Strings.decode(buf); + String blockId = Encoders.Strings.decode(buf); + byte[] metadata = Encoders.ByteArrays.decode(buf); + byte[] blockData = Encoders.ByteArrays.decode(buf); + return new UploadBlock(appId, execId, blockId, metadata, blockData); + } +} diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ShuffleMessagesSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java similarity index 55% rename from network/shuffle/src/test/java/org/apache/spark/network/shuffle/ShuffleMessagesSuite.java rename to network/shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java index ee9482b49cfc3..d65de9ca550a3 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ShuffleMessagesSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java @@ -21,31 +21,24 @@ import static org.junit.Assert.*; -import org.apache.spark.network.util.JavaUtils; +import org.apache.spark.network.shuffle.protocol.*; -import static org.apache.spark.network.shuffle.ExternalShuffleMessages.*; - -public class ShuffleMessagesSuite { +/** Verifies that all BlockTransferMessages can be serialized correctly. */ +public class BlockTransferMessagesSuite { @Test public void serializeOpenShuffleBlocks() { - OpenShuffleBlocks msg = new OpenShuffleBlocks("app-1", "exec-2", - new String[] { "block0", "block1" }); - OpenShuffleBlocks msg2 = JavaUtils.deserialize(JavaUtils.serialize(msg)); - assertEquals(msg, msg2); + checkSerializeDeserialize(new OpenBlocks("app-1", "exec-2", new String[] { "b1", "b2" })); + checkSerializeDeserialize(new RegisterExecutor("app-1", "exec-2", new ExecutorShuffleInfo( + new String[] { "/local1", "/local2" }, 32, "MyShuffleManager"))); + checkSerializeDeserialize(new UploadBlock("app-1", "exec-2", "block-3", new byte[] { 1, 2 }, + new byte[] { 4, 5, 6, 7} )); + checkSerializeDeserialize(new StreamHandle(12345, 16)); } - @Test - public void serializeRegisterExecutor() { - RegisterExecutor msg = new RegisterExecutor("app-1", "exec-2", new ExecutorShuffleInfo( - new String[] { "/local1", "/local2" }, 32, "MyShuffleManager")); - RegisterExecutor msg2 = JavaUtils.deserialize(JavaUtils.serialize(msg)); - assertEquals(msg, msg2); - } - - @Test - public void serializeShuffleStreamHandle() { - ShuffleStreamHandle msg = new ShuffleStreamHandle(12345, 16); - ShuffleStreamHandle msg2 = JavaUtils.deserialize(JavaUtils.serialize(msg)); + private void checkSerializeDeserialize(BlockTransferMessage msg) { + BlockTransferMessage msg2 = BlockTransferMessage.Decoder.fromByteArray(msg.toByteArray()); assertEquals(msg, msg2); + assertEquals(msg.hashCode(), msg2.hashCode()); + assertEquals(msg.toString(), msg2.toString()); } } diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java index 7939cb4d32690..3f9fe1681cf27 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java @@ -24,8 +24,6 @@ import org.junit.Test; import org.mockito.ArgumentCaptor; -import static org.apache.spark.network.shuffle.ExternalShuffleMessages.OpenShuffleBlocks; -import static org.apache.spark.network.shuffle.ExternalShuffleMessages.RegisterExecutor; import static org.junit.Assert.*; import static org.mockito.Matchers.any; import static org.mockito.Mockito.*; @@ -36,7 +34,12 @@ import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.server.OneForOneStreamManager; import org.apache.spark.network.server.RpcHandler; -import org.apache.spark.network.util.JavaUtils; +import org.apache.spark.network.shuffle.protocol.BlockTransferMessage; +import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; +import org.apache.spark.network.shuffle.protocol.OpenBlocks; +import org.apache.spark.network.shuffle.protocol.RegisterExecutor; +import org.apache.spark.network.shuffle.protocol.StreamHandle; +import org.apache.spark.network.shuffle.protocol.UploadBlock; public class ExternalShuffleBlockHandlerSuite { TransportClient client = mock(TransportClient.class); @@ -57,8 +60,7 @@ public void testRegisterExecutor() { RpcResponseCallback callback = mock(RpcResponseCallback.class); ExecutorShuffleInfo config = new ExecutorShuffleInfo(new String[] {"/a", "/b"}, 16, "sort"); - byte[] registerMessage = JavaUtils.serialize( - new RegisterExecutor("app0", "exec1", config)); + byte[] registerMessage = new RegisterExecutor("app0", "exec1", config).toByteArray(); handler.receive(client, registerMessage, callback); verify(blockManager, times(1)).registerExecutor("app0", "exec1", config); @@ -75,9 +77,8 @@ public void testOpenShuffleBlocks() { ManagedBuffer block1Marker = new NioManagedBuffer(ByteBuffer.wrap(new byte[7])); when(blockManager.getBlockData("app0", "exec1", "b0")).thenReturn(block0Marker); when(blockManager.getBlockData("app0", "exec1", "b1")).thenReturn(block1Marker); - byte[] openBlocksMessage = JavaUtils.serialize( - new OpenShuffleBlocks("app0", "exec1", new String[] { "b0", "b1" })); - handler.receive(client, openBlocksMessage, callback); + byte[] openBlocks = new OpenBlocks("app0", "exec1", new String[] { "b0", "b1" }).toByteArray(); + handler.receive(client, openBlocks, callback); verify(blockManager, times(1)).getBlockData("app0", "exec1", "b0"); verify(blockManager, times(1)).getBlockData("app0", "exec1", "b1"); @@ -85,7 +86,8 @@ public void testOpenShuffleBlocks() { verify(callback, times(1)).onSuccess(response.capture()); verify(callback, never()).onFailure((Throwable) any()); - ShuffleStreamHandle handle = JavaUtils.deserialize(response.getValue()); + StreamHandle handle = + (StreamHandle) BlockTransferMessage.Decoder.fromByteArray(response.getValue()); assertEquals(2, handle.numChunks); ArgumentCaptor stream = ArgumentCaptor.forClass(Iterator.class); @@ -100,18 +102,17 @@ public void testOpenShuffleBlocks() { public void testBadMessages() { RpcResponseCallback callback = mock(RpcResponseCallback.class); - byte[] unserializableMessage = new byte[] { 0x12, 0x34, 0x56 }; + byte[] unserializableMsg = new byte[] { 0x12, 0x34, 0x56 }; try { - handler.receive(client, unserializableMessage, callback); + handler.receive(client, unserializableMsg, callback); fail("Should have thrown"); } catch (Exception e) { // pass } - byte[] unexpectedMessage = JavaUtils.serialize( - new ExecutorShuffleInfo(new String[] {"/a", "/b"}, 16, "sort")); + byte[] unexpectedMsg = new UploadBlock("a", "e", "b", new byte[1], new byte[2]).toByteArray(); try { - handler.receive(client, unexpectedMessage, callback); + handler.receive(client, unexpectedMsg, callback); fail("Should have thrown"); } catch (UnsupportedOperationException e) { // pass diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java index 06294fef19621..fcc36662176db 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java @@ -42,6 +42,7 @@ import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.buffer.NioManagedBuffer; import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; import org.apache.spark.network.util.SystemPropertyConfigProvider; import org.apache.spark.network.util.TransportConf; diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java index 848c88f743d50..8afceab1d585a 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java @@ -31,6 +31,7 @@ import org.apache.spark.network.sasl.SecretKeyHolder; import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; import org.apache.spark.network.util.SystemPropertyConfigProvider; import org.apache.spark.network.util.TransportConf; diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java index c18346f6966d6..842741e3d354f 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java @@ -40,7 +40,9 @@ import org.apache.spark.network.client.ChunkReceivedCallback; import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; -import org.apache.spark.network.util.JavaUtils; +import org.apache.spark.network.shuffle.protocol.BlockTransferMessage; +import org.apache.spark.network.shuffle.protocol.OpenBlocks; +import org.apache.spark.network.shuffle.protocol.StreamHandle; public class OneForOneBlockFetcherSuite { @Test @@ -119,17 +121,19 @@ public void testEmptyBlockFetch() { private BlockFetchingListener fetchBlocks(final LinkedHashMap blocks) { TransportClient client = mock(TransportClient.class); BlockFetchingListener listener = mock(BlockFetchingListener.class); - String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]); - OneForOneBlockFetcher fetcher = new OneForOneBlockFetcher(client, blockIds, listener); + final String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]); + OneForOneBlockFetcher fetcher = + new OneForOneBlockFetcher(client, "app-id", "exec-id", blockIds, listener); // Respond to the "OpenBlocks" message with an appropirate ShuffleStreamHandle with streamId 123 doAnswer(new Answer() { @Override public Void answer(InvocationOnMock invocationOnMock) throws Throwable { - String message = JavaUtils.deserialize((byte[]) invocationOnMock.getArguments()[0]); + BlockTransferMessage message = BlockTransferMessage.Decoder.fromByteArray( + (byte[]) invocationOnMock.getArguments()[0]); RpcResponseCallback callback = (RpcResponseCallback) invocationOnMock.getArguments()[1]; - callback.onSuccess(JavaUtils.serialize(new ShuffleStreamHandle(123, blocks.size()))); - assertEquals("OpenZeBlocks", message); + callback.onSuccess(new StreamHandle(123, blocks.size()).toByteArray()); + assertEquals(new OpenBlocks("app-id", "exec-id", blockIds), message); return null; } }).when(client).sendRpc((byte[]) any(), (RpcResponseCallback) any()); @@ -161,7 +165,7 @@ public Void answer(InvocationOnMock invocation) throws Throwable { } }).when(client).fetchChunk(anyLong(), anyInt(), (ChunkReceivedCallback) any()); - fetcher.start("OpenZeBlocks"); + fetcher.start(); return listener; } } diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java index 442b756467442..933decdc5716a 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java @@ -25,6 +25,8 @@ import com.google.common.io.Files; +import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; + /** * Manages some sort- and hash-based shuffle data, including the creation * and cleanup of directories that can be read by the {@link ExternalShuffleBlockManager}.