Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 13 additions & 11 deletions torchrec/distributed/benchmark/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,7 +606,7 @@ def _run_benchmark_core(
output_dir: str,
pre_gpu_load: int = 0,
export_stacks: bool = False,
reset_accumulated_memory_stats: bool = False,
reset_accumulated_memory_stats: bool = True,
all_rank_traces: bool = False,
memory_snapshot: bool = False,
) -> BenchmarkResult:
Expand Down Expand Up @@ -635,24 +635,25 @@ def _run_benchmark_core(
stats in addition to peak memory stats.
"""

# Preparation & memory reset
if device_type == "cuda":
if rank == -1:
for di in range(world_size):
torch.cuda.reset_peak_memory_stats(di)
if reset_accumulated_memory_stats:
torch.cuda.reset_accumulated_memory_stats(di)
else:
torch.cuda.reset_peak_memory_stats(rank)
def _reset_memory_stats() -> None:
if device_type != "cuda":
return
ranks = range(world_size) if rank == -1 else [rank]
for di in ranks:
torch.cuda.reset_peak_memory_stats(di)
if reset_accumulated_memory_stats:
torch.cuda.reset_accumulated_memory_stats(rank)
torch.cuda.reset_accumulated_memory_stats(di)

# Preparation & memory reset
if device_type == "cuda":
# Optional allocator warm-up to create fragmentation similar to production
if pre_gpu_load:
_tmp = torch.rand(16384, 16384, device="cuda")
for _ in range(pre_gpu_load):
_tmp = _tmp * torch.rand(16384, 16384, device="cuda")

_reset_memory_stats()

# Timings
start_events, end_events, times = [], [], []

Expand Down Expand Up @@ -745,6 +746,7 @@ def _trace_handler(prof: torch.profiler.profile) -> None:
)

if memory_snapshot:
torch.cuda.empty_cache()
torch.cuda.memory._record_memory_history(
max_entries=MAX_NUM_OF_MEM_EVENTS_PER_SNAPSHOT
)
Expand Down
192 changes: 189 additions & 3 deletions torchrec/distributed/benchmark/benchmark_comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
see README.md for more details
"""

from contextlib import nullcontext
from dataclasses import dataclass
from typing import Any, Dict, List, Optional

Expand Down Expand Up @@ -50,7 +51,7 @@ class AllToAllSingleRunConfig(BenchFuncConfig):
world_size: int = 2
dim: int = 2048
profile_dir: str = "."
num_benchmarks: int = 1
num_benchmarks: int = 2
num_profiles: int = 2
num_mul: int = 5
num_concat: int = 100
Expand Down Expand Up @@ -94,6 +95,7 @@ def a2a_sync_base(
num_mul: int,
num_concat: int,
ctx: MultiProcessContext,
**_kwargs: Dict[str, Any],
) -> None:
with record_function("## pre-comms compute ##"):
pre_comms = _compute(dim=dim, num_mul=num_mul, num_concat=num_concat, ctx=ctx)
Expand Down Expand Up @@ -186,6 +188,7 @@ def a2a_async_twice(
num_mul: int,
num_concat: int,
ctx: MultiProcessContext,
**_kwargs: Dict[str, Any],
) -> None:
with record_function("## pre-comms compute ##"):
pre_comms = _compute(dim=dim, num_mul=num_mul, num_concat=num_concat, ctx=ctx)
Expand Down Expand Up @@ -254,13 +257,14 @@ def a2a_async_twice(
assert checks1 and checks2


# all_to_all_single with sync and single stream
# LazyAwaitable
def lazyawaitable(
_batch_inputs: List[Dict[str, Any]],
dim: int,
num_mul: int,
num_concat: int,
ctx: MultiProcessContext,
**_kwargs: Dict[str, Any],
) -> None:
with record_function("## pre-comms compute ##"):
pre_comms = _compute(dim=dim, num_mul=num_mul, num_concat=num_concat, ctx=ctx)
Expand Down Expand Up @@ -294,6 +298,183 @@ def lazyawaitable(
assert check_awaitable.item()


# muti-stream memory footprint
def multi_stream_memory(
_batch_inputs: List[Dict[str, Any]],
dim: int,
num_mul: int,
num_concat: int,
ctx: MultiProcessContext,
multi_stream: bool = True,
**_kwargs: Dict[str, Any],
) -> None:
with record_function("## setup ##"):
main_stream = torch.cuda.current_stream()
data_copy_stream = torch.cuda.Stream() if multi_stream else nullcontext()
data_dist_stream = torch.cuda.Stream() if multi_stream else nullcontext()
irrelevant_data = torch.rand(dim, dim, device=ctx.device) - 0.5

# the host to device data transfer will block cuda execution without the `pin_memory()`
host_data = (torch.rand(dim, dim) - 0.5).pin_memory()

with record_function("## irrelevant compute before h2d ##"):
pre_comms = _compute(
dim=dim, num_mul=num_mul, num_concat=num_concat, ctx=ctx, x=irrelevant_data
)

with record_function("## copy data to device ##"):
# use a separate stream to copy data to device, this will not block the main stream
with data_copy_stream:
device_data = host_data.to(ctx.device, non_blocking=True)
# record the data to main stream, so it won't be freed accidently in the data_copy_stream
device_data.record_stream(main_stream)

with record_function("## irrelevant compute after h2d ##"):
irrelevant_data = torch.rand(dim, dim, device=ctx.device) - 0.5
pre_comms = _compute(
dim=dim, num_mul=num_mul, num_concat=num_concat, ctx=ctx, x=irrelevant_data
)

with record_function("## pre-comms compute ##"):
if isinstance(data_copy_stream, torch.cuda.Stream):
# make sure the data copy is done before the pre-comms compute
main_stream.wait_stream(data_copy_stream)
pre_comms = _compute(
dim=dim, num_mul=num_mul, num_concat=num_concat, ctx=ctx, x=device_data
)

# use a separate stream to do the comms, this will not block the main stream
with data_dist_stream:
with record_function("## all_to_all_single ##"):
if isinstance(data_dist_stream, torch.cuda.Stream):
# make sure the pre-comms compute is done before the comms
data_dist_stream.wait_stream(main_stream)
post_comms = torch.zeros_like(pre_comms)
req = dist.all_to_all_single(
output=post_comms,
input=pre_comms,
group=ctx.pg,
async_op=True,
)
# record the data to main stream, so it won't be freed accidently in the data_dist_stream
post_comms.record_stream(main_stream)
with record_function("## a2a comm validation ##"):
# the comm validation is also done in this separate stream since
# there's no data dependency afterwards
req.wait()
checks = DeviceToHostTensorAwaitable(_validate(post_comms, ctx))

with record_function("## irrelevant compute after a2a ##"):
irrelevant_data = torch.rand(dim, dim, device=ctx.device) - 0.5
pre_comms = _compute(
dim=dim, num_mul=num_mul, num_concat=num_concat, ctx=ctx, x=irrelevant_data
)

with record_function("## post-comms compute ##"):
req.wait()
post_comms = _compute(
dim=dim, num_mul=num_mul, num_concat=num_concat, ctx=ctx, x=post_comms[0]
)

with record_function("## assert ##"):
assert checks.item()


def single_stream_memory(
_batch_inputs: List[Dict[str, Any]],
dim: int,
num_mul: int,
num_concat: int,
ctx: MultiProcessContext,
**_kwargs: Dict[str, Any],
) -> None:
return multi_stream_memory(
_batch_inputs=_batch_inputs,
dim=dim,
num_mul=num_mul,
num_concat=num_concat,
ctx=ctx,
multi_stream=False,
)


# an optimized version of muti-stream memory footprint
def multi_stream_optimized(
_batch_inputs: List[Dict[str, Any]],
dim: int,
num_mul: int,
num_concat: int,
ctx: MultiProcessContext,
**_kwargs: Dict[str, Any],
) -> None:
with record_function("## setup ##"):
main_stream = torch.cuda.current_stream()
data_copy_stream = torch.cuda.Stream()
irrelevant_data = torch.rand(dim, dim, device=ctx.device) - 0.5

# the host to device data transfer will block cuda execution without the `pin_memory()`
host_data = (torch.rand(dim, dim) - 0.5).pin_memory()
# pre-allocate memory on the device for the incoming data transfer from the host
device_data = torch.empty_like(host_data, device=ctx.device)

with record_function("## irrelevant compute before h2d ##"):
pre_comms = _compute(
dim=dim, num_mul=num_mul, num_concat=num_concat, ctx=ctx, x=irrelevant_data
)

with record_function("## copy data to device ##"):
with data_copy_stream:
# copy data to device, this will not block the main stream
device_data.copy_(host_data, non_blocking=True)

with record_function("## irrelevant compute after h2d ##"):
irrelevant_data = torch.rand(dim, dim, device=ctx.device) - 0.5
pre_comms = _compute(
dim=dim, num_mul=num_mul, num_concat=num_concat, ctx=ctx, x=irrelevant_data
)

with record_function("## pre-comms compute ##"):
# make sure the data copy is done before the pre-comms compute
main_stream.wait_stream(data_copy_stream)
pre_comms = _compute(
dim=dim, num_mul=num_mul, num_concat=num_concat, ctx=ctx, x=device_data
)

with record_function("## pre-allocate memory for a2a on main stream ##"):
post_comms = torch.zeros_like(pre_comms)

with record_function("## all_to_all_single ##"):
# the all_to_all_single from torch.dist has async feature
# it automaically uses a separate stream to do the comms
# without introducing extra memory footprint
req = dist.all_to_all_single(
output=post_comms,
input=pre_comms,
group=ctx.pg,
async_op=True,
)

with record_function("## irrelevant compute after a2a ##"):
irrelevant_data = torch.rand(dim, dim, device=ctx.device) - 0.5
pre_comms = _compute(
dim=dim, num_mul=num_mul, num_concat=num_concat, ctx=ctx, x=irrelevant_data
)

with record_function("## a2a comm validation ##"):
# this req.wait() can be wrapped into a LazyAwaitable
req.wait()
# still want the compute on the main stream if possible
checks = DeviceToHostTensorAwaitable(_validate(post_comms, ctx))

with record_function("## post-comms compute ##"):
post_comms = _compute(
dim=dim, num_mul=num_mul, num_concat=num_concat, ctx=ctx, x=post_comms[0]
)

with record_function("## assert ##"):
assert checks.item()


# single-rank runner
def a2a_single_runner(rank: int, world_size: int, arg: AllToAllSingleRunConfig) -> None:
# Ensure GPUs are available and we have enough of them
Expand All @@ -308,7 +489,6 @@ def a2a_single_runner(rank: int, world_size: int, arg: AllToAllSingleRunConfig)
backend="nccl",
use_deterministic_algorithms=False,
) as ctx:

if arg.name.startswith("a2a_sync_base"):
func = a2a_sync_base
elif arg.name.startswith("a2a_async_base"):
Expand All @@ -317,6 +497,12 @@ def a2a_single_runner(rank: int, world_size: int, arg: AllToAllSingleRunConfig)
func = a2a_async_twice
elif arg.name.startswith("lazyawaitable"):
func = lazyawaitable
elif arg.name.startswith("multi_stream_memory"):
func = multi_stream_memory
elif arg.name.startswith("single_stream_memory"):
func = single_stream_memory
elif arg.name.startswith("multi_stream_optimized"):
func = multi_stream_optimized
else:
raise ValueError(f"Unknown benchmark name: {arg.name}")

Expand Down
Loading