Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Tensor Parallel to torch_native_llama #1876

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
22 changes: 18 additions & 4 deletions python/sglang/bench_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,7 @@ def prepare_synthetic_inputs_for_latency_test(batch_size, input_len):
return reqs


@torch.inference_mode()
def extend(reqs, model_runner):
def _extend(reqs, model_runner):
batch = ScheduleBatch.init_new(
reqs=reqs,
req_to_token_pool=model_runner.req_to_token_pool,
Expand All @@ -237,8 +236,15 @@ def extend(reqs, model_runner):
return next_token_ids, logits_output.next_token_logits, batch


@torch.inference_mode()
def decode(input_token_ids, batch, model_runner):
def extend(reqs, model_runner):
# Disable inference mode for now when torch TP is applied. We can remove
# this workaround once DTensor adds support for inference mode.
use_inf_mode = not model_runner.torch_tp_applied
with torch.inference_mode(use_inf_mode):
return _extend(reqs, model_runner)


def _decode(input_token_ids, batch, model_runner):
batch.output_ids = input_token_ids
batch.prepare_for_decode()
model_worker_batch = batch.get_model_worker_batch()
Expand All @@ -248,6 +254,14 @@ def decode(input_token_ids, batch, model_runner):
return next_token_ids, logits_output.next_token_logits


def decode(input_token_ids, batch, model_runner):
# Disable inference mode for now when torch TP is applied. We can remove
# this workaround once DTensor adds support for inference mode.
use_inf_mode = not model_runner.torch_tp_applied
with torch.inference_mode(use_inf_mode):
return _decode(input_token_ids, batch, model_runner)


def correctness_test(
server_args,
port_args,
Expand Down
16 changes: 16 additions & 0 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,15 @@ def __init__(
min_per_gpu_memory = self.init_torch_distributed()
self.sampler = Sampler()
self.load_model()

# Apply torch TP if model supports it
supports_torch_tp = getattr(self.model, "supports_torch_tp", False)
if self.tp_size > 1 and supports_torch_tp:
self.apply_torch_tp()
self.torch_tp_applied = True
else:
self.torch_tp_applied = False

if server_args.lora_paths is not None:
self.init_lora_manager()
self.init_memory_pool(
Expand Down Expand Up @@ -548,6 +557,13 @@ def init_cuda_graphs(self):
logger.info("Capture cuda graph begin. This can take up to several minutes.")
self.cuda_graph_runner = CudaGraphRunner(self)

def apply_torch_tp(self):
logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")
from sglang.srt.model_parallel import tensor_parallel

device_mesh = torch.distributed.init_device_mesh(self.device, (self.tp_size,))
tensor_parallel(self.model, device_mesh)

def forward_decode(self, forward_batch: ForwardBatch):
if self.cuda_graph_runner and self.cuda_graph_runner.can_run(forward_batch):
return self.cuda_graph_runner.replay(forward_batch)
Expand Down
98 changes: 98 additions & 0 deletions python/sglang/srt/model_parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
"""
Common utilities for torch model parallelism.
"""

from typing import Optional, Sequence

import torch
from torch.distributed.device_mesh import DeviceMesh

try:
from torch.distributed.tensor import DTensor, Shard
except ImportError:
# torch 2.4 or older
from torch.distributed._tensor import DTensor, Shard

from torch.distributed._functional_collectives import AsyncCollectiveTensor
from torch.distributed.tensor.parallel import (
ColwiseParallel,
RowwiseParallel,
parallelize_module,
)


class ColwiseParallelSharded(ColwiseParallel):
"""
A version of ColwiseParallel where the local weight has been already
sharded. This is used for the fused wqkv case, where during loading, we
already sharded wq, wk, wv before fusing them.
"""

# Override the _partition_linear_fn in ColwiseParallel
def _partition_linear_fn(self, name, module, device_mesh):
# colwise shard weight/bias to Shard(0), weight be Shard(0)
# means Colwise as Linear is input * weight^T + bias, where
# weight would become Shard(1)
for name, param in module.named_parameters():
dtensor = DTensor.from_local(param, device_mesh, [Shard(0)])
dist_param = torch.nn.Parameter(dtensor, requires_grad=False)
module.register_parameter(name, dist_param)


class RowwiseParallelMaybeWait(RowwiseParallel):
"""
A version of RowwiseParallel that waits for the output (establish dependency
between comm stream and compute stream in CUDA sense) before going into the
next op. This is needed to workaround the current interaction between
AsyncCollectiveTensor and custom ops, such as `class RMSNorm(CustomOp)`.
"""

@staticmethod
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
outputs = super(
RowwiseParallelMaybeWait, RowwiseParallelMaybeWait
)._prepare_output_fn(
output_layouts, use_local_output, mod, outputs, device_mesh
)
# wait for the output to be ready
if isinstance(outputs, AsyncCollectiveTensor):
return outputs.wait()
else:
return outputs


def tensor_parallel(
module: torch.nn.Module,
device_mesh: Optional[DeviceMesh] = None,
):
"""
Tensor parallelize the model across the given device mesh.
Args:
module (`torch.nn.Module`):
The module to tensor parallelize.
device_mesh (`torch.distributed.DeviceMesh`):
The device mesh to use for tensor parallelism.
"""

# Tensor parallelize a nn.Module based on the `_tp_plan` attribute of the module.
# No op if `_tp_plan` attribute does not exist under the module.
# This is a helper function to be used with `model.apply` to recursively
# parallelize a model.
def tplize(mod: torch.nn.Module) -> None:
tp_plan = getattr(mod, "_tp_plan", None)
if tp_plan is None:
return
for child_name, tp_style in tp_plan.items():
submod = mod.get_submodule(child_name)
if tp_style == "Colwise":
parallelize_module(submod, device_mesh, ColwiseParallel())
elif tp_style == "Rowwise":
parallelize_module(submod, device_mesh, RowwiseParallelMaybeWait())
elif tp_style == "Colwise_Sharded":
parallelize_module(submod, device_mesh, ColwiseParallelSharded())
else:
raise ValueError(f"Unknown TP style {tp_style}")

# `apply` is a native method of `nn.Module` that recursively applies a
# function to every submodule.
module.apply(tplize)
Loading
Loading