Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve Comms Benchmark Timing #833

Merged
merged 3 commits into from
Dec 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions benchmarks/communication/all_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@


# Run all_gather and print metrics
def timed_all_gather(input, output, args):
def timed_all_gather(input, output, start_event, end_event, args):
if args.dist == 'torch':
import torch.distributed as dist

Expand All @@ -33,11 +33,12 @@ def timed_all_gather(input, output, args):
sync_all()

# time the actual comm op trials times and average it
pre = time.perf_counter()
start_event.record()
for i in range(args.trials):
all_gather_func(output, input, group=None, async_op=args.async_op)
end_event.record()
sync_all()
duration = time.perf_counter() - pre
duration = start_event.elapsed_time(end_event) / 1000

# maintain and clean performance data
avg_duration = duration / args.trials
Expand All @@ -63,6 +64,9 @@ def run_all_gather(local_rank, args):
global_rank = dist.get_rank()
world_size = dist.get_world_size()

start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)

if args.scan:
# Create list of message sizes
M_LIST = []
Expand Down Expand Up @@ -92,7 +96,7 @@ def run_all_gather(local_rank, args):
else:
raise e
sync_all()
timed_all_gather(input, output, args)
timed_all_gather(input, output, start_event, end_event, args)
else:
# all_gather_into_tensor saves memory
if ((args.dist == 'torch' or args.dist == 'deepspeed') and dist.has_all_gather_into_tensor()):
Expand Down Expand Up @@ -126,7 +130,7 @@ def run_all_gather(local_rank, args):
raise e

sync_all()
timed_all_gather(input, output, args)
timed_all_gather(input, output, start_event, end_event, args)


if __name__ == "__main__":
Expand Down
14 changes: 9 additions & 5 deletions benchmarks/communication/all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from deepspeed.accelerator import get_accelerator


def timed_all_reduce(input, args):
def timed_all_reduce(input, start_event, end_event, args):
if args.dist == 'torch':
import torch.distributed as dist
elif args.dist == 'deepspeed':
Expand All @@ -27,11 +27,12 @@ def timed_all_reduce(input, args):
sync_all()

# time the actual comm op trials times and average it
pre = time.perf_counter()
start_event.record()
for i in range(args.trials):
dist.all_reduce(input, async_op=args.async_op)
end_event.record()
sync_all()
duration = time.perf_counter() - pre
duration = start_event.elapsed_time(end_event) / 1000

# maintain and clean performance data
avg_duration = duration / args.trials
Expand Down Expand Up @@ -59,6 +60,9 @@ def run_all_reduce(local_rank, args):
world_size = dist.get_world_size()
global_rank = dist.get_rank()

start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)

if args.scan:
M_LIST = []
for x in (2**p for p in range(1, args.maxsize)):
Expand All @@ -82,7 +86,7 @@ def run_all_reduce(local_rank, args):
else:
raise e
sync_all()
timed_all_reduce(input, args)
timed_all_reduce(input, start_event, end_event, args)
else:
# Send the biggest message size our GPUs can fit. If you're facing OOM errors, reduce the mem_factor
# Don't need output tensor, so we double mem_factor
Expand All @@ -104,7 +108,7 @@ def run_all_reduce(local_rank, args):
else:
raise e
sync_all()
timed_all_reduce(input, args)
timed_all_reduce(input, start_event, end_event, args)


if __name__ == "__main__":
Expand Down
14 changes: 9 additions & 5 deletions benchmarks/communication/all_to_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from deepspeed.accelerator import get_accelerator


def timed_all_to_all(input, output, args):
def timed_all_to_all(input, output, start_event, end_event, args):
if args.dist == 'torch':
import torch.distributed as dist
elif args.dist == 'deepspeed':
Expand All @@ -27,11 +27,12 @@ def timed_all_to_all(input, output, args):
sync_all()

# time the actual comm op trials times and average it
pre = time.perf_counter()
start_event.record()
for i in range(args.trials):
dist.all_to_all_single(output, input, async_op=args.async_op)
end_event.record()
sync_all()
duration = time.perf_counter() - pre
duration = start_event.elapsed_time(end_event) / 1000

# maintain and clean performance data
avg_duration = duration / args.trials
Expand All @@ -58,6 +59,9 @@ def run_all_to_all(local_rank, args):
# Prepare benchmark header
print_header(args, 'all_to_all')

