From defd5d170dd8aabd3c53a80d2e58178ba825bf44 Mon Sep 17 00:00:00 2001 From: cascade812 Date: Sun, 16 Mar 2025 23:19:15 +0000 Subject: [PATCH 1/6] support sequence parallel Signed-off-by: cascade812 --- tests/distributed/test_comm_ops.py | 50 +++++++++- tests/distributed/test_pipeline_parallel.py | 92 ++++++++++++++++++- vllm/config.py | 4 + vllm/distributed/communication_op.py | 5 + .../base_device_communicator.py | 15 +++ vllm/distributed/parallel_state.py | 7 ++ vllm/engine/arg_utils.py | 7 ++ vllm/entrypoints/llm.py | 4 + vllm/forward_context.py | 15 ++- vllm/model_executor/layers/linear.py | 17 +++- .../model_executor/layers/logits_processor.py | 7 ++ .../layers/vocab_parallel_embedding.py | 12 ++- vllm/model_executor/models/llama.py | 2 - vllm/v1/worker/gpu_model_runner.py | 57 ++++++++++-- vllm/worker/model_runner.py | 31 ++++++- 15 files changed, 301 insertions(+), 24 deletions(-) diff --git a/tests/distributed/test_comm_ops.py b/tests/distributed/test_comm_ops.py index ac6d6aae3006..0218e4ec7649 100644 --- a/tests/distributed/test_comm_ops.py +++ b/tests/distributed/test_comm_ops.py @@ -14,7 +14,8 @@ from vllm.distributed import (broadcast_tensor_dict, get_pp_group, tensor_model_parallel_all_gather, - tensor_model_parallel_all_reduce) + tensor_model_parallel_all_reduce, + tensor_model_parallel_reduce_scatter) from ..utils import init_test_distributed_environment, multi_process_parallel @@ -47,6 +48,38 @@ def all_reduce_test_worker( torch.testing.assert_close(t, expected) +@ray.remote(num_gpus=1, max_calls=1) +def reduce_scatter_test_worker( + monkeypatch: pytest.MonkeyPatch, + tp_size: int, + pp_size: int, + rank: int, + distributed_init_port: str +): + # it is important to delete the CUDA_VISIBLE_DEVICES environment variable + # so that each worker can see all the GPUs + # they will be able to set the device to the correct GPU + monkeypatch.delenv("CUDA_VISIBLE_DEVICES", raising=False) + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + init_test_distributed_environment(tp_size, pp_size, rank, + distributed_init_port) + + num_elements = 8 + all_tensors = [ + torch.arange(num_elements, dtype=torch.float32, device="cuda") * + (r + 1) for r in range(tp_size) + ] + + index = rank % tp_size + partition_size = num_elements // tp_size + all_reduce = torch.sum(torch.stack(all_tensors, dim=0), dim=0) + expected = all_reduce[index * partition_size:(index + 1) * partition_size] + t = all_tensors[index] + t = tensor_model_parallel_reduce_scatter(t) + torch.testing.assert_close(t, expected) + + @ray.remote(num_gpus=1, max_calls=1) def all_gather_test_worker( monkeypatch: pytest.MonkeyPatch, @@ -211,6 +244,21 @@ def test_multi_process_tensor_parallel( multi_process_parallel(monkeypatch, tp_size, 1, test_target) +@pytest.mark.skipif(torch.cuda.device_count() < 2, + reason="Need at least 2 GPUs to run the test.") +@pytest.mark.parametrize("tp_size", [2]) +@pytest.mark.parametrize("test_target", [ + all_reduce_test_worker, all_gather_test_worker, reduce_scatter_test_worker, + broadcast_tensor_dict_test_worker +]) +def test_multi_process_tesor_parallel_sequence_parallel( + tp_size: int, + test_target: Callable[..., Any], + monkeypatch: pytest.MonkeyPatch, +): + multi_process_parallel(monkeypatch, tp_size, 1, test_target) + + @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Need at least 2 GPUs to run the test.") @pytest.mark.parametrize("pp_size", [2]) diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index 1342f0da29d8..4c5567212d50 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -39,6 +39,7 @@ def use_v0_only(monkeypatch): class ParallelSetup(NamedTuple): tp_size: int pp_size: int + sp_enabled: bool eager_mode: bool chunked_prefill: bool @@ -81,22 +82,27 @@ def detailed( parallel_setups=[ ParallelSetup(tp_size=tp_base, pp_size=pp_base, + sp_enabled=False, eager_mode=False, chunked_prefill=False), ParallelSetup(tp_size=tp_base, pp_size=2 * pp_base, + sp_enabled=False, eager_mode=False, chunked_prefill=True), ParallelSetup(tp_size=tp_base, pp_size=2 * pp_base, + sp_enabled=False, eager_mode=True, chunked_prefill=False), ParallelSetup(tp_size=2 * tp_base, pp_size=pp_base, + sp_enabled=False, eager_mode=False, chunked_prefill=True), ParallelSetup(tp_size=2 * tp_base, pp_size=pp_base, + sp_enabled=False, eager_mode=True, chunked_prefill=False), ], @@ -121,8 +127,9 @@ def fast( parallel_setups=[ ParallelSetup(tp_size=tp_base, pp_size=pp_base, + sp_enabled=False, eager_mode=True, - chunked_prefill=False), + chunked_prefill=False) ], distributed_backends=["mp"], vllm_major_versions=["0"], @@ -131,6 +138,42 @@ def fast( load_format=load_format), ) + @staticmethod + def sp( + *, + tp_base: int = 2, + pp_base: int = 1, + task: TaskOption = "auto", + multi_node_only: bool = False, + load_format: Optional[str] = None, + ): + return PPTestSettings( + parallel_setups=[ + ParallelSetup(tp_size=tp_base, + pp_size=pp_base, + sp_enabled=True, + eager_mode=False, + chunked_prefill=False), + ParallelSetup(tp_size=2 * tp_base, + pp_size=pp_base, + sp_enabled=True, + eager_mode=False, + chunked_prefill=True), + + # current sp doesn't support combination with pp + # ParallelSetup(tp_size=2 * tp_base, + # pp_size=2 * pp_base, + # sp_enabled=True, + # eager_mode=True, + # chunked_prefill=False), + ], + distributed_backends=["mp", "mp"], + vllm_major_versions=["0", "1"], + task=task, + test_options=PPTestOptions(multi_node_only=multi_node_only, + load_format=load_format), + ) + def iter_params(self, model_id: str): opts = self.test_options @@ -271,10 +314,10 @@ def _compare_tp( ( tp_size, pp_size, + sp_enabled, eager_mode, chunked_prefill, ) = parallel_setup - multi_node_only, load_format = test_options model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id) @@ -364,6 +407,9 @@ def _compare_tp( distributed_backend, ] + if sp_enabled: + pp_args.append("--enable-sequence-parallel") + # compare without pipeline parallelism # NOTE: use mp backend for TP # PP tests might involve multiple nodes, and ray might @@ -479,3 +525,45 @@ def test_tp_multimodal_generation( num_gpus_available, method="generate", is_multimodal=True) + + +SP_TEXT_GENERATION_MODELS = { + # [Decoder-only] + "meta-llama/Llama-3.2-1B-Instruct": PPTestSettings.sp(), +} + +SP_TEST_MODELS = [ + # TODO support other models + # [LANGUAGE GENERATION] + "meta-llama/Llama-3.2-1B-Instruct", +] + + +@pytest.mark.parametrize( + ("model_id", "parallel_setup", "distributed_backend", "vllm_major_version", + "task", "test_options"), + [ + params for model_id, settings in SP_TEXT_GENERATION_MODELS.items() + for params in settings.iter_params(model_id) + if model_id in SP_TEST_MODELS + ], +) +@fork_new_process_for_each_test +def test_tp_sp_generation( + model_id: str, + parallel_setup: ParallelSetup, + distributed_backend: str, + vllm_major_version: str, + task: TaskOption, + test_options: PPTestOptions, + num_gpus_available, +): + _compare_tp(model_id, + parallel_setup, + distributed_backend, + vllm_major_version, + task, + test_options, + num_gpus_available, + method="generate", + is_multimodal=True) diff --git a/vllm/config.py b/vllm/config.py index c510677d64ea..422a9543caad 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1358,6 +1358,8 @@ class ParallelConfig: tensor_parallel_size: int = 1 # Number of tensor parallel groups. data_parallel_size: int = 1 # Number of data parallel groups. data_parallel_rank: int = 0 # Rank of the data parallel group. + enable_sequence_parallel: bool = False # Enable sequence parallelism. + # IP of the data parallel master. data_parallel_master_ip: str = "127.0.0.1" data_parallel_master_port: int = 29500 # Port of the data parallel master. @@ -2150,6 +2152,8 @@ def create_draft_parallel_config( pipeline_parallel_size=target_parallel_config. pipeline_parallel_size, tensor_parallel_size=speculative_draft_tensor_parallel_size, + enable_sequence_parallel=target_parallel_config. + enable_sequence_parallel, distributed_executor_backend=target_parallel_config. distributed_executor_backend, max_parallel_loading_workers=target_parallel_config. diff --git a/vllm/distributed/communication_op.py b/vllm/distributed/communication_op.py index 0228264f91f9..1b48324009cf 100644 --- a/vllm/distributed/communication_op.py +++ b/vllm/distributed/communication_op.py @@ -13,6 +13,11 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: return get_tp_group().all_reduce(input_) +def tensor_model_parallel_reduce_scatter(input_: torch.Tensor) -> torch.Tensor: + """Reduce-scatter the input tensor across model parallel group.""" + return get_tp_group().reduce_scatter(input_) + + def tensor_model_parallel_all_gather(input_: torch.Tensor, dim: int = -1) -> torch.Tensor: """All-gather the input tensor across model parallel group.""" diff --git a/vllm/distributed/device_communicators/base_device_communicator.py b/vllm/distributed/device_communicators/base_device_communicator.py index eb12f8834b41..fcfca72471bf 100644 --- a/vllm/distributed/device_communicators/base_device_communicator.py +++ b/vllm/distributed/device_communicators/base_device_communicator.py @@ -35,6 +35,21 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: dist.all_reduce(input_, group=self.device_group) return input_ + def reduce_scatter(self, input_: torch.Tensor) -> torch.Tensor: + input_size = input_.size() + assert input_size[0] % self.world_size == 0, ( + f"reduce scatter doesn't work when input size {input_size} is not " + f"divisible by world size {self.world_size}") + output_size = (input_size[0] // self.world_size, ) + input_size[1:] + # Allocate output tensor. + output_tensor = torch.empty(output_size, + dtype=input_.dtype, + device=input_.device) + dist.reduce_scatter_tensor(output_tensor, + input_, + group=self.device_group) + return output_tensor + def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: if dim < 0: # Convert negative dim to positive. diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index f897f1950e4c..246bcd005ff1 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -312,6 +312,13 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor: return self.device_communicator.all_reduce(input_) + def reduce_scatter(self, input_: torch.Tensor) -> torch.Tensor: + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return input_ + + return self.device_communicator.reduce_scatter(input_) + def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: world_size = self.world_size # Bypass the function if we are using only 1 GPU. diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 02a9ec46939c..b7d0407f164c 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -115,6 +115,7 @@ class EngineArgs: # number of P/D disaggregation (or other disaggregation) workers pipeline_parallel_size: int = 1 tensor_parallel_size: int = 1 + enable_sequence_parallel: bool = False enable_expert_parallel: bool = False max_parallel_loading_workers: Optional[int] = None block_size: Optional[int] = None @@ -435,6 +436,11 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: type=int, default=EngineArgs.tensor_parallel_size, help='Number of tensor parallel replicas.') + parser.add_argument('--enable-sequence-parallel', + '-sp', + action='store_true', + default=False, + help='If enable sequence parallel') parser.add_argument( '--enable-expert-parallel', action='store_true', @@ -1243,6 +1249,7 @@ def create_engine_config( parallel_config = ParallelConfig( pipeline_parallel_size=self.pipeline_parallel_size, tensor_parallel_size=self.tensor_parallel_size, + enable_sequence_parallel=self.enable_sequence_parallel, enable_expert_parallel=self.enable_expert_parallel, max_parallel_loading_workers=self.max_parallel_loading_workers, disable_custom_all_reduce=self.disable_custom_all_reduce, diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index a0e2fa2918bd..819a99284065 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -74,6 +74,8 @@ class LLM: environments. tensor_parallel_size: The number of GPUs to use for distributed execution with tensor parallelism. + enable_sequence_parallel: Enable sequence parallelism on top of tensor + parallelism. dtype: The data type for the model weights and activations. Currently, we support `float32`, `float16`, and `bfloat16`. If `auto`, we use the `torch_dtype` attribute specified in the model config file. @@ -164,6 +166,7 @@ def __init__( trust_remote_code: bool = False, allowed_local_media_path: str = "", tensor_parallel_size: int = 1, + enable_sequence_parallel: bool = False, dtype: str = "auto", quantization: Optional[str] = None, revision: Optional[str] = None, @@ -219,6 +222,7 @@ def __init__( trust_remote_code=trust_remote_code, allowed_local_media_path=allowed_local_media_path, tensor_parallel_size=tensor_parallel_size, + enable_sequence_parallel=enable_sequence_parallel, dtype=dtype, quantization=quantization, revision=revision, diff --git a/vllm/forward_context.py b/vllm/forward_context.py index e195a03c5cac..dc76c62fe6d8 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -38,6 +38,7 @@ class ForwardContext: attn_metadata: "AttentionMetadata" # set dynamically for each forward pass # TODO: remove after making all virtual_engines share the same kv cache virtual_engine: int # set dynamically for each forward pass + enable_sequence_parallel: bool # If enable sequence_parallelism # set dynamically for each forward pass dp_metadata: Optional[DPMetadata] = None @@ -53,11 +54,16 @@ def get_forward_context() -> ForwardContext: return _forward_context +def try_get_forward_context() -> ForwardContext: + return _forward_context + + @contextmanager def set_forward_context(attn_metadata: Any, vllm_config: VllmConfig, virtual_engine: int = 0, - num_tokens: int = 0): + num_tokens: int = 0, + enable_sequence_parallel: bool = False): """A context manager that stores the current forward context, can be attention metadata, etc. Here we can inject common logic for every model forward pass. @@ -90,6 +96,12 @@ def set_forward_context(attn_metadata: Any, cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_tensor, dim=0) dp_metadata = DPMetadata(cu_tokens_across_dp_cpu) + # TODO: support pipeline parallel + if vllm_config.parallel_config.enable_sequence_parallel: + assert vllm_config.parallel_config.pipeline_parallel_size == 1, ( + "sequence parallel doesn't work correctly when " + "combined with pipeline parallel") + global _forward_context prev_context = _forward_context _forward_context = ForwardContext( @@ -97,6 +109,7 @@ def set_forward_context(attn_metadata: Any, static_forward_context, virtual_engine=virtual_engine, attn_metadata=attn_metadata, + enable_sequence_parallel=enable_sequence_parallel, dp_metadata=dp_metadata) try: yield diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 1ae574072b8f..85c4e3fcd2ae 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -13,7 +13,9 @@ get_tensor_model_parallel_world_size, split_tensor_along_last_dim, tensor_model_parallel_all_gather, - tensor_model_parallel_all_reduce) + tensor_model_parallel_all_reduce, + tensor_model_parallel_reduce_scatter) +from vllm.forward_context import try_get_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) @@ -469,6 +471,11 @@ def forward( ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: bias = self.bias if not self.skip_bias_add else None + forward_context = try_get_forward_context() + if (forward_context is not None + and forward_context.enable_sequence_parallel): + input_ = tensor_model_parallel_all_gather(input_, 0) + # Matrix multiply. assert self.quant_method is not None output_parallel = self.quant_method.apply(self, input_, bias) @@ -1258,8 +1265,14 @@ def forward( output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_) + if self.reduce_results and self.tp_size > 1: - output = tensor_model_parallel_all_reduce(output_parallel) + forward_context = try_get_forward_context() + if (forward_context is not None + and forward_context.enable_sequence_parallel): + output = tensor_model_parallel_reduce_scatter(output_parallel) + else: + output = tensor_model_parallel_all_reduce(output_parallel) else: output = output_parallel diff --git a/vllm/model_executor/layers/logits_processor.py b/vllm/model_executor/layers/logits_processor.py index 4a359725bad0..2a070d231330 100644 --- a/vllm/model_executor/layers/logits_processor.py +++ b/vllm/model_executor/layers/logits_processor.py @@ -10,6 +10,7 @@ import vllm.envs as envs from vllm.distributed import (tensor_model_parallel_all_gather, tensor_model_parallel_gather) +from vllm.forward_context import try_get_forward_context from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -59,6 +60,11 @@ def forward( sampling_metadata: Optional[SamplingMetadata] = None, embedding_bias: Optional[torch.Tensor] = None, ) -> Optional[torch.Tensor]: + forward_context = try_get_forward_context() + if (forward_context is not None + and forward_context.enable_sequence_parallel): + hidden_states = tensor_model_parallel_all_gather(hidden_states, + dim=0) if self.logits_as_input: logits = hidden_states else: @@ -105,6 +111,7 @@ def _get_logits( embedding_bias: Optional[torch.Tensor], ) -> Optional[torch.Tensor]: # Get the logits for the next tokens. + logits = lm_head.quant_method.apply(lm_head, hidden_states, bias=embedding_bias) diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py index f65dfc3cb329..7a4878eedb7b 100644 --- a/vllm/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm/model_executor/layers/vocab_parallel_embedding.py @@ -9,7 +9,9 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, - tensor_model_parallel_all_reduce) + tensor_model_parallel_all_reduce, + tensor_model_parallel_reduce_scatter) +from vllm.forward_context import try_get_forward_context from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, method_has_implemented_embedding) from vllm.model_executor.parameter import BasevLLMParameter @@ -204,7 +206,6 @@ def __init__(self, quant_config: Optional[QuantizationConfig] = None, prefix: str = ""): super().__init__() - # Keep the input dimensions. tp_rank = get_tensor_model_parallel_rank() self.tp_size = get_tensor_model_parallel_world_size() @@ -418,7 +419,12 @@ def forward(self, input_): if self.tp_size > 1: output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0) # Reduce across all the model parallel GPUs. - output = tensor_model_parallel_all_reduce(output_parallel) + forward_context = try_get_forward_context() + if (forward_context is not None + and forward_context.enable_sequence_parallel): + output = tensor_model_parallel_reduce_scatter(output_parallel) + else: + output = tensor_model_parallel_all_reduce(output_parallel) return output def extra_repr(self) -> str: diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 81b5d9bda9ac..c4b5fbf67638 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -479,7 +479,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.model = self._init_model(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) - if get_pp_group().is_last_rank: self.unpadded_vocab_size = config.vocab_size if lora_config: @@ -500,7 +499,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): if config.tie_word_embeddings: self.lm_head = self.lm_head.tie_weights( self.model.embed_tokens) - logit_scale = getattr(config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 657333c6d84c..7f77cf164026 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -856,7 +856,6 @@ def _execute_encoder(self, scheduler_output: "SchedulerOutput"): # depending on the input multimodal items. curr_group_outputs = self.model.get_multimodal_embeddings( **batched_mm_inputs) - for output in curr_group_outputs: encoder_outputs.append(output) @@ -1026,9 +1025,22 @@ def execute_model( for k, v in self.intermediate_tensors.items() }) + # only do sequence parallelism when num of tokens + # is divisible by parallel size. + # sequence parallelism uses torch.distributed.reduce_scatter which only + # supports the case when size is divisible by parallel size + enable_sequence_parallel = ( + self.vllm_config.parallel_config.enable_sequence_parallel + and num_input_tokens % + self.vllm_config.parallel_config.tensor_parallel_size == 0) + # Run the decoder. # Use persistent buffers for CUDA graphs. - with set_forward_context(attn_metadata, self.vllm_config): + with set_forward_context( + attn_metadata, + self.vllm_config, + enable_sequence_parallel=enable_sequence_parallel, + ): hidden_states = self.model( input_ids=input_ids, positions=positions, @@ -1041,7 +1053,11 @@ def execute_model( hidden_states = hidden_states[:num_scheduled_tokens] sample_hidden_states = hidden_states[logits_indices] - logits = self.model.compute_logits(sample_hidden_states, None) + with set_forward_context( + attn_metadata, + self.vllm_config, + enable_sequence_parallel=enable_sequence_parallel): + logits = self.model.compute_logits(sample_hidden_states, None) # Apply structured output bitmasks if present if scheduler_output.grammar_bitmask is not None: @@ -1218,7 +1234,16 @@ def _get_prompt_logprobs_dict( req_idx = self.input_batch.req_id_to_index[req_id] offset = self.query_start_loc_np[req_idx].item() prompt_hidden_states = hidden_states[offset:offset + num_logits] - logits = self.model.compute_logits(prompt_hidden_states, None) + + enable_sequence_parallel = ( + self.vllm_config.parallel_config.enable_sequence_parallel + and num_tokens % + self.vllm_config.parallel_config.tensor_parallel_size == 0) + with set_forward_context( + None, + self.vllm_config, + enable_sequence_parallel=enable_sequence_parallel): + logits = self.model.compute_logits(prompt_hidden_states, None) # Get the "target" tokens for each index. For prompt at index i, # the token at prompt index i+1 is the "sampled" token we want @@ -1295,9 +1320,17 @@ def _dummy_run( for k, v in self.intermediate_tensors.items() }) - with set_forward_context(None, - self.vllm_config, - num_tokens=num_tokens): + enable_sequence_parallel = ( + self.vllm_config.parallel_config.enable_sequence_parallel + and num_tokens % + self.vllm_config.parallel_config.tensor_parallel_size == 0) + + with set_forward_context( + None, + self.vllm_config, + num_tokens=num_tokens, + enable_sequence_parallel=enable_sequence_parallel, + ): hidden_states = model( input_ids=input_ids, positions=positions, @@ -1314,7 +1347,15 @@ def _dummy_sampler_run( hidden_states: torch.Tensor, ) -> torch.Tensor: - logits = self.model.compute_logits(hidden_states, None) + enable_sequence_parallel = ( + self.vllm_config.parallel_config.enable_sequence_parallel + and hidden_states.size()[0] % + self.vllm_config.parallel_config.tensor_parallel_size == 0) + with set_forward_context( + None, + self.vllm_config, + enable_sequence_parallel=enable_sequence_parallel): + logits = self.model.compute_logits(hidden_states, None) num_reqs = logits.size(0) dummy_tensors = lambda v: torch.full( diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 473bd901b5b2..3cd7c994f809 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1556,7 +1556,7 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: capture_inputs) with set_forward_context(attn_metadata, self.vllm_config, - virtual_engine): + virtual_engine, batch_size): graph_runner.capture(**capture_inputs) self.graph_memory_pool = graph_runner.graph.pool() self.graph_runners[virtual_engine][batch_size] = ( @@ -1736,9 +1736,24 @@ def execute_model( model_forward_end = torch.cuda.Event(enable_timing=True) model_forward_start.record() + num_tokens = (model_input.input_tokens.shape[0] + if model_input.input_tokens is not None else None) + + # only do sequence parallelism when num of tokens + # is divisible by parallel size. + # sequence parallelism uses torch.distributed.reduce_scatter which only + # supports the case when size is divisible by parallel size + enable_sequence_parallel = ( + self.vllm_config.parallel_config.enable_sequence_parallel + and num_tokens is not None and num_tokens % + self.vllm_config.parallel_config.tensor_parallel_size == 0) + if not bypass_model_exec: - with set_forward_context(model_input.attn_metadata, - self.vllm_config, virtual_engine): + with set_forward_context( + model_input.attn_metadata, + self.vllm_config, + virtual_engine, + enable_sequence_parallel=enable_sequence_parallel): hidden_or_intermediate_states = model_executable( input_ids=model_input.input_tokens, positions=model_input.input_positions, @@ -1785,8 +1800,14 @@ def execute_model( torch.tensor(model_forward_time + orig_model_forward_time)) return hidden_or_intermediate_states - logits = self.model.compute_logits(hidden_or_intermediate_states, - model_input.sampling_metadata) + with set_forward_context( + model_input.attn_metadata, + self.vllm_config, + virtual_engine, + enable_sequence_parallel=enable_sequence_parallel, + ): + logits = self.model.compute_logits(hidden_or_intermediate_states, + model_input.sampling_metadata) if not self.is_driver_worker: return [] From b0a9c0107a473d0a5f6107335a7889cc2480a642 Mon Sep 17 00:00:00 2001 From: cascade812 Date: Mon, 17 Mar 2025 00:31:56 +0000 Subject: [PATCH 2/6] fix Signed-off-by: cascade812 --- vllm/forward_context.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/forward_context.py b/vllm/forward_context.py index dc76c62fe6d8..5359550c6319 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -54,7 +54,7 @@ def get_forward_context() -> ForwardContext: return _forward_context -def try_get_forward_context() -> ForwardContext: +def try_get_forward_context() -> Optional[ForwardContext]: return _forward_context From 3045445db4615d7c3107281c403bbfc1373cb09d Mon Sep 17 00:00:00 2001 From: cascade812 Date: Wed, 19 Mar 2025 05:33:27 +0000 Subject: [PATCH 3/6] update Signed-off-by: cascade812 --- tests/distributed/test_comm_ops.py | 10 +++------- tests/distributed/test_pipeline_parallel.py | 2 +- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/tests/distributed/test_comm_ops.py b/tests/distributed/test_comm_ops.py index 0218e4ec7649..351b160d8e6e 100644 --- a/tests/distributed/test_comm_ops.py +++ b/tests/distributed/test_comm_ops.py @@ -49,13 +49,9 @@ def all_reduce_test_worker( @ray.remote(num_gpus=1, max_calls=1) -def reduce_scatter_test_worker( - monkeypatch: pytest.MonkeyPatch, - tp_size: int, - pp_size: int, - rank: int, - distributed_init_port: str -): +def reduce_scatter_test_worker(monkeypatch: pytest.MonkeyPatch, tp_size: int, + pp_size: int, rank: int, + distributed_init_port: str): # it is important to delete the CUDA_VISIBLE_DEVICES environment variable # so that each worker can see all the GPUs # they will be able to set the device to the correct GPU diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index 4c5567212d50..6c476ce5fa13 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -548,7 +548,7 @@ def test_tp_multimodal_generation( if model_id in SP_TEST_MODELS ], ) -@fork_new_process_for_each_test +@create_new_process_for_each_test def test_tp_sp_generation( model_id: str, parallel_setup: ParallelSetup, From bd7f3d463d0eaed628c66737ef757056e1758476 Mon Sep 17 00:00:00 2001 From: cascade812 Date: Wed, 19 Mar 2025 20:19:35 +0000 Subject: [PATCH 4/6] fix Signed-off-by: cascade812 --- tests/distributed/test_pipeline_parallel.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index 6c476ce5fa13..fc00699cbcc5 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -548,7 +548,7 @@ def test_tp_multimodal_generation( if model_id in SP_TEST_MODELS ], ) -@create_new_process_for_each_test +@create_new_process_for_each_test() def test_tp_sp_generation( model_id: str, parallel_setup: ParallelSetup, @@ -566,4 +566,4 @@ def test_tp_sp_generation( test_options, num_gpus_available, method="generate", - is_multimodal=True) + is_multimodal=False) From 62c43f96b9c0550dc9a95f769cb52df86f4d981c Mon Sep 17 00:00:00 2001 From: cascade812 Date: Fri, 21 Mar 2025 04:33:18 +0000 Subject: [PATCH 5/6] update test and fix v1 Signed-off-by: cascade812 --- tests/distributed/test_sequence_parallel.py | 283 ++++++++++++++++++++ vllm/v1/worker/gpu_model_runner.py | 33 +-- 2 files changed, 293 insertions(+), 23 deletions(-) create mode 100644 tests/distributed/test_sequence_parallel.py diff --git a/tests/distributed/test_sequence_parallel.py b/tests/distributed/test_sequence_parallel.py new file mode 100644 index 000000000000..53a59a665980 --- /dev/null +++ b/tests/distributed/test_sequence_parallel.py @@ -0,0 +1,283 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +WARNING: This test runs in both single-node (4 GPUs) and multi-node + (2 node with 2 GPUs each) modes. If the test only uses 2 GPUs, it is + important to set the distributed backend to "mp" to avoid Ray scheduling + all workers in a node other than the head node, which can cause the test + to fail. +""" +import json +import os +from dataclasses import dataclass +from typing import Literal, NamedTuple, Optional + +import pytest + +from vllm.config import TaskOption +from vllm.logger import init_logger + +from ..models.registry import HF_EXAMPLE_MODELS +from ..utils import compare_two_settings, create_new_process_for_each_test + +logger = init_logger("test_sequence_parallel") + +VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1" + + +class ParallelSetup(NamedTuple): + tp_size: int + sp_enabled: bool + eager_mode: bool + chunked_prefill: bool + + +class SPTestOptions(NamedTuple): + multi_node_only: bool + load_format: Optional[str] = None + + +@dataclass +class SPTestSettings: + parallel_setups: list[ParallelSetup] + # NOTE: the length of distributed_backends and + # vllm_major_versions should be the same, and they + # are first zipped together to iterate over all + # test settings. + distributed_backends: list[str] + # vllm major version: "0" for V0, "1" for V1 + vllm_major_versions: list[str] + task: TaskOption + test_options: SPTestOptions + + def __post_init__(self): + if len(self.distributed_backends) != len(self.vllm_major_versions): + raise ValueError( + f"Length mismatch: distributed_backends " + f"({len(self.distributed_backends)}) != " + f"vllm_major_versions ({len(self.vllm_major_versions)})") + + @staticmethod + def detailed( + *, + tp_base: int = 2, + multi_node_only: bool = False, + task: TaskOption = "auto", + load_format: Optional[str] = None, + ): + return SPTestSettings( + parallel_setups=[ + ParallelSetup(tp_size=tp_base, + sp_enabled=True, + eager_mode=False, + chunked_prefill=False), + ParallelSetup(tp_size=tp_base, + sp_enabled=True, + eager_mode=False, + chunked_prefill=True), + ParallelSetup(tp_size=tp_base, + sp_enabled=True, + eager_mode=True, + chunked_prefill=False), + ParallelSetup(tp_size=2 * tp_base, + sp_enabled=True, + eager_mode=True, + chunked_prefill=True) + ], + # only ray is supported for V1 + # distributed_backends=["mp", "mp", "ray", "ray"], + # vllm_major_versions=["0", "1", "0", "1"], + distributed_backends=["mp", "mp"], + vllm_major_versions=["0", "1"], + task=task, + test_options=SPTestOptions(multi_node_only=multi_node_only, + load_format=load_format), + ) + + @staticmethod + def fast( + *, + tp_base: int = 2, + task: TaskOption = "auto", + multi_node_only: bool = False, + load_format: Optional[str] = None, + ): + return SPTestSettings( + parallel_setups=[ + ParallelSetup(tp_size=tp_base, + sp_enabled=True, + eager_mode=True, + chunked_prefill=False), + ], + distributed_backends=["mp"], + vllm_major_versions=["1"], + task=task, + test_options=SPTestOptions(multi_node_only=multi_node_only, + load_format=load_format), + ) + + def iter_params(self, model_id: str): + opts = self.test_options + + for parallel_setup in self.parallel_setups: + for backend, vllm_major_version in zip(self.distributed_backends, + self.vllm_major_versions): + yield (model_id, parallel_setup, backend, vllm_major_version, + self.task, opts) + + +def _compare_sp( + model_id: str, + parallel_setup: ParallelSetup, + distributed_backend: str, + vllm_major_version: str, + task: TaskOption, + test_options: SPTestOptions, + num_gpus_available: int, + *, + method: Literal["generate", "encode"], + is_multimodal: bool, +): + ( + tp_size, + sp_enabled, + eager_mode, + chunked_prefill, + ) = parallel_setup + + multi_node_only, load_format = test_options + + model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id) + model_info.check_transformers_version(on_fail="skip") + + trust_remote_code = model_info.trust_remote_code + tokenizer_mode = model_info.tokenizer_mode + hf_overrides = model_info.hf_overrides + + if load_format == "dummy": + # Avoid OOM + text_overrides = { + "num_hidden_layers": 4, + "hidden_size": 512, + "intermediate_size": 800, + "num_attention_heads": 4, + "num_key_value_heads": 1, + } + + if is_multimodal: + hf_overrides.update({"text_config": text_overrides}) + else: + hf_overrides.update(text_overrides) + else: + model_info.check_available_online(on_fail="skip") + + pp_size = 1 + if num_gpus_available < tp_size * pp_size: + pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs") + if VLLM_MULTI_NODE and distributed_backend == "mp": + pytest.skip("Skipping multi-node pipeline parallel test for " + "multiprocessing distributed backend") + if multi_node_only and not VLLM_MULTI_NODE: + pytest.skip("Not in multi-node setting") + + common_args = [ + # use half precision for speed and memory savings in CI environment + "--dtype", + "float16", + "--max-model-len", + "2048", + "--max-num-seqs", + "8", + ] + if chunked_prefill: + common_args.append("--enable-chunked-prefill") + if eager_mode: + common_args.append("--enforce-eager") + if task != "auto": + common_args.extend(["--task", task]) + if trust_remote_code: + common_args.append("--trust-remote-code") + if tokenizer_mode: + common_args.extend(["--tokenizer-mode", tokenizer_mode]) + if load_format: + common_args.extend(["--load-format", load_format]) + if hf_overrides: + common_args.extend(["--hf-overrides", json.dumps(hf_overrides)]) + + sp_env = None + sp_args = [ + *common_args, + "--enable-sequence-parallel", + "--tensor-parallel-size", + str(tp_size), + "--distributed-executor-backend", + distributed_backend, + ] + + tp_env = { + "VLLM_USE_V1": vllm_major_version, + } + tp_args = [ + *common_args, + "--tensor-parallel-size", + str(tp_size), + "--distributed-executor-backend", + "mp", + ] + + try: + compare_two_settings(model_id, + sp_args, + tp_args, + sp_env, + tp_env, + method=method) + except Exception: + testing_ray_compiled_graph = sp_env is not None + if testing_ray_compiled_graph and vllm_major_version == "0": + # Ray Compiled Graph tests are flaky for V0, + # so we don't want to fail the test + logger.exception("Ray Compiled Graph tests failed") + else: + raise + + +SP_TEXT_GENERATION_MODELS = { + # [Decoder-only] + "unsloth/Llama-3.2-1B-Instruct": SPTestSettings.detailed(), +} + +SP_TEST_MODELS = [ + # TODO support other models + # [LANGUAGE GENERATION] + "unsloth/Llama-3.2-1B-Instruct", +] + + +@pytest.mark.parametrize( + ("model_id", "parallel_setup", "distributed_backend", "vllm_major_version", + "task", "test_options"), + [ + params for model_id, settings in SP_TEXT_GENERATION_MODELS.items() + for params in settings.iter_params(model_id) + if model_id in SP_TEST_MODELS + ], +) +@create_new_process_for_each_test() +def test_tp_sp_generation( + model_id: str, + parallel_setup: ParallelSetup, + distributed_backend: str, + vllm_major_version: str, + task: TaskOption, + test_options: SPTestOptions, + num_gpus_available, +): + _compare_sp(model_id, + parallel_setup, + distributed_backend, + vllm_major_version, + task, + test_options, + num_gpus_available, + method="generate", + is_multimodal=False) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 7f77cf164026..bdf7eefd4046 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -13,6 +13,7 @@ from vllm.attention import AttentionType, get_attn_backend from vllm.attention.layer import Attention from vllm.config import CompilationLevel, VllmConfig +from vllm.distributed import tensor_model_parallel_all_gather from vllm.distributed.parallel_state import get_pp_group, graph_capture from vllm.forward_context import set_forward_context from vllm.inputs import INPUT_REGISTRY @@ -1051,13 +1052,12 @@ def execute_model( # For mid-pipeline stages, return the hidden states. return hidden_states + if enable_sequence_parallel: + hidden_states = tensor_model_parallel_all_gather(hidden_states, + dim=0) hidden_states = hidden_states[:num_scheduled_tokens] sample_hidden_states = hidden_states[logits_indices] - with set_forward_context( - attn_metadata, - self.vllm_config, - enable_sequence_parallel=enable_sequence_parallel): - logits = self.model.compute_logits(sample_hidden_states, None) + logits = self.model.compute_logits(sample_hidden_states, None) # Apply structured output bitmasks if present if scheduler_output.grammar_bitmask is not None: @@ -1235,15 +1235,7 @@ def _get_prompt_logprobs_dict( offset = self.query_start_loc_np[req_idx].item() prompt_hidden_states = hidden_states[offset:offset + num_logits] - enable_sequence_parallel = ( - self.vllm_config.parallel_config.enable_sequence_parallel - and num_tokens % - self.vllm_config.parallel_config.tensor_parallel_size == 0) - with set_forward_context( - None, - self.vllm_config, - enable_sequence_parallel=enable_sequence_parallel): - logits = self.model.compute_logits(prompt_hidden_states, None) + logits = self.model.compute_logits(prompt_hidden_states, None) # Get the "target" tokens for each index. For prompt at index i, # the token at prompt index i+1 is the "sampled" token we want @@ -1337,6 +1329,9 @@ def _dummy_run( intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, ) + if get_pp_group().is_last_rank and enable_sequence_parallel: + hidden_states = tensor_model_parallel_all_gather(hidden_states, + dim=0) logit_indices = np.cumsum(num_scheduled_tokens) - 1 return hidden_states[logit_indices] @@ -1347,15 +1342,7 @@ def _dummy_sampler_run( hidden_states: torch.Tensor, ) -> torch.Tensor: - enable_sequence_parallel = ( - self.vllm_config.parallel_config.enable_sequence_parallel - and hidden_states.size()[0] % - self.vllm_config.parallel_config.tensor_parallel_size == 0) - with set_forward_context( - None, - self.vllm_config, - enable_sequence_parallel=enable_sequence_parallel): - logits = self.model.compute_logits(hidden_states, None) + logits = self.model.compute_logits(hidden_states, None) num_reqs = logits.size(0) dummy_tensors = lambda v: torch.full( From 9cebd6f7ab68a93158acf810195e9dbaf4434774 Mon Sep 17 00:00:00 2001 From: cascade812 Date: Sat, 22 Mar 2025 04:12:59 +0000 Subject: [PATCH 6/6] update reduce_scatter and tests Signed-off-by: cascade812 --- .buildkite/test-pipeline.yaml | 2 + tests/distributed/test_comm_ops.py | 2 +- tests/distributed/test_pipeline_parallel.py | 92 +------------------ tests/distributed/test_sequence_parallel.py | 33 ++++--- vllm/distributed/communication_op.py | 5 +- .../base_device_communicator.py | 47 +++++++--- .../device_communicators/cuda_communicator.py | 25 +++++ vllm/distributed/parallel_state.py | 34 ++++++- vllm/model_executor/layers/linear.py | 3 +- .../layers/vocab_parallel_embedding.py | 3 +- vllm/model_executor/models/llama.py | 2 + 11 files changed, 119 insertions(+), 129 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 230dd8383420..b383812ea274 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -527,6 +527,8 @@ steps: # - pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py - VLLM_USE_V1=0 CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py - VLLM_USE_V1=0 CUDA_VISIBLE_DEVICES=0,1 pytest -v -s kv_transfer/test_disagg.py + # test sequence parallel + - pytest -v -s distributed/test_sequence_parallel.py - label: Plugin Tests (2 GPUs) # 40min working_dir: "/vllm-workspace/tests" diff --git a/tests/distributed/test_comm_ops.py b/tests/distributed/test_comm_ops.py index 351b160d8e6e..d97598b23fab 100644 --- a/tests/distributed/test_comm_ops.py +++ b/tests/distributed/test_comm_ops.py @@ -72,7 +72,7 @@ def reduce_scatter_test_worker(monkeypatch: pytest.MonkeyPatch, tp_size: int, all_reduce = torch.sum(torch.stack(all_tensors, dim=0), dim=0) expected = all_reduce[index * partition_size:(index + 1) * partition_size] t = all_tensors[index] - t = tensor_model_parallel_reduce_scatter(t) + t = tensor_model_parallel_reduce_scatter(t, 0) torch.testing.assert_close(t, expected) diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index fc00699cbcc5..1342f0da29d8 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -39,7 +39,6 @@ def use_v0_only(monkeypatch): class ParallelSetup(NamedTuple): tp_size: int pp_size: int - sp_enabled: bool eager_mode: bool chunked_prefill: bool @@ -82,27 +81,22 @@ def detailed( parallel_setups=[ ParallelSetup(tp_size=tp_base, pp_size=pp_base, - sp_enabled=False, eager_mode=False, chunked_prefill=False), ParallelSetup(tp_size=tp_base, pp_size=2 * pp_base, - sp_enabled=False, eager_mode=False, chunked_prefill=True), ParallelSetup(tp_size=tp_base, pp_size=2 * pp_base, - sp_enabled=False, eager_mode=True, chunked_prefill=False), ParallelSetup(tp_size=2 * tp_base, pp_size=pp_base, - sp_enabled=False, eager_mode=False, chunked_prefill=True), ParallelSetup(tp_size=2 * tp_base, pp_size=pp_base, - sp_enabled=False, eager_mode=True, chunked_prefill=False), ], @@ -127,9 +121,8 @@ def fast( parallel_setups=[ ParallelSetup(tp_size=tp_base, pp_size=pp_base, - sp_enabled=False, eager_mode=True, - chunked_prefill=False) + chunked_prefill=False), ], distributed_backends=["mp"], vllm_major_versions=["0"], @@ -138,42 +131,6 @@ def fast( load_format=load_format), ) - @staticmethod - def sp( - *, - tp_base: int = 2, - pp_base: int = 1, - task: TaskOption = "auto", - multi_node_only: bool = False, - load_format: Optional[str] = None, - ): - return PPTestSettings( - parallel_setups=[ - ParallelSetup(tp_size=tp_base, - pp_size=pp_base, - sp_enabled=True, - eager_mode=False, - chunked_prefill=False), - ParallelSetup(tp_size=2 * tp_base, - pp_size=pp_base, - sp_enabled=True, - eager_mode=False, - chunked_prefill=True), - - # current sp doesn't support combination with pp - # ParallelSetup(tp_size=2 * tp_base, - # pp_size=2 * pp_base, - # sp_enabled=True, - # eager_mode=True, - # chunked_prefill=False), - ], - distributed_backends=["mp", "mp"], - vllm_major_versions=["0", "1"], - task=task, - test_options=PPTestOptions(multi_node_only=multi_node_only, - load_format=load_format), - ) - def iter_params(self, model_id: str): opts = self.test_options @@ -314,10 +271,10 @@ def _compare_tp( ( tp_size, pp_size, - sp_enabled, eager_mode, chunked_prefill, ) = parallel_setup + multi_node_only, load_format = test_options model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id) @@ -407,9 +364,6 @@ def _compare_tp( distributed_backend, ] - if sp_enabled: - pp_args.append("--enable-sequence-parallel") - # compare without pipeline parallelism # NOTE: use mp backend for TP # PP tests might involve multiple nodes, and ray might @@ -525,45 +479,3 @@ def test_tp_multimodal_generation( num_gpus_available, method="generate", is_multimodal=True) - - -SP_TEXT_GENERATION_MODELS = { - # [Decoder-only] - "meta-llama/Llama-3.2-1B-Instruct": PPTestSettings.sp(), -} - -SP_TEST_MODELS = [ - # TODO support other models - # [LANGUAGE GENERATION] - "meta-llama/Llama-3.2-1B-Instruct", -] - - -@pytest.mark.parametrize( - ("model_id", "parallel_setup", "distributed_backend", "vllm_major_version", - "task", "test_options"), - [ - params for model_id, settings in SP_TEXT_GENERATION_MODELS.items() - for params in settings.iter_params(model_id) - if model_id in SP_TEST_MODELS - ], -) -@create_new_process_for_each_test() -def test_tp_sp_generation( - model_id: str, - parallel_setup: ParallelSetup, - distributed_backend: str, - vllm_major_version: str, - task: TaskOption, - test_options: PPTestOptions, - num_gpus_available, -): - _compare_tp(model_id, - parallel_setup, - distributed_backend, - vllm_major_version, - task, - test_options, - num_gpus_available, - method="generate", - is_multimodal=False) diff --git a/tests/distributed/test_sequence_parallel.py b/tests/distributed/test_sequence_parallel.py index 53a59a665980..7a828b1b6525 100644 --- a/tests/distributed/test_sequence_parallel.py +++ b/tests/distributed/test_sequence_parallel.py @@ -66,14 +66,15 @@ def detailed( ): return SPTestSettings( parallel_setups=[ - ParallelSetup(tp_size=tp_base, - sp_enabled=True, - eager_mode=False, - chunked_prefill=False), - ParallelSetup(tp_size=tp_base, - sp_enabled=True, - eager_mode=False, - chunked_prefill=True), + # TODO support eager_mode = False + # ParallelSetup(tp_size=tp_base, + # sp_enabled=True, + # eager_mode=False, + # chunked_prefill=False), + # ParallelSetup(tp_size=tp_base, + # sp_enabled=True, + # eager_mode=False, + # chunked_prefill=True), ParallelSetup(tp_size=tp_base, sp_enabled=True, eager_mode=True, @@ -84,10 +85,8 @@ def detailed( chunked_prefill=True) ], # only ray is supported for V1 - # distributed_backends=["mp", "mp", "ray", "ray"], - # vllm_major_versions=["0", "1", "0", "1"], - distributed_backends=["mp", "mp"], - vllm_major_versions=["0", "1"], + distributed_backends=["mp", "mp", "ray", "ray"], + vllm_major_versions=["0", "1", "0", "1"], task=task, test_options=SPTestOptions(multi_node_only=multi_node_only, load_format=load_format), @@ -96,7 +95,7 @@ def detailed( @staticmethod def fast( *, - tp_base: int = 2, + tp_base: int = 4, task: TaskOption = "auto", multi_node_only: bool = False, load_format: Optional[str] = None, @@ -108,8 +107,8 @@ def fast( eager_mode=True, chunked_prefill=False), ], - distributed_backends=["mp"], - vllm_major_versions=["1"], + distributed_backends=["mp", "mp"], + vllm_major_versions=["0", "1"], task=task, test_options=SPTestOptions(multi_node_only=multi_node_only, load_format=load_format), @@ -243,13 +242,13 @@ def _compare_sp( SP_TEXT_GENERATION_MODELS = { # [Decoder-only] - "unsloth/Llama-3.2-1B-Instruct": SPTestSettings.detailed(), + "meta-llama/Llama-3.2-1B-Instruct": SPTestSettings.fast(), } SP_TEST_MODELS = [ # TODO support other models # [LANGUAGE GENERATION] - "unsloth/Llama-3.2-1B-Instruct", + "meta-llama/Llama-3.2-1B-Instruct", ] diff --git a/vllm/distributed/communication_op.py b/vllm/distributed/communication_op.py index 1b48324009cf..0552254ba423 100644 --- a/vllm/distributed/communication_op.py +++ b/vllm/distributed/communication_op.py @@ -13,9 +13,10 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: return get_tp_group().all_reduce(input_) -def tensor_model_parallel_reduce_scatter(input_: torch.Tensor) -> torch.Tensor: +def tensor_model_parallel_reduce_scatter(input_: torch.Tensor, + dim: int) -> torch.Tensor: """Reduce-scatter the input tensor across model parallel group.""" - return get_tp_group().reduce_scatter(input_) + return get_tp_group().reduce_scatter(input_, dim) def tensor_model_parallel_all_gather(input_: torch.Tensor, diff --git a/vllm/distributed/device_communicators/base_device_communicator.py b/vllm/distributed/device_communicators/base_device_communicator.py index fcfca72471bf..5f0d631a6bfe 100644 --- a/vllm/distributed/device_communicators/base_device_communicator.py +++ b/vllm/distributed/device_communicators/base_device_communicator.py @@ -35,20 +35,39 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: dist.all_reduce(input_, group=self.device_group) return input_ - def reduce_scatter(self, input_: torch.Tensor) -> torch.Tensor: - input_size = input_.size() - assert input_size[0] % self.world_size == 0, ( - f"reduce scatter doesn't work when input size {input_size} is not " - f"divisible by world size {self.world_size}") - output_size = (input_size[0] // self.world_size, ) + input_size[1:] - # Allocate output tensor. - output_tensor = torch.empty(output_size, - dtype=input_.dtype, - device=input_.device) - dist.reduce_scatter_tensor(output_tensor, - input_, - group=self.device_group) - return output_tensor + def reduce_scatter(self, + input_: torch.Tensor, + dim: int = -1) -> torch.Tensor: + world_size = self.world_size + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + assert -input_.dim() <= dim < input_.dim(), ( + f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") + + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + + # Note: This will produce an incorrect answer if we don't make + # the input_tensor contiguous. Possible bug in reduce_scatter_tensor? + input_tensor = input_.movedim(0, dim).contiguous() + + assert input_tensor.shape[0] % world_size == 0 + chunk_size = input_tensor.shape[0] // world_size + output_shape = (chunk_size, ) + input_tensor.shape[1:] + + output_tensor = torch.empty(output_shape, + dtype=input_tensor.dtype, + device=input_tensor.device) + + # Perform reduce-scatter operation + torch.distributed.reduce_scatter_tensor(output_tensor, + input_tensor, + group=self.device_group) + + # Reshape before returning + return output_tensor.movedim(0, dim).contiguous() def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: if dim < 0: diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index 07c9ff506092..8bca278f3888 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -70,6 +70,31 @@ def all_reduce(self, input_): torch.distributed.all_reduce(out, group=self.device_group) return out + def reduce_scatter(self, input_: torch.Tensor, dim: int = -1): + world_size = self.world_size + pynccl_comm = self.pynccl_comm + assert pynccl_comm is not None + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + + # Note: This will produce an incorrect answer if we don't make + # the input_tensor contiguous. Possible bug in reduce_scatter_tensor? + input_tensor = input_.movedim(0, dim).contiguous() + + assert input_tensor.shape[0] % world_size == 0 + chunk_size = input_tensor.shape[0] // world_size + output_shape = (chunk_size, ) + input_tensor.shape[1:] + + output = torch.empty(output_shape, + dtype=input_tensor.dtype, + device=input_tensor.device) + + pynccl_comm.reduce_scatter(output, input_) + + # Reshape before returning + return output.movedim(0, dim).contiguous() + def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None: """Sends a tensor to the destination rank in a non-blocking way""" """NOTE: `dst` is the local rank of the destination rank.""" diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 246bcd005ff1..4899cbb8b7e8 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -118,6 +118,22 @@ def all_reduce_fake(tensor: torch.Tensor, group_name: str) -> torch.Tensor: return torch.empty_like(tensor) +def reduce_scatter(tensor: torch.Tensor, dim: int, world_size: int, + group_name: str) -> torch.Tensor: + assert group_name in _groups, f"Group {group_name} is not found." + group = _groups[group_name]() + if group is None: + raise ValueError(f"Group {group_name} is destroyed.") + return group.reduce_scatter(tensor, dim) + + +def reduce_scatter_fake(tensor: torch.Tensor, dim: int, world_size: int, + group_name: str) -> torch.Tensor: + new_shape = list(tensor.shape) + new_shape[dim] = tensor.shape[dim] // world_size + return torch.empty(new_shape, dtype=tensor.dtype, device=tensor.device) + + if supports_custom_op(): direct_register_custom_op( op_name="all_reduce", @@ -126,6 +142,13 @@ def all_reduce_fake(tensor: torch.Tensor, group_name: str) -> torch.Tensor: fake_impl=all_reduce_fake, ) + direct_register_custom_op( + op_name="reduce_scatter", + op_func=reduce_scatter, + mutates_args=[], + fake_impl=reduce_scatter_fake, + ) + class GroupCoordinator: """ @@ -312,12 +335,17 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor: return self.device_communicator.all_reduce(input_) - def reduce_scatter(self, input_: torch.Tensor) -> torch.Tensor: + def reduce_scatter(self, + input_: torch.Tensor, + dim: int = -1) -> torch.Tensor: + world_size = self.world_size # Bypass the function if we are using only 1 GPU. - if self.world_size == 1: + if world_size == 1: return input_ + assert -input_.dim() <= dim < input_.dim(), ( + f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") - return self.device_communicator.reduce_scatter(input_) + return self.device_communicator.reduce_scatter(input_, dim) def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: world_size = self.world_size diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 85c4e3fcd2ae..6c680a525b7d 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -1270,7 +1270,8 @@ def forward( forward_context = try_get_forward_context() if (forward_context is not None and forward_context.enable_sequence_parallel): - output = tensor_model_parallel_reduce_scatter(output_parallel) + output = tensor_model_parallel_reduce_scatter(output_parallel, + dim=0) else: output = tensor_model_parallel_all_reduce(output_parallel) else: diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py index 7a4878eedb7b..b6a92e4db850 100644 --- a/vllm/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm/model_executor/layers/vocab_parallel_embedding.py @@ -422,7 +422,8 @@ def forward(self, input_): forward_context = try_get_forward_context() if (forward_context is not None and forward_context.enable_sequence_parallel): - output = tensor_model_parallel_reduce_scatter(output_parallel) + output = tensor_model_parallel_reduce_scatter(output_parallel, + dim=0) else: output = tensor_model_parallel_all_reduce(output_parallel) return output diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index c4b5fbf67638..81b5d9bda9ac 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -479,6 +479,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.model = self._init_model(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) + if get_pp_group().is_last_rank: self.unpadded_vocab_size = config.vocab_size if lora_config: @@ -499,6 +500,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): if config.tie_word_embeddings: self.lm_head = self.lm_head.tie_weights( self.model.embed_tokens) + logit_scale = getattr(config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size,