diff --git a/tests/distributed/test_comm_ops.py b/tests/distributed/test_comm_ops.py index 7b0346b8ab50..66b8c26ab767 100644 --- a/tests/distributed/test_comm_ops.py +++ b/tests/distributed/test_comm_ops.py @@ -11,7 +11,8 @@ from vllm.distributed import (broadcast_tensor_dict, get_pp_group, tensor_model_parallel_all_gather, - tensor_model_parallel_all_reduce) + tensor_model_parallel_all_reduce, + tensor_model_parallel_reduce_scatter) from ..utils import init_test_distributed_environment, multi_process_parallel @@ -38,6 +39,33 @@ def all_reduce_test_worker(tp_size: int, pp_size: int, rank: int, torch.testing.assert_close(t, expected) +@ray.remote(num_gpus=1, max_calls=1) +def reduce_scatter_test_worker(tp_size: int, pp_size: int, rank: int, + distributed_init_port: str): + # it is important to delete the CUDA_VISIBLE_DEVICES environment variable + # so that each worker can see all the GPUs + # they will be able to set the device to the correct GPU + os.environ.pop("CUDA_VISIBLE_DEVICES", None) + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + init_test_distributed_environment(tp_size, pp_size, rank, + distributed_init_port) + + num_elements = 8 + all_tensors = [ + torch.arange(num_elements, dtype=torch.float32, device="cuda") * + (r + 1) for r in range(tp_size) + ] + + index = rank % tp_size + partition_size = num_elements // tp_size + all_reduce = torch.sum(torch.stack(all_tensors, dim=0), dim=0) + expected = all_reduce[index * partition_size:(index + 1) * partition_size] + t = all_tensors[index] + t = tensor_model_parallel_reduce_scatter(t) + torch.testing.assert_close(t, expected) + + @ray.remote(num_gpus=1, max_calls=1) def all_gather_test_worker(tp_size: int, pp_size: int, rank: int, distributed_init_port: str): @@ -178,6 +206,17 @@ def test_multi_process_tensor_parallel(tp_size, test_target): multi_process_parallel(tp_size, 1, test_target) +@pytest.mark.skipif(torch.cuda.device_count() < 2, + reason="Need at least 2 GPUs to run the test.") +@pytest.mark.parametrize("tp_size", [2]) +@pytest.mark.parametrize("test_target", [ + all_reduce_test_worker, all_gather_test_worker, reduce_scatter_test_worker, + broadcast_tensor_dict_test_worker +]) +def test_multi_process_tesor_parallel_sequence_parallel(tp_size, test_target): + multi_process_parallel(tp_size, 1, test_target) + + @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Need at least 2 GPUs to run the test.") @pytest.mark.parametrize("pp_size", [2]) @@ -199,3 +238,17 @@ def test_multi_process_pipeline_parallel(pp_size, test_target): def test_multi_process_tensor_parallel_pipeline_parallel( tp_size, pp_size, test_target): multi_process_parallel(tp_size, pp_size, test_target) + + +@pytest.mark.skipif(torch.cuda.device_count() < 4, + reason="Need at least 4 GPUs to run the test.") +@pytest.mark.parametrize("tp_size", [2]) +@pytest.mark.parametrize("pp_size", [2]) +@pytest.mark.parametrize("test_target", [ + send_recv_test_worker, send_recv_tensor_dict_test_worker, + all_reduce_test_worker, all_gather_test_worker, reduce_scatter_test_worker, + broadcast_tensor_dict_test_worker +]) +def test_multi_process_tensor_parallel_sequence_parallel_pipeline_parallel( + tp_size, pp_size, test_target): + multi_process_parallel(tp_size, pp_size, test_target) diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index 05b6ba40506a..0be3fc76f93f 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -39,6 +39,7 @@ def use_v0_only(monkeypatch): class ParallelSetup(NamedTuple): tp_size: int pp_size: int + sp_enabled: bool eager_mode: bool chunked_prefill: bool @@ -81,22 +82,27 @@ def detailed( parallel_setups=[ ParallelSetup(tp_size=tp_base, pp_size=pp_base, + sp_enabled=False, eager_mode=False, chunked_prefill=False), ParallelSetup(tp_size=tp_base, pp_size=2 * pp_base, + sp_enabled=False, eager_mode=False, chunked_prefill=True), ParallelSetup(tp_size=tp_base, pp_size=2 * pp_base, + sp_enabled=False, eager_mode=True, chunked_prefill=False), ParallelSetup(tp_size=2 * tp_base, pp_size=pp_base, + sp_enabled=False, eager_mode=False, chunked_prefill=True), ParallelSetup(tp_size=2 * tp_base, pp_size=pp_base, + sp_enabled=False, eager_mode=True, chunked_prefill=False), ], @@ -121,8 +127,9 @@ def fast( parallel_setups=[ ParallelSetup(tp_size=tp_base, pp_size=pp_base, + sp_enabled=False, eager_mode=True, - chunked_prefill=False), + chunked_prefill=False) ], distributed_backends=["mp"], vllm_major_versions=["0"], @@ -131,6 +138,42 @@ def fast( load_format=load_format), ) + @staticmethod + def sp( + *, + tp_base: int = 2, + pp_base: int = 1, + task: TaskOption = "auto", + multi_node_only: bool = False, + load_format: Optional[str] = None, + ): + return PPTestSettings( + parallel_setups=[ + ParallelSetup(tp_size=tp_base, + pp_size=pp_base, + sp_enabled=True, + eager_mode=False, + chunked_prefill=False), + ParallelSetup(tp_size=2 * tp_base, + pp_size=pp_base, + sp_enabled=True, + eager_mode=False, + chunked_prefill=True), + + # current sp doesn't support combination with pp + # ParallelSetup(tp_size=2 * tp_base, + # pp_size=2 * pp_base, + # sp_enabled=True, + # eager_mode=True, + # chunked_prefill=False), + ], + distributed_backends=["mp", "mp"], + vllm_major_versions=["0", "1"], + task=task, + test_options=PPTestOptions(multi_node_only=multi_node_only, + load_format=load_format), + ) + def iter_params(self, model_id: str): opts = self.test_options @@ -271,10 +314,10 @@ def _compare_tp( ( tp_size, pp_size, + sp_enabled, eager_mode, chunked_prefill, ) = parallel_setup - multi_node_only, load_format = test_options model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id) @@ -360,6 +403,9 @@ def _compare_tp( distributed_backend, ] + if sp_enabled: + pp_args.append("--enable-sequence-parallel") + # compare without pipeline parallelism # NOTE: use mp backend for TP # PP tests might involve multiple nodes, and ray might @@ -469,3 +515,45 @@ def test_tp_multimodal_generation( num_gpus_available, method="generate", is_multimodal=True) + + +SP_TEXT_GENERATION_MODELS = { + # [Decoder-only] + "meta-llama/Llama-3.2-1B-Instruct": PPTestSettings.sp(), +} + +SP_TEST_MODELS = [ + # TODO support other models + # [LANGUAGE GENERATION] + "meta-llama/Llama-3.2-1B-Instruct", +] + + +@pytest.mark.parametrize( + ("model_id", "parallel_setup", "distributed_backend", "vllm_major_version", + "task", "test_options"), + [ + params for model_id, settings in SP_TEXT_GENERATION_MODELS.items() + for params in settings.iter_params(model_id) + if model_id in SP_TEST_MODELS + ], +) +@fork_new_process_for_each_test +def test_tp_sp_generation( + model_id: str, + parallel_setup: ParallelSetup, + distributed_backend: str, + vllm_major_version: str, + task: TaskOption, + test_options: PPTestOptions, + num_gpus_available, +): + _compare_tp(model_id, + parallel_setup, + distributed_backend, + vllm_major_version, + task, + test_options, + num_gpus_available, + method="generate", + is_multimodal=True) diff --git a/vllm/config.py b/vllm/config.py index 70cc0affe998..0c50fe22ff7d 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1342,6 +1342,8 @@ class ParallelConfig: tensor_parallel_size: int = 1 # Number of tensor parallel groups. data_parallel_size: int = 1 # Number of data parallel groups. data_parallel_rank: int = 0 # Rank of the data parallel group. + enable_sequence_parallel: bool = False # Enable sequence parallelism. + # IP of the data parallel master. data_parallel_master_ip: str = "127.0.0.1" data_parallel_master_port: int = 29500 # Port of the data parallel master. @@ -2134,6 +2136,8 @@ def create_draft_parallel_config( pipeline_parallel_size=target_parallel_config. pipeline_parallel_size, tensor_parallel_size=speculative_draft_tensor_parallel_size, + enable_sequence_parallel=target_parallel_config. + enable_sequence_parallel, distributed_executor_backend=target_parallel_config. distributed_executor_backend, max_parallel_loading_workers=target_parallel_config. diff --git a/vllm/distributed/communication_op.py b/vllm/distributed/communication_op.py index 0228264f91f9..1b48324009cf 100644 --- a/vllm/distributed/communication_op.py +++ b/vllm/distributed/communication_op.py @@ -13,6 +13,11 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: return get_tp_group().all_reduce(input_) +def tensor_model_parallel_reduce_scatter(input_: torch.Tensor) -> torch.Tensor: + """Reduce-scatter the input tensor across model parallel group.""" + return get_tp_group().reduce_scatter(input_) + + def tensor_model_parallel_all_gather(input_: torch.Tensor, dim: int = -1) -> torch.Tensor: """All-gather the input tensor across model parallel group.""" diff --git a/vllm/distributed/device_communicators/base_device_communicator.py b/vllm/distributed/device_communicators/base_device_communicator.py index eb12f8834b41..fcfca72471bf 100644 --- a/vllm/distributed/device_communicators/base_device_communicator.py +++ b/vllm/distributed/device_communicators/base_device_communicator.py @@ -35,6 +35,21 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: dist.all_reduce(input_, group=self.device_group) return input_ + def reduce_scatter(self, input_: torch.Tensor) -> torch.Tensor: + input_size = input_.size() + assert input_size[0] % self.world_size == 0, ( + f"reduce scatter doesn't work when input size {input_size} is not " + f"divisible by world size {self.world_size}") + output_size = (input_size[0] // self.world_size, ) + input_size[1:] + # Allocate output tensor. + output_tensor = torch.empty(output_size, + dtype=input_.dtype, + device=input_.device) + dist.reduce_scatter_tensor(output_tensor, + input_, + group=self.device_group) + return output_tensor + def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: if dim < 0: # Convert negative dim to positive. diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 86166dd5bb83..ad8f8bacb3ec 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -312,6 +312,13 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor: return self.device_communicator.all_reduce(input_) + def reduce_scatter(self, input_: torch.Tensor) -> torch.Tensor: + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return input_ + + return self.device_communicator.reduce_scatter(input_) + def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: world_size = self.world_size # Bypass the function if we are using only 1 GPU. diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 31d567de0efa..7343634f4641 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -114,6 +114,7 @@ class EngineArgs: # number of P/D disaggregation (or other disaggregation) workers pipeline_parallel_size: int = 1 tensor_parallel_size: int = 1 + enable_sequence_parallel: bool = False enable_expert_parallel: bool = False max_parallel_loading_workers: Optional[int] = None block_size: Optional[int] = None @@ -434,6 +435,11 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: type=int, default=EngineArgs.tensor_parallel_size, help='Number of tensor parallel replicas.') + parser.add_argument('--enable-sequence-parallel', + '-sp', + action='store_true', + default=False, + help='If enable sequence parallel') parser.add_argument( '--enable-expert-parallel', action='store_true', @@ -1242,6 +1248,7 @@ def create_engine_config( parallel_config = ParallelConfig( pipeline_parallel_size=self.pipeline_parallel_size, tensor_parallel_size=self.tensor_parallel_size, + enable_sequence_parallel=self.enable_sequence_parallel, enable_expert_parallel=self.enable_expert_parallel, max_parallel_loading_workers=self.max_parallel_loading_workers, disable_custom_all_reduce=self.disable_custom_all_reduce, diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index a0e2fa2918bd..819a99284065 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -74,6 +74,8 @@ class LLM: environments. tensor_parallel_size: The number of GPUs to use for distributed execution with tensor parallelism. + enable_sequence_parallel: Enable sequence parallelism on top of tensor + parallelism. dtype: The data type for the model weights and activations. Currently, we support `float32`, `float16`, and `bfloat16`. If `auto`, we use the `torch_dtype` attribute specified in the model config file. @@ -164,6 +166,7 @@ def __init__( trust_remote_code: bool = False, allowed_local_media_path: str = "", tensor_parallel_size: int = 1, + enable_sequence_parallel: bool = False, dtype: str = "auto", quantization: Optional[str] = None, revision: Optional[str] = None, @@ -219,6 +222,7 @@ def __init__( trust_remote_code=trust_remote_code, allowed_local_media_path=allowed_local_media_path, tensor_parallel_size=tensor_parallel_size, + enable_sequence_parallel=enable_sequence_parallel, dtype=dtype, quantization=quantization, revision=revision, diff --git a/vllm/forward_context.py b/vllm/forward_context.py index e195a03c5cac..dc76c62fe6d8 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -38,6 +38,7 @@ class ForwardContext: attn_metadata: "AttentionMetadata" # set dynamically for each forward pass # TODO: remove after making all virtual_engines share the same kv cache virtual_engine: int # set dynamically for each forward pass + enable_sequence_parallel: bool # If enable sequence_parallelism # set dynamically for each forward pass dp_metadata: Optional[DPMetadata] = None @@ -53,11 +54,16 @@ def get_forward_context() -> ForwardContext: return _forward_context +def try_get_forward_context() -> ForwardContext: + return _forward_context + + @contextmanager def set_forward_context(attn_metadata: Any, vllm_config: VllmConfig, virtual_engine: int = 0, - num_tokens: int = 0): + num_tokens: int = 0, + enable_sequence_parallel: bool = False): """A context manager that stores the current forward context, can be attention metadata, etc. Here we can inject common logic for every model forward pass. @@ -90,6 +96,12 @@ def set_forward_context(attn_metadata: Any, cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_tensor, dim=0) dp_metadata = DPMetadata(cu_tokens_across_dp_cpu) + # TODO: support pipeline parallel + if vllm_config.parallel_config.enable_sequence_parallel: + assert vllm_config.parallel_config.pipeline_parallel_size == 1, ( + "sequence parallel doesn't work correctly when " + "combined with pipeline parallel") + global _forward_context prev_context = _forward_context _forward_context = ForwardContext( @@ -97,6 +109,7 @@ def set_forward_context(attn_metadata: Any, static_forward_context, virtual_engine=virtual_engine, attn_metadata=attn_metadata, + enable_sequence_parallel=enable_sequence_parallel, dp_metadata=dp_metadata) try: yield diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 1ae574072b8f..85c4e3fcd2ae 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -13,7 +13,9 @@ get_tensor_model_parallel_world_size, split_tensor_along_last_dim, tensor_model_parallel_all_gather, - tensor_model_parallel_all_reduce) + tensor_model_parallel_all_reduce, + tensor_model_parallel_reduce_scatter) +from vllm.forward_context import try_get_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) @@ -469,6 +471,11 @@ def forward( ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: bias = self.bias if not self.skip_bias_add else None + forward_context = try_get_forward_context() + if (forward_context is not None + and forward_context.enable_sequence_parallel): + input_ = tensor_model_parallel_all_gather(input_, 0) + # Matrix multiply. assert self.quant_method is not None output_parallel = self.quant_method.apply(self, input_, bias) @@ -1258,8 +1265,14 @@ def forward( output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_) + if self.reduce_results and self.tp_size > 1: - output = tensor_model_parallel_all_reduce(output_parallel) + forward_context = try_get_forward_context() + if (forward_context is not None + and forward_context.enable_sequence_parallel): + output = tensor_model_parallel_reduce_scatter(output_parallel) + else: + output = tensor_model_parallel_all_reduce(output_parallel) else: output = output_parallel diff --git a/vllm/model_executor/layers/logits_processor.py b/vllm/model_executor/layers/logits_processor.py index 4a359725bad0..2a070d231330 100644 --- a/vllm/model_executor/layers/logits_processor.py +++ b/vllm/model_executor/layers/logits_processor.py @@ -10,6 +10,7 @@ import vllm.envs as envs from vllm.distributed import (tensor_model_parallel_all_gather, tensor_model_parallel_gather) +from vllm.forward_context import try_get_forward_context from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -59,6 +60,11 @@ def forward( sampling_metadata: Optional[SamplingMetadata] = None, embedding_bias: Optional[torch.Tensor] = None, ) -> Optional[torch.Tensor]: + forward_context = try_get_forward_context() + if (forward_context is not None + and forward_context.enable_sequence_parallel): + hidden_states = tensor_model_parallel_all_gather(hidden_states, + dim=0) if self.logits_as_input: logits = hidden_states else: @@ -105,6 +111,7 @@ def _get_logits( embedding_bias: Optional[torch.Tensor], ) -> Optional[torch.Tensor]: # Get the logits for the next tokens. + logits = lm_head.quant_method.apply(lm_head, hidden_states, bias=embedding_bias) diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py index f65dfc3cb329..7a4878eedb7b 100644 --- a/vllm/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm/model_executor/layers/vocab_parallel_embedding.py @@ -9,7 +9,9 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, - tensor_model_parallel_all_reduce) + tensor_model_parallel_all_reduce, + tensor_model_parallel_reduce_scatter) +from vllm.forward_context import try_get_forward_context from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, method_has_implemented_embedding) from vllm.model_executor.parameter import BasevLLMParameter @@ -204,7 +206,6 @@ def __init__(self, quant_config: Optional[QuantizationConfig] = None, prefix: str = ""): super().__init__() - # Keep the input dimensions. tp_rank = get_tensor_model_parallel_rank() self.tp_size = get_tensor_model_parallel_world_size() @@ -418,7 +419,12 @@ def forward(self, input_): if self.tp_size > 1: output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0) # Reduce across all the model parallel GPUs. - output = tensor_model_parallel_all_reduce(output_parallel) + forward_context = try_get_forward_context() + if (forward_context is not None + and forward_context.enable_sequence_parallel): + output = tensor_model_parallel_reduce_scatter(output_parallel) + else: + output = tensor_model_parallel_all_reduce(output_parallel) return output def extra_repr(self) -> str: diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 81b5d9bda9ac..c4b5fbf67638 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -479,7 +479,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.model = self._init_model(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) - if get_pp_group().is_last_rank: self.unpadded_vocab_size = config.vocab_size if lora_config: @@ -500,7 +499,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): if config.tie_word_embeddings: self.lm_head = self.lm_head.tie_weights( self.model.embed_tokens) - logit_scale = getattr(config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index c2a976108e4d..1c3961a6cc50 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -815,7 +815,6 @@ def _execute_encoder(self, scheduler_output: "SchedulerOutput"): # depending on the input multimodal items. curr_group_outputs = self.model.get_multimodal_embeddings( **batched_mm_inputs) - for output in curr_group_outputs: encoder_outputs.append(output) @@ -984,9 +983,22 @@ def execute_model( for k, v in self.intermediate_tensors.items() }) + # only do sequence parallelism when num of tokens + # is divisible by parallel size. + # sequence parallelism uses torch.distributed.reduce_scatter which only + # supports the case when size is divisible by parallel size + enable_sequence_parallel = ( + self.vllm_config.parallel_config.enable_sequence_parallel + and num_input_tokens % + self.vllm_config.parallel_config.tensor_parallel_size == 0) + # Run the decoder. # Use persistent buffers for CUDA graphs. - with set_forward_context(attn_metadata, self.vllm_config): + with set_forward_context( + attn_metadata, + self.vllm_config, + enable_sequence_parallel=enable_sequence_parallel, + ): hidden_states = self.model( input_ids=input_ids, positions=positions, @@ -999,7 +1011,11 @@ def execute_model( hidden_states = hidden_states[:num_scheduled_tokens] sample_hidden_states = hidden_states[logits_indices] - logits = self.model.compute_logits(sample_hidden_states, None) + with set_forward_context( + attn_metadata, + self.vllm_config, + enable_sequence_parallel=enable_sequence_parallel): + logits = self.model.compute_logits(sample_hidden_states, None) # Apply structured output bitmasks if present if scheduler_output.grammar_bitmask is not None: @@ -1166,7 +1182,16 @@ def _get_prompt_logprobs_dict( req_idx = self.input_batch.req_id_to_index[req_id] offset = self.query_start_loc_np[req_idx].item() prompt_hidden_states = hidden_states[offset:offset + num_logits] - logits = self.model.compute_logits(prompt_hidden_states, None) + + enable_sequence_parallel = ( + self.vllm_config.parallel_config.enable_sequence_parallel + and num_tokens % + self.vllm_config.parallel_config.tensor_parallel_size == 0) + with set_forward_context( + None, + self.vllm_config, + enable_sequence_parallel=enable_sequence_parallel): + logits = self.model.compute_logits(prompt_hidden_states, None) # Get the "target" tokens for each index. For prompt at index i, # the token at prompt index i+1 is the "sampled" token we want @@ -1243,9 +1268,17 @@ def _dummy_run( for k, v in self.intermediate_tensors.items() }) - with set_forward_context(None, - self.vllm_config, - num_tokens=num_tokens): + enable_sequence_parallel = ( + self.vllm_config.parallel_config.enable_sequence_parallel + and num_tokens % + self.vllm_config.parallel_config.tensor_parallel_size == 0) + + with set_forward_context( + None, + self.vllm_config, + num_tokens=num_tokens, + enable_sequence_parallel=enable_sequence_parallel, + ): hidden_states = model( input_ids=input_ids, positions=positions, @@ -1262,7 +1295,15 @@ def _dummy_sampler_run( hidden_states: torch.Tensor, ) -> torch.Tensor: - logits = self.model.compute_logits(hidden_states, None) + enable_sequence_parallel = ( + self.vllm_config.parallel_config.enable_sequence_parallel + and hidden_states.size()[0] % + self.vllm_config.parallel_config.tensor_parallel_size == 0) + with set_forward_context( + None, + self.vllm_config, + enable_sequence_parallel=enable_sequence_parallel): + logits = self.model.compute_logits(hidden_states, None) num_reqs = logits.size(0) dummy_tensors = lambda v: torch.full( diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 473bd901b5b2..3cd7c994f809 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1556,7 +1556,7 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: capture_inputs) with set_forward_context(attn_metadata, self.vllm_config, - virtual_engine): + virtual_engine, batch_size): graph_runner.capture(**capture_inputs) self.graph_memory_pool = graph_runner.graph.pool() self.graph_runners[virtual_engine][batch_size] = ( @@ -1736,9 +1736,24 @@ def execute_model( model_forward_end = torch.cuda.Event(enable_timing=True) model_forward_start.record() + num_tokens = (model_input.input_tokens.shape[0] + if model_input.input_tokens is not None else None) + + # only do sequence parallelism when num of tokens + # is divisible by parallel size. + # sequence parallelism uses torch.distributed.reduce_scatter which only + # supports the case when size is divisible by parallel size + enable_sequence_parallel = ( + self.vllm_config.parallel_config.enable_sequence_parallel + and num_tokens is not None and num_tokens % + self.vllm_config.parallel_config.tensor_parallel_size == 0) + if not bypass_model_exec: - with set_forward_context(model_input.attn_metadata, - self.vllm_config, virtual_engine): + with set_forward_context( + model_input.attn_metadata, + self.vllm_config, + virtual_engine, + enable_sequence_parallel=enable_sequence_parallel): hidden_or_intermediate_states = model_executable( input_ids=model_input.input_tokens, positions=model_input.input_positions, @@ -1785,8 +1800,14 @@ def execute_model( torch.tensor(model_forward_time + orig_model_forward_time)) return hidden_or_intermediate_states - logits = self.model.compute_logits(hidden_or_intermediate_states, - model_input.sampling_metadata) + with set_forward_context( + model_input.attn_metadata, + self.vllm_config, + virtual_engine, + enable_sequence_parallel=enable_sequence_parallel, + ): + logits = self.model.compute_logits(hidden_or_intermediate_states, + model_input.sampling_metadata) if not self.is_driver_worker: return []