Skip to content

Commit b3c3a2f

Browse files
authored
Fix embedding model accuracy issue when merged prefill is enabled (#1048)
The make_attn_bias in hpu_model_runner doesn't cover the non-causal embedding model mask set and also vertical mask off is not set when merged prefill is enabled.
1 parent 4445dca commit b3c3a2f

File tree

3 files changed

+35
-22
lines changed

3 files changed

+35
-22
lines changed

vllm/worker/hpu_enc_dec_model_runner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@
4141

4242
class HpuModelAdapterEncoderDecoder(HpuModelAdapter):
4343

44-
def __init__(self, model, vllm_config, layer_names):
45-
super().__init__(model, vllm_config, layer_names)
44+
def __init__(self, model, vllm_config, layer_names, is_causal):
45+
super().__init__(model, vllm_config, layer_names, False)
4646

4747
# We only wrap the language model in HPU graph because some Ops in
4848
# vision model will fallback to CPU and cause the graph building fail.

vllm/worker/hpu_model_runner.py

Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ def find_rope_layer(parent, path):
242242

243243
class HpuModelAdapter(torch.nn.Module):
244244

245-
def __init__(self, model, vllm_config, layer_names):
245+
def __init__(self, model, vllm_config, layer_names, is_causal):
246246
super().__init__()
247247
self.model = model
248248
self.prefill_use_fusedsdpa = "fsdpa" in enabled_flags()
@@ -253,9 +253,7 @@ def __init__(self, model, vllm_config, layer_names):
253253
self.dtype = vllm_config.model_config.dtype
254254
self.layer_names = layer_names
255255
self.is_pooler = hasattr(self.model, "_pooler")
256-
self.is_causal = True
257-
if self.is_pooler:
258-
self.set_causal_option(self.model)
256+
self.is_causal = is_causal
259257
self.use_merged_prefill = VLLM_MERGED_PREFILL
260258

261259
def _set_attn_bias(self, attn_metadata, batch_size, seq_len, device,
@@ -441,18 +439,6 @@ def make_empty_intermediate_tensors(self, *args, **kwargs):
441439
def generate_proposals(self, *args, **kwargs):
442440
return self.model.generate_proposals(*args, **kwargs)
443441

444-
def set_causal_option(self, module):
445-
if isinstance(module, HPUAttentionImpl) and hasattr(
446-
module, 'attn_type'):
447-
self.is_causal = not (
448-
module.attn_type == AttentionType.ENCODER
449-
or module.attn_type == AttentionType.ENCODER_ONLY
450-
or module.attn_type == AttentionType.ENCODER_DECODER)
451-
return
452-
else:
453-
for child_name, child_module in module.named_children():
454-
self.set_causal_option(child_module)
455-
456442
# sampler property will be used by spec_decode_worker
457443
# don't rename
458444
@property
@@ -628,6 +614,7 @@ def __init__(
628614
return_hidden_states: bool = False,
629615
input_registry: InputRegistry = INPUT_REGISTRY,
630616
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
617+
is_causal: bool = True,
631618
):
632619
ModelRunnerBase.__init__(self, vllm_config=vllm_config)
633620
environment.set_model_config(self.model_config)
@@ -716,6 +703,7 @@ def __init__(
716703
# For both multi-step scheduling and delayed sampling
717704
self.cached_step_outputs: List[torch.Tensor] = []
718705
self.is_pooler = False
706+
self.is_causal = is_causal
719707
# For delayed sampling
720708
self.cached_step_inputs: List[
721709
ModelInputForHPUWithSamplingMetadata] = []
@@ -865,12 +853,15 @@ def load_model(self) -> None:
865853
hidden_layer_markstep_interval)
866854
path_to_rope = get_path_to_rope(self.model)
867855
torch.hpu.synchronize()
868-
856+
self.is_causal = True
857+
if self.is_pooler:
858+
self.set_causal_option(self.model)
869859
with HabanaMemoryProfiler() as m_wrap:
870860
self.model = self._maybe_wrap_in_hpu_graph(
871861
self.model,
872862
vllm_config=self.vllm_config,
873-
layer_names=path_to_rope)
863+
layer_names=path_to_rope,
864+
is_causal=self.is_causal)
874865
msg = f"Wrapping in HPU Graph took {m_wrap.get_summary_string()}"
875866
logger.info(msg)
876867
with HabanaMemoryProfiler() as m_wrap:
@@ -1060,17 +1051,36 @@ def make_attn_bias(self, seq_lens, max_prompt_len, dtype):
10601051
pad=-1,
10611052
dtype=torch.long,
10621053
flat=self.use_merged_prefill)
1054+
10631055
q_seq_idx_t = seq_idx_t.unsqueeze(-1)
10641056
kv_seq_idx_t = seq_idx_t.unsqueeze(-2)
10651057
q_seq_pos_t = seq_pos_t.unsqueeze(-1)
10661058
kv_seq_pos_t = seq_pos_t.unsqueeze(-2)
10671059
seq_idx_t = q_seq_idx_t != kv_seq_idx_t
10681060
seq_pos_t = kv_seq_pos_t > q_seq_pos_t
1069-
attn_mask = seq_idx_t | seq_pos_t
1061+
attn_mask = (seq_idx_t | seq_pos_t) if self.is_causal else seq_idx_t
1062+
if self.is_pooler:
1063+
mask_v = torch.where(q_seq_pos_t < 0, True, False)
1064+
attn_mask = attn_mask | mask_v
1065+
off_value = -3E38 #small number, avoid nan and overflow
1066+
else:
1067+
off_value = -math.inf
10701068
attn_bias = torch.zeros_like(attn_mask, dtype=dtype)
1071-
attn_bias.masked_fill_(attn_mask, -math.inf)
1069+
attn_bias.masked_fill_(attn_mask, off_value)
10721070
return attn_bias.unsqueeze(1)
10731071

1072+
def set_causal_option(self, module):
1073+
if isinstance(module, HPUAttentionImpl) and hasattr(
1074+
module, 'attn_type'):
1075+
self.is_causal = not (
1076+
module.attn_type == AttentionType.ENCODER
1077+
or module.attn_type == AttentionType.ENCODER_ONLY
1078+
or module.attn_type == AttentionType.ENCODER_DECODER)
1079+
return
1080+
else:
1081+
for child_name, child_module in module.named_children():
1082+
self.set_causal_option(child_module)
1083+
10741084
def move_to_device(self, tensor):
10751085
return tensor if tensor is None else tensor.to(self.device,
10761086
non_blocking=True)

vllm/worker/hpu_worker.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,15 +85,18 @@ def __init__(
8585

8686
is_encoder_decoder_model = self._is_encoder_decoder_model()
8787
ModelRunnerClass: Type[HPUModelRunnerBase] = HPUModelRunner
88+
is_causal = True
8889
if self.model_config.runner_type == "pooling":
8990
ModelRunnerClass = HPUPoolingModelRunner
9091
elif is_encoder_decoder_model:
9192
ModelRunnerClass = HPUEncoderDecoderModelRunner
93+
is_causal = False
9294
self.model_runner: HPUModelRunnerBase = ModelRunnerClass(
9395
vllm_config=vllm_config,
9496
kv_cache_dtype=self.cache_config.cache_dtype,
9597
is_driver_worker=is_driver_worker,
9698
**speculative_args,
99+
is_causal=is_causal,
97100
)
98101
if model_runner_cls is not None:
99102
self.model_runner = model_runner_cls(self.model_runner)

0 commit comments

Comments
 (0)