From 8d3ec7465a861a5462d87edaa838c8b2f6030d53 Mon Sep 17 00:00:00 2001 From: bapatra Date: Fri, 3 May 2024 14:53:53 -0700 Subject: [PATCH 1/2] minor change for LongRoPE config to account for rename from longrope -> su --- vllm/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/config.py b/vllm/config.py index 77f0eaddca1c7..d28664ebae84b 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -972,7 +972,7 @@ def _get_and_verify_max_len( derived_max_model_len = default_max_len rope_scaling = getattr(hf_config, "rope_scaling", None) - if rope_scaling is not None and rope_scaling["type"] != "longrope": + if rope_scaling is not None and rope_scaling["type"] not in ("longrope", "su"): assert "factor" in rope_scaling scaling_factor = rope_scaling["factor"] if rope_scaling["type"] == "yarn": From e1dd365aeecdf98ff3526b8674a308c2b3ab16d8 Mon Sep 17 00:00:00 2001 From: bapatra Date: Fri, 3 May 2024 16:33:38 -0700 Subject: [PATCH 2/2] handling TP slicing on the vllm side for dummy tokens fix --- vllm/model_executor/models/phi3small/phi3small.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/phi3small/phi3small.py b/vllm/model_executor/models/phi3small/phi3small.py index 2772c87906f4a..5ba7c618db4d7 100644 --- a/vllm/model_executor/models/phi3small/phi3small.py +++ b/vllm/model_executor/models/phi3small/phi3small.py @@ -391,7 +391,11 @@ def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: logits = self.logits_processor(self.lm_head.weight, hidden_states, sampling_metadata) - if self.dummy_token_indices is not None: + if self.dummy_token_indices is not None and logits is not None: + # In case of tensor-parallelism, the logit processor under the hood + # does an `tensor_model_parallel_gather`, so that the vocab multiplication + # would happen only on rank 0. For all other ranks, the logits are returned as + # None. Hence only rank with not None logits should fill the dummy tokens with -inf. logits.index_fill_(-1, self.dummy_token_indices, -torch.inf) return logits