Skip to content

Commit d3d15a6

Browse files
varun-sundar-rabindranathVarun Sundar Rabindranath
authored andcommitted
[Kernel] LoRA - Enable CUDAGraphs for V1 (vllm-project#14626)
Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com> Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
1 parent f5cb455 commit d3d15a6

File tree

4 files changed

+35
-14
lines changed

4 files changed

+35
-14
lines changed

tests/lora/test_worker.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def set_active_loras(worker: Union[Worker, V1Worker],
5252
seed=0,
5353
dtype="float16",
5454
revision=None,
55+
enforce_eager=True,
5556
),
5657
load_config=LoadConfig(
5758
download_dir=None,

vllm/config.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2289,9 +2289,14 @@ def compute_hash(self) -> str:
22892289
excluding anything before input ids/embeddings and after
22902290
the final hidden states.
22912291
"""
2292-
# no factors to consider.
2293-
# LoRA is not compatible with `torch.compile` .
22942292
factors: list[Any] = []
2293+
factors.append(self.max_lora_rank)
2294+
factors.append(self.max_loras)
2295+
factors.append(self.fully_sharded_loras)
2296+
factors.append(self.lora_dtype)
2297+
factors.append(self.lora_extra_vocab_size)
2298+
factors.append(self.long_lora_scaling_factors)
2299+
factors.append(self.bias_enabled)
22952300
hash_str = hashlib.md5(str(factors).encode()).hexdigest()
22962301
return hash_str
22972302

@@ -3305,6 +3310,11 @@ def compute_hash(self) -> str:
33053310
vllm_factors.append("None")
33063311
if self.lora_config:
33073312
vllm_factors.append(self.lora_config.compute_hash())
3313+
# LoRA creates static buffers based on max_num_batched_tokens.
3314+
# The tensor sizes and strides get captured in the torch.compile
3315+
# graph explicitly.
3316+
vllm_factors.append(
3317+
str(self.scheduler_config.max_num_batched_tokens))
33083318
else:
33093319
vllm_factors.append("None")
33103320
if self.speculative_config:
@@ -3455,12 +3465,15 @@ def __post_init__(self):
34553465
" Disabling `torch.compile`.")
34563466
self.compilation_config.level = CompilationLevel.NO_COMPILATION
34573467

3458-
if self.lora_config is not None and self.compilation_config.level !=\
3459-
CompilationLevel.NO_COMPILATION:
3460-
logger.warning("LoRA is not supported with `torch.compile` yet. "
3461-
"Disabling `torch.compile`.")
3468+
if ((not envs.VLLM_USE_V1) and self.lora_config is not None
3469+
and self.compilation_config.level
3470+
!= CompilationLevel.NO_COMPILATION):
3471+
logger.warning(
3472+
"LoRA for V0 is not supported with `torch.compile` yet. "
3473+
"Disabling `torch.compile`.")
34623474
self.compilation_config.level = CompilationLevel.NO_COMPILATION
34633475

3476+
34643477
if self.model_config and self.model_config.use_mla and \
34653478
not (current_platform.is_cuda() or current_platform.is_rocm()):
34663479
logger.info(

vllm/lora/layers.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -237,16 +237,19 @@ def set_lora(
237237
self.embeddings_weights[:embeddings.shape[0]].copy_(embeddings)
238238

239239
def forward(self, x: torch.Tensor) -> torch.Tensor:
240-
added_tokens_mask = x > self.base_layer.org_vocab_size - 1
241-
embeddings_indices = self.punica_wrapper.embeddings_indices
242-
indices = embeddings_indices[1].view_as(x)
240+
added_tokens_mask = torch.where(x > self.base_layer.org_vocab_size - 1,
241+
1, 0)
242+
embeddings_indices = torch.narrow(
243+
self.punica_wrapper._embeddings_indices, 1, 0, x.size(0))
244+
245+
indices = embeddings_indices[1]
243246
full_lora_a_embeddings = F.embedding(
244247
x + indices,
245248
self.lora_a_stacked_2d,
246249
)
247-
indices = embeddings_indices[0].view_as(x)
248-
full_output = self.base_layer.forward(
249-
x.add_(indices * added_tokens_mask))
250+
indices = embeddings_indices[0]
251+
full_output = self.base_layer.forward(x +
252+
(indices * added_tokens_mask))
250253

251254
full_output_org = full_output
252255
if full_output.ndim == 3:

vllm/lora/punica_wrapper/punica_gpu.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,9 @@ def add_expand(self,
254254
y_org = y
255255
y = y.view(-1, y.shape[-1])
256256
if lora_bias_stacked is not None:
257-
self._apply_bias(self.token_lora_indices, y, output_slices,
257+
token_lora_indices = torch.narrow(self._token_lora_indices, 0, 0,
258+
y.size(0))
259+
self._apply_bias(token_lora_indices, y, output_slices,
258260
lora_bias_stacked)
259261

260262
if env.VLLM_USE_V1:
@@ -365,7 +367,9 @@ def add_lora_linear(self,
365367
assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices)
366368
if lora_bias_stacked is not None:
367369
assert len(lora_bias_stacked) == len(output_slices)
368-
y = self._apply_bias(self.token_lora_indices, y, output_slices,
370+
token_lora_indices = torch.narrow(self._token_lora_indices, 0, 0,
371+
y.size(0))
372+
y = self._apply_bias(token_lora_indices, y, output_slices,
369373
lora_bias_stacked)
370374

371375
if buffer is None:

0 commit comments

Comments
 (0)