start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)

if args.scan:
M_LIST = []
for x in (2**p for p in range(1, args.maxsize)):
Expand All @@ -83,7 +87,7 @@ def run_all_to_all(local_rank, args):
else:
raise e
sync_all()
timed_all_to_all(input, output, args)
timed_all_to_all(input, output, start_event, end_event, args)
else:
# Send the biggest message size our GPUs can fit. If you're facing OOM errors, reduce the mem_factor
elements_per_gpu = max_numel(comm_op='all_to_all',
Expand Down Expand Up @@ -118,7 +122,7 @@ def run_all_to_all(local_rank, args):
print(f"Before AllToAll Input List at rank {global_rank}: {input}")
dist.barrier()

timed_all_to_all(input, output, args)
timed_all_to_all(input, output, start_event, end_event, args)

if args.debug:
for i in range(world_size):
Expand Down
14 changes: 9 additions & 5 deletions benchmarks/communication/broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from deepspeed.accelerator import get_accelerator


def timed_broadcast(input, args):
def timed_broadcast(input, start_event, end_event, args):
if args.dist == 'torch':
import torch.distributed as dist
elif args.dist == 'deepspeed':
Expand All @@ -27,11 +27,12 @@ def timed_broadcast(input, args):
sync_all()

# time the actual comm op trials times and average it
pre = time.perf_counter()
start_event.record()
for i in range(args.trials):
dist.broadcast(input, 0, async_op=args.async_op)
end_event.record()
sync_all()
duration = time.perf_counter() - pre
duration = start_event.elapsed_time(end_event) / 1000

# maintain and clean performance data
avg_duration = duration / args.trials
Expand Down Expand Up @@ -59,6 +60,9 @@ def run_broadcast(local_rank, args):
world_size = dist.get_world_size()
global_rank = dist.get_rank()

start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)

if args.scan:
M_LIST = []
for x in (2**p for p in range(1, args.maxsize)):
Expand All @@ -82,7 +86,7 @@ def run_broadcast(local_rank, args):
else:
raise e
sync_all()
timed_broadcast(input, args)
timed_broadcast(input, start_event, end_event, args)
else:
# Send the biggest message size our GPUs can fit. If you're facing OOM errors, reduce the mem_factor
# Don't need output tensor, so we double mem_factor
Expand All @@ -102,7 +106,7 @@ def run_broadcast(local_rank, args):
sync_all()
return
sync_all()
timed_broadcast(input, args)
timed_broadcast(input, start_event, end_event, args)


if __name__ == "__main__":
Expand Down
14 changes: 9 additions & 5 deletions benchmarks/communication/pt2pt.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from deepspeed.accelerator import get_accelerator


def timed_pt2pt(input, args):
def timed_pt2pt(input, start_event, end_event, args):
if args.dist == 'torch':
import torch.distributed as dist
elif args.dist == 'deepspeed':
Expand All @@ -36,7 +36,7 @@ def timed_pt2pt(input, args):
sync_all()

# time the actual comm op trials times and average it
pre = time.perf_counter()
start_event.record()
for i in range(args.trials):
if dist.get_rank() == 0:
if args.async_op:
Expand All @@ -49,8 +49,9 @@ def timed_pt2pt(input, args):
else:
dist.recv(input, src=0)

end_event.record()
sync_all()
duration = time.perf_counter() - pre
duration = start_event.elapsed_time(end_event) / 1000

# maintain and clean performance data
avg_duration = duration / args.trials
Expand All @@ -77,6 +78,9 @@ def run_pt2pt(local_rank, args):
global_rank = dist.get_rank()
world_size = dist.get_world_size()

start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)

if args.scan:
# Create list of message sizes
M_LIST = []
Expand All @@ -101,7 +105,7 @@ def run_pt2pt(local_rank, args):
else:
raise e
sync_all()
timed_pt2pt(input, args)
timed_pt2pt(input, start_event, end_event, args)
else:
# Send the biggest message size our GPUs can fit. If you're facing OOM errors, reduce the mem_factor
# Don't need output tensor, so double mem_factor
Expand All @@ -121,7 +125,7 @@ def run_pt2pt(local_rank, args):
sync_all()
return
sync_all()
timed_pt2pt(input, args)
timed_pt2pt(input, start_event, end_event, args)


if __name__ == "__main__":
Expand Down
Loading