Skip to content

TP SP examples improvement #1354

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions distributed/tensor_parallelism/log_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,6 @@ def rank_log(_rank, logger, msg):

def verify_min_gpu_count(min_gpus: int = 2) -> bool:
""" verification that we have at least 2 gpus to run dist examples """
has_cuda = torch.cuda.is_available()
gpu_count = torch.cuda.device_count()
return has_cuda and gpu_count >= min_gpus
has_gpu = torch.accelerator.is_available()
gpu_count = torch.accelerator.device_count()
return has_gpu and gpu_count >= min_gpus
4 changes: 3 additions & 1 deletion distributed/tensor_parallelism/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,6 @@
--pre
--extra-index-url https://download.pytorch.org/whl/nightly/cu118
--extra-index-url https://download.pytorch.org/whl/nightly/cu121
torch >= 2.3.0.dev0; sys_platform == "linux"
--extra-index-url https://download.pytorch.org/whl/nightly/cu126
--extra-index-url https://download.pytorch.org/whl/nightly/cu128
torch >= 2.7.1; sys_platform == "linux"
9 changes: 6 additions & 3 deletions distributed/tensor_parallelism/sequence_parallel_example.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# torchrun --nnodes 1 --nproc-per-node 4 <fn>
import os
import sys
import torch
Expand Down Expand Up @@ -63,9 +64,10 @@ def forward(self, x):
"""
logger = get_logger()

device_type = torch.accelerator.current_accelerator().type
# create a device mesh based on the given world_size.
device_mesh = init_device_mesh(
device_type="cuda", mesh_shape=(int(os.environ["WORLD_SIZE"]),)
device_type=device_type, mesh_shape=(int(os.environ["WORLD_SIZE"]),)
)

_rank = device_mesh.get_rank()
Expand All @@ -75,7 +77,7 @@ def forward(self, x):
rank_log(_rank, logger, f"Device Mesh created: {device_mesh=}")

# create model and move it to GPU. Init_device_mesh has already assigned gpu ids...
model = ToyModel().to("cuda")
model = ToyModel().to(device_type)

# Custom parallelization plan for the model
sp_model = parallelize_module(
Expand All @@ -100,7 +102,8 @@ def forward(self, x):

for i in range(num_iters):
# For SP, input can be different across all ranks.
inp = torch.rand(20, 10, device="cuda")
#inp = torch.rand(20, 10, device=device_type)
inp = torch.rand(1, 10, device=device_type)
output = sp_model(inp)
output.sum().backward()
optimizer.step()
Expand Down
11 changes: 6 additions & 5 deletions distributed/tensor_parallelism/tensor_parallel_example.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# torchrun --nnodes 1 --nproc-per-node 4 <fn>
import os
import sys
import torch
Expand Down Expand Up @@ -76,8 +77,8 @@ def forward(self, x):

# create a device mesh based on the given world_size.
_world_size = int(os.environ["WORLD_SIZE"])

device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(_world_size,))
device_type = torch.accelerator.current_accelerator().type
device_mesh = init_device_mesh(device_type=device_type, mesh_shape=(_world_size,))
_rank = device_mesh.get_rank()


Expand All @@ -88,8 +89,8 @@ def forward(self, x):

rank_log(_rank, logger, f"Device Mesh created: {device_mesh=}")

# create model and move it to GPU - init"cuda"_mesh has already mapped GPU ids.
tp_model = ToyModel().to("cuda")
# create model and move it to GPU - initdevice_type_mesh has already mapped GPU ids.
tp_model = ToyModel().to(device_type)


# Custom parallelization plan for the model
Expand All @@ -116,7 +117,7 @@ def forward(self, x):
# For TP, input needs to be same across all TP ranks.
# Setting the random seed is to mimic the behavior of dataloader.
torch.manual_seed(i)
inp = torch.rand(20, 10, device="cuda")
inp = torch.rand(20, 10, device=device_type)
output = tp_model(inp)
output.sum().backward()
optimizer.step()
Expand Down