Skip to content

Commit dc2defd

Browse files
TroyGardenfacebook-github-bot
authored andcommitted
demostration of cuda memory footprint with multi-stream use case
Differential Revision: D85399705
1 parent 7ddc21d commit dc2defd

File tree

1 file changed

+156
-1
lines changed

1 file changed

+156
-1
lines changed

torchrec/distributed/benchmark/benchmark_comms.py

Lines changed: 156 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
see README.md for more details
2222
"""
2323

24+
from contextlib import nullcontext
2425
from dataclasses import dataclass
2526
from typing import Any, Dict, List, Optional
2627

@@ -54,6 +55,8 @@ class AllToAllSingleRunConfig(BenchFuncConfig):
5455
num_profiles: int = 2
5556
num_mul: int = 5
5657
num_concat: int = 100
58+
multi_stream: bool = True
59+
main_stream_allocation: bool = False
5760

5861

5962
def _compute(
@@ -94,6 +97,7 @@ def a2a_sync_base(
9497
num_mul: int,
9598
num_concat: int,
9699
ctx: MultiProcessContext,
100+
**_kwargs: Dict[str, Any],
97101
) -> None:
98102
with record_function("## pre-comms compute ##"):
99103
pre_comms = _compute(dim=dim, num_mul=num_mul, num_concat=num_concat, ctx=ctx)
@@ -186,6 +190,7 @@ def a2a_async_twice(
186190
num_mul: int,
187191
num_concat: int,
188192
ctx: MultiProcessContext,
193+
**_kwargs: Dict[str, Any],
189194
) -> None:
190195
with record_function("## pre-comms compute ##"):
191196
pre_comms = _compute(dim=dim, num_mul=num_mul, num_concat=num_concat, ctx=ctx)
@@ -254,13 +259,14 @@ def a2a_async_twice(
254259
assert checks1 and checks2
255260

256261

257-
# all_to_all_single with sync and single stream
262+
# LazyAwaitable
258263
def lazyawaitable(
259264
_batch_inputs: List[Dict[str, Any]],
260265
dim: int,
261266
num_mul: int,
262267
num_concat: int,
263268
ctx: MultiProcessContext,
269+
**_kwargs: Dict[str, Any],
264270
) -> None:
265271
with record_function("## pre-comms compute ##"):
266272
pre_comms = _compute(dim=dim, num_mul=num_mul, num_concat=num_concat, ctx=ctx)
@@ -294,6 +300,149 @@ def lazyawaitable(
294300
assert check_awaitable.item()
295301

296302

303+
# muti-stream memory footprint
304+
def multi_stream_memory(
305+
_batch_inputs: List[Dict[str, Any]],
306+
dim: int,
307+
num_mul: int,
308+
num_concat: int,
309+
ctx: MultiProcessContext,
310+
multi_stream: bool,
311+
**_kwargs: Dict[str, Any],
312+
) -> None:
313+
with record_function("## setup ##"):
314+
main_stream = torch.cuda.current_stream()
315+
data_copy_stream = torch.cuda.Stream() if multi_stream else nullcontext()
316+
data_dist_stream = torch.cuda.Stream() if multi_stream else nullcontext()
317+
irrelevant_data = torch.rand(dim, dim, device=ctx.device) - 0.5
318+
319+
# the host to device data transfer will block cuda execution without the `pin_memory()`
320+
host_data = (torch.rand(dim, dim) - 0.5).pin_memory()
321+
322+
with record_function("## irrelevant compute before h2d ##"):
323+
pre_comms = _compute(
324+
dim=dim, num_mul=num_mul, num_concat=num_concat, ctx=ctx, x=irrelevant_data
325+
)
326+
327+
with record_function("## copy data to device ##"):
328+
with data_copy_stream:
329+
device_data = host_data.to(ctx.device, non_blocking=True)
330+
331+
with record_function("## irrelevant compute after h2d ##"):
332+
pre_comms = _compute(
333+
dim=dim, num_mul=num_mul, num_concat=num_concat, ctx=ctx, x=irrelevant_data
334+
)
335+
336+
with record_function("## pre-comms compute ##"):
337+
if data_copy_stream is torch.cuda.Stream:
338+
main_stream.wait_stream(data_copy_stream)
339+
device_data.record_stream(main_stream)
340+
pre_comms = _compute(
341+
dim=dim, num_mul=num_mul, num_concat=num_concat, ctx=ctx, x=device_data
342+
)
343+
344+
with data_dist_stream:
345+
with record_function("## all_to_all_single ##"):
346+
if data_dist_stream is torch.cuda.Stream:
347+
data_dist_stream.wait_stream(main_stream) # pyre-ignore[16]
348+
post_comms = torch.zeros_like(pre_comms)
349+
req = dist.all_to_all_single(
350+
output=post_comms,
351+
input=pre_comms,
352+
group=ctx.pg,
353+
async_op=True,
354+
)
355+
with record_function("## a2a comm validation ##"):
356+
req.wait()
357+
checks = DeviceToHostTensorAwaitable(_validate(post_comms, ctx))
358+
359+
with record_function("## irrelevant compute after a2a ##"):
360+
pre_comms = _compute(
361+
dim=dim, num_mul=num_mul, num_concat=num_concat, ctx=ctx, x=irrelevant_data
362+
)
363+
364+
with record_function("## post-comms compute ##"):
365+
req.wait()
366+
post_comms.record_stream(main_stream)
367+
post_comms = _compute(
368+
dim=dim, num_mul=num_mul, num_concat=num_concat, ctx=ctx, x=post_comms[0]
369+
)
370+
371+
with record_function("## assert ##"):
372+
assert checks.item()
373+
374+
375+
def multi_stream_optimized(
376+
_batch_inputs: List[Dict[str, Any]],
377+
dim: int,
378+
num_mul: int,
379+
num_concat: int,
380+
ctx: MultiProcessContext,
381+
**_kwargs: Dict[str, Any],
382+
) -> None:
383+
with record_function("## setup ##"):
384+
main_stream = torch.cuda.current_stream()
385+
data_copy_stream = torch.cuda.Stream()
386+
data_dist_stream = torch.cuda.Stream()
387+
irrelevant_data = torch.rand(dim, dim, device=ctx.device) - 0.5
388+
389+
# the host to device data transfer will block cuda execution without the `pin_memory()`
390+
host_data = (torch.rand(dim, dim) - 0.5).pin_memory()
391+
device_data = torch.empty_like(host_data, device=ctx.device)
392+
393+
with record_function("## irrelevant compute before h2d ##"):
394+
pre_comms = _compute(
395+
dim=dim, num_mul=num_mul, num_concat=num_concat, ctx=ctx, x=irrelevant_data
396+
)
397+
398+
with record_function("## copy data to device ##"):
399+
with data_copy_stream:
400+
device_data.record_stream(data_copy_stream)
401+
device_data.copy_(host_data, non_blocking=True)
402+
403+
with record_function("## irrelevant compute after h2d ##"):
404+
pre_comms = _compute(
405+
dim=dim, num_mul=num_mul, num_concat=num_concat, ctx=ctx, x=irrelevant_data
406+
)
407+
408+
with record_function("## pre-comms compute ##"):
409+
if data_copy_stream is torch.cuda.Stream:
410+
main_stream.wait_stream(data_copy_stream)
411+
pre_comms = _compute(
412+
dim=dim, num_mul=num_mul, num_concat=num_concat, ctx=ctx, x=device_data
413+
)
414+
415+
with record_function("## pre-allocate memory for a2a on main stream ##"):
416+
post_comms = torch.zeros_like(pre_comms)
417+
418+
with data_dist_stream:
419+
with record_function("## all_to_all_single ##"):
420+
data_dist_stream.wait_stream(main_stream)
421+
req = dist.all_to_all_single(
422+
output=post_comms,
423+
input=pre_comms,
424+
group=ctx.pg,
425+
async_op=True,
426+
)
427+
428+
with record_function("## irrelevant compute after a2a ##"):
429+
pre_comms = _compute(
430+
dim=dim, num_mul=num_mul, num_concat=num_concat, ctx=ctx, x=irrelevant_data
431+
)
432+
433+
with record_function("## a2a comm validation ##"):
434+
req.wait()
435+
checks = DeviceToHostTensorAwaitable(_validate(post_comms, ctx))
436+
437+
with record_function("## post-comms compute ##"):
438+
post_comms = _compute(
439+
dim=dim, num_mul=num_mul, num_concat=num_concat, ctx=ctx, x=post_comms[0]
440+
)
441+
442+
with record_function("## assert ##"):
443+
assert checks.item()
444+
445+
297446
# single-rank runner
298447
def a2a_single_runner(rank: int, world_size: int, arg: AllToAllSingleRunConfig) -> None:
299448
# Ensure GPUs are available and we have enough of them
@@ -317,6 +466,10 @@ def a2a_single_runner(rank: int, world_size: int, arg: AllToAllSingleRunConfig)
317466
func = a2a_async_twice
318467
elif arg.name.startswith("lazyawaitable"):
319468
func = lazyawaitable
469+
elif arg.name.startswith("multi_stream_memory"):
470+
func = multi_stream_memory
471+
elif arg.name.startswith("multi_stream_optimized"):
472+
func = multi_stream_optimized
320473
else:
321474
raise ValueError(f"Unknown benchmark name: {arg.name}")
322475

@@ -328,6 +481,8 @@ def a2a_single_runner(rank: int, world_size: int, arg: AllToAllSingleRunConfig)
328481
"dim": arg.dim,
329482
"num_mul": arg.num_mul,
330483
"num_concat": arg.num_concat,
484+
"multi_stream": arg.multi_stream,
485+
"main_stream_allocation": arg.main_stream_allocation,
331486
},
332487
func_to_benchmark=func,
333488
rank=rank,

0 commit comments

Comments
 (0)