Skip to content

Commit e2dbe39

Browse files
TroyGardenfacebook-github-bot
authored andcommitted
demonstration of cuda memory footprint in multi-stream scenario (#3480)
Summary: Google Document: https://docs.google.com/document/d/1Odt6oJJgvPDeVSQmQqrl3iUAl2yu_HPUVLJ92EqDiTI Workplace Post: https://fb.workplace.com/groups/429376538334034/permalink/1488841649054179/ # context * high-level design and technical discussions are in the document/post * this diff added three benchmark jobs to demonstrate the memory footprint in multi-stream vs single-stream scenarios * other changes in the benchmark function: **a**. make reset_accumulated_memory_stats default to True **b**. call `torch.cuda.empty_cache()` before the memory snapshot so that the snapshot won't include the residual effects from the previous benchmark runs * benchmark stats |name| GPU Runtime|CPU Runtime|GPU Peak Memory **alloc**|GPU Peak Memory **reserved**| CPU Peak RSS| |--|--| |single_stream_memory|**158.58 ms**|233.79 ms|5.06 GB|5.11 GB|1.70 GB| |multi_stream_memory|145.43 ms|241.68 ms|5.06 GB |**10.13 GB** | 1.70 GB| |multi_stream_optimized|146.98 ms |244.72 ms |5.06 GB |5.11 GB | 1.70 GB| Reviewed By: spmex Differential Revision: D85399705
1 parent b7f8426 commit e2dbe39

File tree

2 files changed

+202
-14
lines changed

2 files changed

+202
-14
lines changed

torchrec/distributed/benchmark/base.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -606,7 +606,7 @@ def _run_benchmark_core(
606606
output_dir: str,
607607
pre_gpu_load: int = 0,
608608
export_stacks: bool = False,
609-
reset_accumulated_memory_stats: bool = False,
609+
reset_accumulated_memory_stats: bool = True,
610610
all_rank_traces: bool = False,
611611
memory_snapshot: bool = False,
612612
) -> BenchmarkResult:
@@ -635,24 +635,25 @@ def _run_benchmark_core(
635635
stats in addition to peak memory stats.
636636
"""
637637

638-
# Preparation & memory reset
639-
if device_type == "cuda":
640-
if rank == -1:
641-
for di in range(world_size):
642-
torch.cuda.reset_peak_memory_stats(di)
643-
if reset_accumulated_memory_stats:
644-
torch.cuda.reset_accumulated_memory_stats(di)
645-
else:
646-
torch.cuda.reset_peak_memory_stats(rank)
638+
def _reset_memory_stats() -> None:
639+
if device_type != "cuda":
640+
return
641+
ranks = range(world_size) if rank == -1 else [rank]
642+
for di in ranks:
643+
torch.cuda.reset_peak_memory_stats(di)
647644
if reset_accumulated_memory_stats:
648-
torch.cuda.reset_accumulated_memory_stats(rank)
645+
torch.cuda.reset_accumulated_memory_stats(di)
649646

647+
# Preparation & memory reset
648+
if device_type == "cuda":
650649
# Optional allocator warm-up to create fragmentation similar to production
651650
if pre_gpu_load:
652651
_tmp = torch.rand(16384, 16384, device="cuda")
653652
for _ in range(pre_gpu_load):
654653
_tmp = _tmp * torch.rand(16384, 16384, device="cuda")
655654

655+
_reset_memory_stats()
656+
656657
# Timings
657658
start_events, end_events, times = [], [], []
658659

@@ -745,6 +746,7 @@ def _trace_handler(prof: torch.profiler.profile) -> None:
745746
)
746747

747748
if memory_snapshot:
749+
torch.cuda.empty_cache()
748750
torch.cuda.memory._record_memory_history(
749751
max_entries=MAX_NUM_OF_MEM_EVENTS_PER_SNAPSHOT
750752
)

torchrec/distributed/benchmark/benchmark_comms.py

Lines changed: 189 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
see README.md for more details
2222
"""
2323

24+
from contextlib import nullcontext
2425
from dataclasses import dataclass
2526
from typing import Any, Dict, List, Optional
2627

@@ -50,7 +51,7 @@ class AllToAllSingleRunConfig(BenchFuncConfig):
5051
world_size: int = 2
5152
dim: int = 2048
5253
profile_dir: str = "."
53-
num_benchmarks: int = 1
54+
num_benchmarks: int = 2
5455
num_profiles: int = 2
5556
num_mul: int = 5
5657
num_concat: int = 100
@@ -94,6 +95,7 @@ def a2a_sync_base(
9495
num_mul: int,
9596
num_concat: int,
9697
ctx: MultiProcessContext,
98+
**_kwargs: Dict[str, Any],
9799
) -> None:
98100
with record_function("## pre-comms compute ##"):
99101
pre_comms = _compute(dim=dim, num_mul=num_mul, num_concat=num_concat, ctx=ctx)
@@ -186,6 +188,7 @@ def a2a_async_twice(
186188
num_mul: int,
187189
num_concat: int,
188190
ctx: MultiProcessContext,
191+
**_kwargs: Dict[str, Any],
189192
) -> None:
190193
with record_function("## pre-comms compute ##"):
191194
pre_comms = _compute(dim=dim, num_mul=num_mul, num_concat=num_concat, ctx=ctx)
@@ -254,13 +257,14 @@ def a2a_async_twice(
254257
assert checks1 and checks2
255258

256259

257-
# all_to_all_single with sync and single stream
260+
# LazyAwaitable
258261
def lazyawaitable(
259262
_batch_inputs: List[Dict[str, Any]],
260263
dim: int,
261264
num_mul: int,
262265
num_concat: int,
263266
ctx: MultiProcessContext,
267+
**_kwargs: Dict[str, Any],
264268
) -> None:
265269
with record_function("## pre-comms compute ##"):
266270
pre_comms = _compute(dim=dim, num_mul=num_mul, num_concat=num_concat, ctx=ctx)
@@ -294,6 +298,183 @@ def lazyawaitable(
294298
assert check_awaitable.item()
295299

296300

301+
# muti-stream memory footprint
302+
def multi_stream_memory(
303+
_batch_inputs: List[Dict[str, Any]],
304+
dim: int,
305+
num_mul: int,
306+
num_concat: int,
307+
ctx: MultiProcessContext,
308+
multi_stream: bool = True,
309+
**_kwargs: Dict[str, Any],
310+
) -> None:
311+
with record_function("## setup ##"):
312+
main_stream = torch.cuda.current_stream()
313+
data_copy_stream = torch.cuda.Stream() if multi_stream else nullcontext()
314+
data_dist_stream = torch.cuda.Stream() if multi_stream else nullcontext()
315+
irrelevant_data = torch.rand(dim, dim, device=ctx.device) - 0.5
316+
317+
# the host to device data transfer will block cuda execution without the `pin_memory()`
318+
host_data = (torch.rand(dim, dim) - 0.5).pin_memory()
319+
320+
with record_function("## irrelevant compute before h2d ##"):
321+
pre_comms = _compute(
322+
dim=dim, num_mul=num_mul, num_concat=num_concat, ctx=ctx, x=irrelevant_data
323+
)
324+
325+
with record_function("## copy data to device ##"):
326+
# use a separate stream to copy data to device, this will not block the main stream
327+
with data_copy_stream:
328+
device_data = host_data.to(ctx.device, non_blocking=True)
329+
# record the data to main stream, so it won't be freed accidently in the data_copy_stream
330+
device_data.record_stream(main_stream)
331+
332+
with record_function("## irrelevant compute after h2d ##"):
333+
irrelevant_data = torch.rand(dim, dim, device=ctx.device) - 0.5
334+
pre_comms = _compute(
335+
dim=dim, num_mul=num_mul, num_concat=num_concat, ctx=ctx, x=irrelevant_data
336+
)
337+
338+
with record_function("## pre-comms compute ##"):
339+
if isinstance(data_copy_stream, torch.cuda.Stream):
340+
# make sure the data copy is done before the pre-comms compute
341+
main_stream.wait_stream(data_copy_stream)
342+
pre_comms = _compute(
343+
dim=dim, num_mul=num_mul, num_concat=num_concat, ctx=ctx, x=device_data
344+
)
345+
346+
# use a separate stream to do the comms, this will not block the main stream
347+
with data_dist_stream:
348+
with record_function("## all_to_all_single ##"):
349+
if isinstance(data_dist_stream, torch.cuda.Stream):
350+
# make sure the pre-comms compute is done before the comms
351+
data_dist_stream.wait_stream(main_stream)
352+
post_comms = torch.zeros_like(pre_comms)
353+
req = dist.all_to_all_single(
354+
output=post_comms,
355+
input=pre_comms,
356+
group=ctx.pg,
357+
async_op=True,
358+
)
359+
# record the data to main stream, so it won't be freed accidently in the data_dist_stream
360+
post_comms.record_stream(main_stream)
361+
with record_function("## a2a comm validation ##"):
362+
# the comm validation is also done in this separate stream since
363+
# there's no data dependency afterwards
364+
req.wait()
365+
checks = DeviceToHostTensorAwaitable(_validate(post_comms, ctx))
366+
367+
with record_function("## irrelevant compute after a2a ##"):
368+
irrelevant_data = torch.rand(dim, dim, device=ctx.device) - 0.5
369+
pre_comms = _compute(
370+
dim=dim, num_mul=num_mul, num_concat=num_concat, ctx=ctx, x=irrelevant_data
371+
)
372+
373+
with record_function("## post-comms compute ##"):
374+
req.wait()
375+
post_comms = _compute(
376+
dim=dim, num_mul=num_mul, num_concat=num_concat, ctx=ctx, x=post_comms[0]
377+
)
378+
379+
with record_function("## assert ##"):
380+
assert checks.item()
381+
382+
383+
def single_stream_memory(
384+
_batch_inputs: List[Dict[str, Any]],
385+
dim: int,
386+
num_mul: int,
387+
num_concat: int,
388+
ctx: MultiProcessContext,
389+
**_kwargs: Dict[str, Any],
390+
) -> None:
391+
return multi_stream_memory(
392+
_batch_inputs=_batch_inputs,
393+
dim=dim,
394+
num_mul=num_mul,
395+
num_concat=num_concat,
396+
ctx=ctx,
397+
multi_stream=False,
398+
)
399+
400+
401+
# an optimized version of muti-stream memory footprint
402+
def multi_stream_optimized(
403+
_batch_inputs: List[Dict[str, Any]],
404+
dim: int,
405+
num_mul: int,
406+
num_concat: int,
407+
ctx: MultiProcessContext,
408+
**_kwargs: Dict[str, Any],
409+
) -> None:
410+
with record_function("## setup ##"):
411+
main_stream = torch.cuda.current_stream()
412+
data_copy_stream = torch.cuda.Stream()
413+
irrelevant_data = torch.rand(dim, dim, device=ctx.device) - 0.5
414+
415+
# the host to device data transfer will block cuda execution without the `pin_memory()`
416+
host_data = (torch.rand(dim, dim) - 0.5).pin_memory()
417+
# pre-allocate memory on the device for the incoming data transfer from the host
418+
device_data = torch.empty_like(host_data, device=ctx.device)
419+
420+
with record_function("## irrelevant compute before h2d ##"):
421+
pre_comms = _compute(
422+
dim=dim, num_mul=num_mul, num_concat=num_concat, ctx=ctx, x=irrelevant_data
423+
)
424+
425+
with record_function("## copy data to device ##"):
426+
with data_copy_stream:
427+
# copy data to device, this will not block the main stream
428+
device_data.copy_(host_data, non_blocking=True)
429+
430+
with record_function("## irrelevant compute after h2d ##"):
431+
irrelevant_data = torch.rand(dim, dim, device=ctx.device) - 0.5
432+
pre_comms = _compute(
433+
dim=dim, num_mul=num_mul, num_concat=num_concat, ctx=ctx, x=irrelevant_data
434+
)
435+
436+
with record_function("## pre-comms compute ##"):
437+
# make sure the data copy is done before the pre-comms compute
438+
main_stream.wait_stream(data_copy_stream)
439+
pre_comms = _compute(
440+
dim=dim, num_mul=num_mul, num_concat=num_concat, ctx=ctx, x=device_data
441+
)
442+
443+
with record_function("## pre-allocate memory for a2a on main stream ##"):
444+
post_comms = torch.zeros_like(pre_comms)
445+
446+
with record_function("## all_to_all_single ##"):
447+
# the all_to_all_single from torch.dist has async feature
448+
# it automaically uses a separate stream to do the comms
449+
# without introducing extra memory footprint
450+
req = dist.all_to_all_single(
451+
output=post_comms,
452+
input=pre_comms,
453+
group=ctx.pg,
454+
async_op=True,
455+
)
456+
457+
with record_function("## irrelevant compute after a2a ##"):
458+
irrelevant_data = torch.rand(dim, dim, device=ctx.device) - 0.5
459+
pre_comms = _compute(
460+
dim=dim, num_mul=num_mul, num_concat=num_concat, ctx=ctx, x=irrelevant_data
461+
)
462+
463+
with record_function("## a2a comm validation ##"):
464+
# this req.wait() can be wrapped into a LazyAwaitable
465+
req.wait()
466+
# still want the compute on the main stream if possible
467+
checks = DeviceToHostTensorAwaitable(_validate(post_comms, ctx))
468+
469+
with record_function("## post-comms compute ##"):
470+
post_comms = _compute(
471+
dim=dim, num_mul=num_mul, num_concat=num_concat, ctx=ctx, x=post_comms[0]
472+
)
473+
474+
with record_function("## assert ##"):
475+
assert checks.item()
476+
477+
297478
# single-rank runner
298479
def a2a_single_runner(rank: int, world_size: int, arg: AllToAllSingleRunConfig) -> None:
299480
# Ensure GPUs are available and we have enough of them
@@ -308,7 +489,6 @@ def a2a_single_runner(rank: int, world_size: int, arg: AllToAllSingleRunConfig)
308489
backend="nccl",
309490
use_deterministic_algorithms=False,
310491
) as ctx:
311-
312492
if arg.name.startswith("a2a_sync_base"):
313493
func = a2a_sync_base
314494
elif arg.name.startswith("a2a_async_base"):
@@ -317,6 +497,12 @@ def a2a_single_runner(rank: int, world_size: int, arg: AllToAllSingleRunConfig)
317497
func = a2a_async_twice
318498
elif arg.name.startswith("lazyawaitable"):
319499
func = lazyawaitable
500+
elif arg.name.startswith("multi_stream_memory"):
501+
func = multi_stream_memory
502+
elif arg.name.startswith("single_stream_memory"):
503+
func = single_stream_memory
504+
elif arg.name.startswith("multi_stream_optimized"):
505+
func = multi_stream_optimized
320506
else:
321507
raise ValueError(f"Unknown benchmark name: {arg.name}")
322508

0 commit comments

Comments
 (0)