Skip to content

Commit b39c491

Browse files
lsy323mgoin
authored andcommitted
[torch.compile][TPU] Make @support_torch_compile work for XLA backend (vllm-project#15782)
Signed-off-by: Siyuan Liu <lsiyuan@google.com> Signed-off-by: mgoin <mgoin64@gmail.com> Co-authored-by: mgoin <mgoin64@gmail.com> Signed-off-by: Mu Huai <tianbowen.tbw@antgroup.com>
1 parent c05ffd9 commit b39c491

File tree

2 files changed

+47
-76
lines changed

2 files changed

+47
-76
lines changed

vllm/v1/worker/tpu_model_runner.py

Lines changed: 38 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import vllm.envs as envs
1616
from vllm.attention.backends.abstract import AttentionType
1717
from vllm.attention.layer import Attention
18+
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
1819
from vllm.config import VllmConfig
1920
from vllm.forward_context import set_forward_context
2021
from vllm.logger import init_logger
@@ -691,11 +692,10 @@ def execute_model(
691692
hidden_states = self.model(
692693
input_ids=input_ids,
693694
positions=self.position_ids,
694-
kv_caches=self.kv_caches,
695695
inputs_embeds=inputs_embeds,
696696
)
697-
selected_token_ids = self.model.sample_from_hidden(
698-
hidden_states, tpu_sampling_metadata)
697+
selected_token_ids = self.sample_from_hidden(hidden_states,
698+
tpu_sampling_metadata)
699699
# Remove padding on cpu and keep dynamic op outside of xla graph.
700700
selected_token_ids = selected_token_ids.cpu()[:num_reqs]
701701

@@ -795,17 +795,15 @@ def load_model(self) -> None:
795795
"get_tensor_model_parallel_rank",
796796
return_value=xm_tp_rank):
797797
model = get_model(vllm_config=self.vllm_config)
798-
model = model.eval()
798+
# Sync all pending XLA execution during model initialization and weight
799+
# loading.
799800
xm.mark_step()
800801
xm.wait_device_ops()
801-
model = ModelWrapperV1(model)
802-
self.model = torch.compile(model,
803-
backend="openxla",
804-
fullgraph=True,
805-
dynamic=False)
802+
self.model = model
803+
self.sampler = TPUSampler()
806804

807805
@torch.no_grad()
808-
def _dummy_run(self, kv_caches, num_tokens: int) -> None:
806+
def _dummy_run(self, num_tokens: int) -> None:
809807
if self.is_multimodal_model:
810808
input_ids = None
811809
inputs_embeds = torch.zeros((num_tokens, self.hidden_size),
@@ -856,7 +854,6 @@ def _dummy_run(self, kv_caches, num_tokens: int) -> None:
856854
with set_forward_context(attn_metadata, self.vllm_config, 0):
857855
out = self.model(input_ids=input_ids,
858856
positions=position_ids,
859-
kv_caches=kv_caches,
860857
inputs_embeds=inputs_embeds)
861858
self._hidden_states_dtype = out.dtype
862859

@@ -868,7 +865,7 @@ def capture_model(self) -> None:
868865
start = time.perf_counter()
869866
for num_tokens in self.num_tokens_paddings:
870867
logger.info(" -- num_tokens: %d", num_tokens)
871-
self._dummy_run(self.kv_caches, num_tokens)
868+
self._dummy_run(num_tokens)
872869
xm.mark_step()
873870
xm.wait_device_ops()
874871
end = time.perf_counter()
@@ -899,8 +896,7 @@ def capture_model(self) -> None:
899896
from_input_batch(self.input_batch, indices)
900897
logger.info(" -- num_tokens: %d, num_seqs: %d", num_tokens,
901898
num_reqs_to_sample)
902-
out = self.model.sample_from_hidden(dummy_hidden,
903-
sampling_meta)
899+
out = self.sample_from_hidden(dummy_hidden, sampling_meta)
904900
out = out.cpu()
905901
# Requests can't be more than tokens. But do compile for the
906902
# next bigger value in case num_tokens uses bucketed padding.
@@ -954,79 +950,48 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
954950
self.vllm_config.compilation_config.static_forward_context,
955951
self.kv_caches)
956952

957-
958-
class ModelWrapperV1(nn.Module):
959-
960-
def __init__(self, model: nn.Module):
961-
super().__init__()
962-
self.model = model
963-
self.sampler = TPUSampler()
964-
965-
def sample(
966-
self, logits: torch.Tensor,
967-
sampling_metadata: TPUSupportedSamplingMetadata) -> SamplerOutput:
968-
sampler_out = self.sampler(logits, sampling_metadata)
969-
return sampler_out
970-
971-
def forward(
972-
self,
973-
input_ids: torch.Tensor,
974-
positions: torch.Tensor,
975-
kv_caches: list[torch.Tensor],
976-
inputs_embeds: Optional[torch.Tensor] = None,
977-
) -> torch.Tensor:
978-
"""Executes the forward pass of the model.
979-
980-
Args:
981-
input_ids: The input token IDs of shape [num_tokens].
982-
positions: The input position IDs of shape [num_tokens].
983-
kv_caches: The key and value caches. They can be None during the
984-
memory profiling at initialization.
985-
inputs_embeds: The input embeddings of shape [num_tokens,
986-
hidden_size]. It is used for multimodal models.
987-
"""
988-
989-
hidden_states = self.model(
990-
input_ids=input_ids,
991-
positions=positions,
992-
inputs_embeds=inputs_embeds,
993-
)
994-
995-
return hidden_states
953+
def reset_dynamo_cache(self):
954+
if self.is_multimodal_model:
955+
assert hasattr(self.model, "language_model")
956+
compiled_model = self.model.language_model.model
957+
else:
958+
compiled_model = self.model.model
959+
if isinstance(compiled_model, TorchCompileWrapperWithCustomDispatcher):
960+
logger.info("Clear dynamo cache and cached dynamo bytecode.")
961+
torch._dynamo.eval_frame.remove_from_cache(
962+
compiled_model.original_code_object)
963+
compiled_model.compiled_codes.clear()
996964

997965
def sample_from_hidden(
998966
self,
999967
hidden_states: torch.Tensor,
1000968
sampling_metadata: TPUSupportedSamplingMetadata,
1001969
) -> torch.Tensor:
1002970
"""
1003-
Sample with xla-friendly function. This function is to be traced
1004-
separately from `forward` for lighter compilation overhead.
1005-
"""
971+
Sample with xla-friendly function. This function is to be traced
972+
separately for lighter compilation overhead.
973+
"""
1006974
# Tensor `sample_hidden_states` is of fixed pre-compiled size.
1007975
sample_hidden_states = \
1008976
hidden_states[sampling_metadata.indices_do_sample]
1009-
logits = self.compute_logits(sample_hidden_states)
977+
# SamplingMetadata here for pruning output in LogitsProcessor, disabled.
978+
logits = self.model.compute_logits(sample_hidden_states, None)
979+
980+
def sample(
981+
logits: torch.Tensor,
982+
sampling_metadata: TPUSupportedSamplingMetadata
983+
) -> SamplerOutput:
984+
sampler_out = self.sampler(logits, sampling_metadata)
985+
return sampler_out
986+
1010987
# Optimized greedy sampling branch, tracing both paths in a single pass
1011988
# NOTE all_greedy is a scalar, this is just an optimized if/else.
1012-
out_tokens = torch.where(sampling_metadata.all_greedy,
1013-
torch.argmax(logits, dim=-1, keepdim=True),
1014-
self.sample(logits, sampling_metadata)\
1015-
.sampled_token_ids)
989+
out_tokens = torch.where(
990+
sampling_metadata.all_greedy,
991+
torch.argmax(logits, dim=-1, keepdim=True),
992+
sample(logits, sampling_metadata).sampled_token_ids)
1016993
return out_tokens
1017994

1018-
def compute_logits(self,
1019-
hidden_states: torch.Tensor) -> Optional[torch.Tensor]:
1020-
# SamplingMetadata here for pruning output in LogitsProcessor, disabled
1021-
logits = self.model.compute_logits(hidden_states, None)
1022-
return logits
1023-
1024-
def get_multimodal_embeddings(self, *args, **kwargs):
1025-
return self.model.get_multimodal_embeddings(*args, **kwargs)
1026-
1027-
def get_input_embeddings(self, *args, **kwargs):
1028-
return self.model.get_input_embeddings(*args, **kwargs)
1029-
1030995

1031996
def _get_padded_number(n: int, multiple: int) -> int:
1032997
return ((n + multiple - 1) // multiple) * multiple

vllm/v1/worker/tpu_worker.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -157,13 +157,19 @@ def determine_available_memory(self) -> int:
157157
runner_kv_caches)
158158

159159
self.model_runner._dummy_run(
160-
runner_kv_caches,
161-
num_tokens=self.scheduler_config.max_num_batched_tokens,
162-
)
160+
self.scheduler_config.max_num_batched_tokens)
163161

164162
# Synchronize before measuring the memory usage.
165163
xm.wait_device_ops()
166164

165+
# During the profiling run, the model runs without KV cache. After
166+
# the profiling run, the model always runs with KV cache. Here we clear
167+
# the dynamo cache and cached bytecode to ensure the model always has
168+
# one compiled bytecode. Having one FX graph/cached bytecode per
169+
# compiled model is required for `support_torch_compile` decorator to
170+
# skip dynamo guard.
171+
self.model_runner.reset_dynamo_cache()
172+
167173
# Get the maximum amount of memory used by the model weights and
168174
# intermediate activations.
169175
m = xm.get_memory_info(self.device)

0 commit comments

Comments
 (0)