Skip to content

Commit

Permalink
UCX simplify receiving frames in comms (#3651)
Browse files Browse the repository at this point in the history
* Prefix `for`-loop variables with `each_*`

Should make it easier to disambiguate things like `frame` and `frames`
as they are now `each_frame` and `frames`.

* Allocate frames the same way in 0-length case

* Always allocate frames, receive non-trivial ones

* Allocate all frames to fill before receiving

* Filter out non-trivial frames to transmit
  • Loading branch information
jakirkham authored Mar 27, 2020
1 parent 7802bf3 commit 3fceec6
Showing 1 changed file with 18 additions and 22 deletions.
40 changes: 18 additions & 22 deletions distributed/comm/ucx.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,9 @@ async def write(
frames = await to_frames(
msg, serializers=serializers, on_error=on_error
)
send_frames = [
each_frame for each_frame in frames if len(each_frame) > 0
]

# Send meta data
cuda_frames = np.array(
Expand All @@ -167,6 +170,7 @@ async def write(
await self.ep.send(
np.array([nbytes(f) for f in frames], dtype=np.uint64)
)

# Send frames

# It is necessary to first synchronize the default stream before start sending
Expand All @@ -177,10 +181,9 @@ async def write(
if cuda_frames.any():
synchronize_stream(0)

for frame in frames:
if nbytes(frame) > 0:
await self.ep.send(frame)
return sum(map(nbytes, frames))
for each_frame in send_frames:
await self.ep.send(each_frame)
return sum(map(nbytes, send_frames))
except (ucp.exceptions.UCXBaseException):
self.abort()
raise CommClosedError("While writing, the connection was closed")
Expand All @@ -206,30 +209,23 @@ async def read(self, deserializers=("cuda", "dask", "pickle", "error")):
raise CommClosedError("While reading, the connection was closed")
else:
# Recv frames
frames = []
for is_cuda, size in zip(is_cudas.tolist(), sizes.tolist()):
if size > 0:
if is_cuda:
frame = cuda_array(size)
else:
frame = np.empty(size, dtype=np.uint8)
frames.append(frame)
else:
if is_cuda:
frames.append(cuda_array(size))
else:
frames.append(b"")
frames = [
cuda_array(each_size)
if is_cuda
else np.empty(each_size, dtype=np.uint8)
for is_cuda, each_size in zip(is_cudas.tolist(), sizes.tolist())
]
recv_frames = [
each_frame for each_frame in frames if len(each_frame) > 0
]

# It is necessary to first populate `frames` with CUDA arrays and synchronize
# the default stream before starting receiving to ensure buffers have been allocated
if is_cudas.any():
synchronize_stream(0)
for i, (is_cuda, size) in enumerate(
zip(is_cudas.tolist(), sizes.tolist())
):
if size > 0:
await self.ep.recv(frames[i])

for each_frame in recv_frames:
await self.ep.recv(each_frame)
msg = await from_frames(
frames, deserialize=self.deserialize, deserializers=deserializers
)
Expand Down

0 comments on commit 3fceec6

Please sign in to comment.