Skip to content

Commit

Permalink
add example files for NCCL all_to_all_v/all_gather_v (#222)
Browse files Browse the repository at this point in the history
  • Loading branch information
ghostplant authored Jan 6, 2024
1 parent c2b2271 commit fdf6e59
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 2 deletions.
15 changes: 14 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,19 @@ Tutel MoE: An Optimized Mixture-of-Experts Implementation.

### What's New:

- Tutel v0.3.1: Add NCCL all_to_all_v and all_gather_v for arbitrary-length message transfers:
```py
>> Example:
# All_to_All_v:
python3 -m torch.distributed.run --nproc_per_node=2 --master_port=7340 -m tutel.examples.nccl_all_to_all_v
# All_Gather_v:
python3 -m torch.distributed.run --nproc_per_node=2 --master_port=7340 -m tutel.examples.nccl_all_gather_v

>> How to:
net.batch_all_to_all_v([t_x_cuda, t_y_cuda, ..], common_send_counts)
net.batch_all_gather_v([t_x_cuda, t_y_cuda, ..])
```

- Tutel v0.3: Add Megablocks solution to improve decoder inference on single-GPU with num_local_expert >= 2:
```py
>> Example (capacity_factor=0 for dropless-MoE):
Expand Down Expand Up @@ -46,7 +59,7 @@ Tutel MoE: An Optimized Mixture-of-Experts Implementation.
```
* Prepare Recommended Pytorch >= 2.0.0 (minimize version == 1.8.0):
# Windows/Linux Pytorch for NVIDIA CUDA >= 11.7:
python3 -m pip install torch==2.0.0 --index-url https://download.pytorch.org/whl/cu117
python3 -m pip install torch==2.0.0 --index-url https://download.pytorch.org/whl/cu118
# Linux Pytorch for AMD ROCm == 5.4.2:
python3 -m pip install torch==2.0.0 --index-url https://download.pytorch.org/whl/rocm5.4.2
# Windows/Linux Pytorch for CPU:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
IS_HIP_EXTENSION = False

if len(sys.argv) <= 1:
sys.argv += ['install', '--user']
sys.argv += ['install']

root_path = os.path.dirname(sys.argv[0])
root_path = root_path if root_path else '.'
Expand Down
22 changes: 22 additions & 0 deletions tutel/examples/nccl_all_gather_v.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#!/usr/bin/env python3
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import torch
from tutel import system, net

parallel_env = system.init_data_model_parallel(backend='nccl', group_count=1)
local_device = parallel_env.local_device

assert parallel_env.global_size == 2, "This test case is set for World Size == 2 only"

if parallel_env.global_rank == 0:
input = torch.tensor([10, 10, 10, 10, 10], device=local_device)
else:
input = torch.tensor([20, 20, 20], device=local_device)

print(f'Device-{parallel_env.global_rank} sends: {[input,]}')

net.barrier()

print(f'Device-{parallel_env.global_rank} recvs: {net.batch_all_gather_v([input,])[0]}')
24 changes: 24 additions & 0 deletions tutel/examples/nccl_all_to_all_v.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#!/usr/bin/env python3
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import torch
from tutel import system, net

parallel_env = system.init_data_model_parallel(backend='nccl', group_count=1)
local_device = parallel_env.local_device

assert parallel_env.global_size == 2, "This test case is set for World Size == 2 only"

if parallel_env.global_rank == 0:
input = torch.tensor([10, 10, 10, 10, 10], device=local_device)
send_counts = torch.tensor([1, 4], dtype=torch.int32, device=local_device)
else:
input = torch.tensor([20, 20, 20], device=local_device)
send_counts = torch.tensor([2, 1], dtype=torch.int32, device=local_device)

print(f'Device-{parallel_env.global_rank} sends: {[input,]}')

net.barrier()

print(f'Device-{parallel_env.global_rank} recvs: {net.batch_all_to_all_v([input,], send_counts)[0]}')

0 comments on commit fdf6e59

Please sign in to comment.