diff --git a/torchrec/distributed/benchmark/base.py b/torchrec/distributed/benchmark/base.py index 0821fe579..858abde85 100644 --- a/torchrec/distributed/benchmark/base.py +++ b/torchrec/distributed/benchmark/base.py @@ -606,7 +606,7 @@ def _run_benchmark_core( output_dir: str, pre_gpu_load: int = 0, export_stacks: bool = False, - reset_accumulated_memory_stats: bool = False, + reset_accumulated_memory_stats: bool = True, all_rank_traces: bool = False, memory_snapshot: bool = False, ) -> BenchmarkResult: @@ -635,24 +635,25 @@ def _run_benchmark_core( stats in addition to peak memory stats. """ - # Preparation & memory reset - if device_type == "cuda": - if rank == -1: - for di in range(world_size): - torch.cuda.reset_peak_memory_stats(di) - if reset_accumulated_memory_stats: - torch.cuda.reset_accumulated_memory_stats(di) - else: - torch.cuda.reset_peak_memory_stats(rank) + def _reset_memory_stats() -> None: + if device_type != "cuda": + return + ranks = range(world_size) if rank == -1 else [rank] + for di in ranks: + torch.cuda.reset_peak_memory_stats(di) if reset_accumulated_memory_stats: - torch.cuda.reset_accumulated_memory_stats(rank) + torch.cuda.reset_accumulated_memory_stats(di) + # Preparation & memory reset + if device_type == "cuda": # Optional allocator warm-up to create fragmentation similar to production if pre_gpu_load: _tmp = torch.rand(16384, 16384, device="cuda") for _ in range(pre_gpu_load): _tmp = _tmp * torch.rand(16384, 16384, device="cuda") + _reset_memory_stats() + # Timings start_events, end_events, times = [], [], [] @@ -745,6 +746,7 @@ def _trace_handler(prof: torch.profiler.profile) -> None: ) if memory_snapshot: + torch.cuda.empty_cache() torch.cuda.memory._record_memory_history( max_entries=MAX_NUM_OF_MEM_EVENTS_PER_SNAPSHOT ) diff --git a/torchrec/distributed/benchmark/benchmark_comms.py b/torchrec/distributed/benchmark/benchmark_comms.py index f3e9bb27b..da3969352 100644 --- a/torchrec/distributed/benchmark/benchmark_comms.py +++ b/torchrec/distributed/benchmark/benchmark_comms.py @@ -21,6 +21,7 @@ see README.md for more details """ +from contextlib import nullcontext from dataclasses import dataclass from typing import Any, Dict, List, Optional @@ -50,7 +51,7 @@ class AllToAllSingleRunConfig(BenchFuncConfig): world_size: int = 2 dim: int = 2048 profile_dir: str = "." - num_benchmarks: int = 1 + num_benchmarks: int = 2 num_profiles: int = 2 num_mul: int = 5 num_concat: int = 100 @@ -94,6 +95,7 @@ def a2a_sync_base( num_mul: int, num_concat: int, ctx: MultiProcessContext, + **_kwargs: Dict[str, Any], ) -> None: with record_function("## pre-comms compute ##"): pre_comms = _compute(dim=dim, num_mul=num_mul, num_concat=num_concat, ctx=ctx) @@ -186,6 +188,7 @@ def a2a_async_twice( num_mul: int, num_concat: int, ctx: MultiProcessContext, + **_kwargs: Dict[str, Any], ) -> None: with record_function("## pre-comms compute ##"): pre_comms = _compute(dim=dim, num_mul=num_mul, num_concat=num_concat, ctx=ctx) @@ -254,13 +257,14 @@ def a2a_async_twice( assert checks1 and checks2 -# all_to_all_single with sync and single stream +# LazyAwaitable def lazyawaitable( _batch_inputs: List[Dict[str, Any]], dim: int, num_mul: int, num_concat: int, ctx: MultiProcessContext, + **_kwargs: Dict[str, Any], ) -> None: with record_function("## pre-comms compute ##"): pre_comms = _compute(dim=dim, num_mul=num_mul, num_concat=num_concat, ctx=ctx) @@ -294,6 +298,183 @@ def lazyawaitable( assert check_awaitable.item() +# muti-stream memory footprint +def multi_stream_memory( + _batch_inputs: List[Dict[str, Any]], + dim: int, + num_mul: int, + num_concat: int, + ctx: MultiProcessContext, + multi_stream: bool = True, + **_kwargs: Dict[str, Any], +) -> None: + with record_function("## setup ##"): + main_stream = torch.cuda.current_stream() + data_copy_stream = torch.cuda.Stream() if multi_stream else nullcontext() + data_dist_stream = torch.cuda.Stream() if multi_stream else nullcontext() + irrelevant_data = torch.rand(dim, dim, device=ctx.device) - 0.5 + + # the host to device data transfer will block cuda execution without the `pin_memory()` + host_data = (torch.rand(dim, dim) - 0.5).pin_memory() + + with record_function("## irrelevant compute before h2d ##"): + pre_comms = _compute( + dim=dim, num_mul=num_mul, num_concat=num_concat, ctx=ctx, x=irrelevant_data + ) + + with record_function("## copy data to device ##"): + # use a separate stream to copy data to device, this will not block the main stream + with data_copy_stream: + device_data = host_data.to(ctx.device, non_blocking=True) + # record the data to main stream, so it won't be freed accidently in the data_copy_stream + device_data.record_stream(main_stream) + + with record_function("## irrelevant compute after h2d ##"): + irrelevant_data = torch.rand(dim, dim, device=ctx.device) - 0.5 + pre_comms = _compute( + dim=dim, num_mul=num_mul, num_concat=num_concat, ctx=ctx, x=irrelevant_data + ) + + with record_function("## pre-comms compute ##"): + if isinstance(data_copy_stream, torch.cuda.Stream): + # make sure the data copy is done before the pre-comms compute + main_stream.wait_stream(data_copy_stream) + pre_comms = _compute( + dim=dim, num_mul=num_mul, num_concat=num_concat, ctx=ctx, x=device_data + ) + + # use a separate stream to do the comms, this will not block the main stream + with data_dist_stream: + with record_function("## all_to_all_single ##"): + if isinstance(data_dist_stream, torch.cuda.Stream): + # make sure the pre-comms compute is done before the comms + data_dist_stream.wait_stream(main_stream) + post_comms = torch.zeros_like(pre_comms) + req = dist.all_to_all_single( + output=post_comms, + input=pre_comms, + group=ctx.pg, + async_op=True, + ) + # record the data to main stream, so it won't be freed accidently in the data_dist_stream + post_comms.record_stream(main_stream) + with record_function("## a2a comm validation ##"): + # the comm validation is also done in this separate stream since + # there's no data dependency afterwards + req.wait() + checks = DeviceToHostTensorAwaitable(_validate(post_comms, ctx)) + + with record_function("## irrelevant compute after a2a ##"): + irrelevant_data = torch.rand(dim, dim, device=ctx.device) - 0.5 + pre_comms = _compute( + dim=dim, num_mul=num_mul, num_concat=num_concat, ctx=ctx, x=irrelevant_data + ) + + with record_function("## post-comms compute ##"): + req.wait() + post_comms = _compute( + dim=dim, num_mul=num_mul, num_concat=num_concat, ctx=ctx, x=post_comms[0] + ) + + with record_function("## assert ##"): + assert checks.item() + + +def single_stream_memory( + _batch_inputs: List[Dict[str, Any]], + dim: int, + num_mul: int, + num_concat: int, + ctx: MultiProcessContext, + **_kwargs: Dict[str, Any], +) -> None: + return multi_stream_memory( + _batch_inputs=_batch_inputs, + dim=dim, + num_mul=num_mul, + num_concat=num_concat, + ctx=ctx, + multi_stream=False, + ) + + +# an optimized version of muti-stream memory footprint +def multi_stream_optimized( + _batch_inputs: List[Dict[str, Any]], + dim: int, + num_mul: int, + num_concat: int, + ctx: MultiProcessContext, + **_kwargs: Dict[str, Any], +) -> None: + with record_function("## setup ##"): + main_stream = torch.cuda.current_stream() + data_copy_stream = torch.cuda.Stream() + irrelevant_data = torch.rand(dim, dim, device=ctx.device) - 0.5 + + # the host to device data transfer will block cuda execution without the `pin_memory()` + host_data = (torch.rand(dim, dim) - 0.5).pin_memory() + # pre-allocate memory on the device for the incoming data transfer from the host + device_data = torch.empty_like(host_data, device=ctx.device) + + with record_function("## irrelevant compute before h2d ##"): + pre_comms = _compute( + dim=dim, num_mul=num_mul, num_concat=num_concat, ctx=ctx, x=irrelevant_data + ) + + with record_function("## copy data to device ##"): + with data_copy_stream: + # copy data to device, this will not block the main stream + device_data.copy_(host_data, non_blocking=True) + + with record_function("## irrelevant compute after h2d ##"): + irrelevant_data = torch.rand(dim, dim, device=ctx.device) - 0.5 + pre_comms = _compute( + dim=dim, num_mul=num_mul, num_concat=num_concat, ctx=ctx, x=irrelevant_data + ) + + with record_function("## pre-comms compute ##"): + # make sure the data copy is done before the pre-comms compute + main_stream.wait_stream(data_copy_stream) + pre_comms = _compute( + dim=dim, num_mul=num_mul, num_concat=num_concat, ctx=ctx, x=device_data + ) + + with record_function("## pre-allocate memory for a2a on main stream ##"): + post_comms = torch.zeros_like(pre_comms) + + with record_function("## all_to_all_single ##"): + # the all_to_all_single from torch.dist has async feature + # it automaically uses a separate stream to do the comms + # without introducing extra memory footprint + req = dist.all_to_all_single( + output=post_comms, + input=pre_comms, + group=ctx.pg, + async_op=True, + ) + + with record_function("## irrelevant compute after a2a ##"): + irrelevant_data = torch.rand(dim, dim, device=ctx.device) - 0.5 + pre_comms = _compute( + dim=dim, num_mul=num_mul, num_concat=num_concat, ctx=ctx, x=irrelevant_data + ) + + with record_function("## a2a comm validation ##"): + # this req.wait() can be wrapped into a LazyAwaitable + req.wait() + # still want the compute on the main stream if possible + checks = DeviceToHostTensorAwaitable(_validate(post_comms, ctx)) + + with record_function("## post-comms compute ##"): + post_comms = _compute( + dim=dim, num_mul=num_mul, num_concat=num_concat, ctx=ctx, x=post_comms[0] + ) + + with record_function("## assert ##"): + assert checks.item() + + # single-rank runner def a2a_single_runner(rank: int, world_size: int, arg: AllToAllSingleRunConfig) -> None: # Ensure GPUs are available and we have enough of them @@ -308,7 +489,6 @@ def a2a_single_runner(rank: int, world_size: int, arg: AllToAllSingleRunConfig) backend="nccl", use_deterministic_algorithms=False, ) as ctx: - if arg.name.startswith("a2a_sync_base"): func = a2a_sync_base elif arg.name.startswith("a2a_async_base"): @@ -317,6 +497,12 @@ def a2a_single_runner(rank: int, world_size: int, arg: AllToAllSingleRunConfig) func = a2a_async_twice elif arg.name.startswith("lazyawaitable"): func = lazyawaitable + elif arg.name.startswith("multi_stream_memory"): + func = multi_stream_memory + elif arg.name.startswith("single_stream_memory"): + func = single_stream_memory + elif arg.name.startswith("multi_stream_optimized"): + func = multi_stream_optimized else: raise ValueError(f"Unknown benchmark name: {arg.name}")