Skip to content

Commit d559979

Browse files
authored
[bugfix] fix cpu tests (#10585)
Signed-off-by: youkaichao <youkaichao@gmail.com>
1 parent d345f40 commit d559979

File tree

3 files changed

+16
-10
lines changed

3 files changed

+16
-10
lines changed

vllm/worker/cpu_embedding_model_runner.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import torch
55

6+
from vllm.forward_context import set_forward_context
67
from vllm.model_executor.pooling_metadata import PoolingMetadata
78
from vllm.multimodal import MultiModalKwargs
89
from vllm.pooling_params import PoolingParams
@@ -64,7 +65,8 @@ def execute_model(
6465
intermediate_tensors,
6566
}
6667

67-
hidden_states = model_executable(**execute_model_kwargs)
68+
with set_forward_context(model_input.attn_metadata, self.vllm_config):
69+
hidden_states = model_executable(**execute_model_kwargs)
6870

6971
# Only perform pooling in the driver worker.
7072
if not self.is_driver_worker:

vllm/worker/cpu_enc_dec_model_runner.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import torch
55

66
from vllm.attention import AttentionMetadata
7+
from vllm.forward_context import set_forward_context
78
from vllm.model_executor import SamplingMetadata
89
from vllm.model_executor.layers.sampler import SamplerOutput
910
from vllm.multimodal import MultiModalKwargs
@@ -303,7 +304,8 @@ def execute_model(
303304
intermediate_tensors,
304305
}
305306

306-
hidden_states = model_executable(**execute_model_kwargs)
307+
with set_forward_context(model_input.attn_metadata, self.vllm_config):
308+
hidden_states = model_executable(**execute_model_kwargs)
307309

308310
# Compute the logits.
309311
logits = self.model.compute_logits(hidden_states,

vllm/worker/cpu_model_runner.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from vllm.attention import AttentionMetadata, get_attn_backend
1212
from vllm.config import VllmConfig
13+
from vllm.forward_context import set_forward_context
1314
from vllm.logger import init_logger
1415
from vllm.model_executor import SamplingMetadata
1516
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
@@ -487,14 +488,15 @@ def execute_model(
487488
multimodal_kwargs = MultiModalKwargs.as_kwargs(
488489
model_input.multi_modal_kwargs, device=self.device)
489490

490-
hidden_states = model_executable(
491-
input_ids=model_input.input_tokens,
492-
positions=model_input.input_positions,
493-
kv_caches=kv_caches,
494-
attn_metadata=model_input.attn_metadata,
495-
intermediate_tensors=intermediate_tensors,
496-
**multimodal_kwargs,
497-
)
491+
with set_forward_context(model_input.attn_metadata, self.vllm_config):
492+
hidden_states = model_executable(
493+
input_ids=model_input.input_tokens,
494+
positions=model_input.input_positions,
495+
kv_caches=kv_caches,
496+
attn_metadata=model_input.attn_metadata,
497+
intermediate_tensors=intermediate_tensors,
498+
**multimodal_kwargs,
499+
)
498500

499501
# Compute the logits.
500502
logits = self.model.compute_logits(hidden_states,

0 commit comments

Comments
 (0)