Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions vllm/worker/hpu_enc_dec_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@

class HpuModelAdapterEncoderDecoder(HpuModelAdapter):

def __init__(self, model, vllm_config, layer_names):
super().__init__(model, vllm_config, layer_names)
def __init__(self, model, vllm_config, layer_names, is_causal):
super().__init__(model, vllm_config, layer_names, False)

# We only wrap the language model in HPU graph because some Ops in
# vision model will fallback to CPU and cause the graph building fail.
Expand Down
50 changes: 30 additions & 20 deletions vllm/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def find_rope_layer(parent, path):

class HpuModelAdapter(torch.nn.Module):

def __init__(self, model, vllm_config, layer_names):
def __init__(self, model, vllm_config, layer_names, is_causal):
super().__init__()
self.model = model
self.prefill_use_fusedsdpa = "fsdpa" in enabled_flags()
Expand All @@ -253,9 +253,7 @@ def __init__(self, model, vllm_config, layer_names):
self.dtype = vllm_config.model_config.dtype
self.layer_names = layer_names
self.is_pooler = hasattr(self.model, "_pooler")
self.is_causal = True
if self.is_pooler:
self.set_causal_option(self.model)
self.is_causal = is_causal
self.use_merged_prefill = VLLM_MERGED_PREFILL

def _set_attn_bias(self, attn_metadata, batch_size, seq_len, device,
Expand Down Expand Up @@ -441,18 +439,6 @@ def make_empty_intermediate_tensors(self, *args, **kwargs):
def generate_proposals(self, *args, **kwargs):
return self.model.generate_proposals(*args, **kwargs)

def set_causal_option(self, module):
if isinstance(module, HPUAttentionImpl) and hasattr(
module, 'attn_type'):
self.is_causal = not (
module.attn_type == AttentionType.ENCODER
or module.attn_type == AttentionType.ENCODER_ONLY
or module.attn_type == AttentionType.ENCODER_DECODER)
return
else:
for child_name, child_module in module.named_children():
self.set_causal_option(child_module)

# sampler property will be used by spec_decode_worker
# don't rename
@property
Expand Down Expand Up @@ -628,6 +614,7 @@ def __init__(
return_hidden_states: bool = False,
input_registry: InputRegistry = INPUT_REGISTRY,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
is_causal: bool = True,
):
ModelRunnerBase.__init__(self, vllm_config=vllm_config)
environment.set_model_config(self.model_config)
Expand Down Expand Up @@ -715,6 +702,7 @@ def __init__(
# For both multi-step scheduling and delayed sampling
self.cached_step_outputs: List[torch.Tensor] = []
self.is_pooler = False
self.is_causal = is_causal
# For delayed sampling
self.cached_step_inputs: List[
ModelInputForHPUWithSamplingMetadata] = []
Expand Down Expand Up @@ -839,12 +827,15 @@ def load_model(self) -> None:
hidden_layer_markstep_interval)
path_to_rope = get_path_to_rope(self.model)
torch.hpu.synchronize()

self.is_causal = True
if self.is_pooler:
self.set_causal_option(self.model)
with HabanaMemoryProfiler() as m_wrap:
self.model = self._maybe_wrap_in_hpu_graph(
self.model,
vllm_config=self.vllm_config,
layer_names=path_to_rope)
layer_names=path_to_rope,
is_causal=self.is_causal)
msg = f"Wrapping in HPU Graph took {m_wrap.get_summary_string()}"
logger.info(msg)
with HabanaMemoryProfiler() as m_wrap:
Expand Down Expand Up @@ -1034,17 +1025,36 @@ def make_attn_bias(self, seq_lens, max_prompt_len, dtype):
pad=-1,
dtype=torch.long,
flat=self.use_merged_prefill)

q_seq_idx_t = seq_idx_t.unsqueeze(-1)
kv_seq_idx_t = seq_idx_t.unsqueeze(-2)
q_seq_pos_t = seq_pos_t.unsqueeze(-1)
kv_seq_pos_t = seq_pos_t.unsqueeze(-2)
seq_idx_t = q_seq_idx_t != kv_seq_idx_t
seq_pos_t = kv_seq_pos_t > q_seq_pos_t
attn_mask = seq_idx_t | seq_pos_t
attn_mask = (seq_idx_t | seq_pos_t) if self.is_causal else seq_idx_t
if self.is_pooler:
mask_v = torch.where(q_seq_pos_t < 0, True, False)
attn_mask = attn_mask | mask_v
off_value = -3E38 #small number, avoid nan and overflow
else:
off_value = -math.inf
attn_bias = torch.zeros_like(attn_mask, dtype=dtype)
attn_bias.masked_fill_(attn_mask, -math.inf)
attn_bias.masked_fill_(attn_mask, off_value)
return attn_bias.unsqueeze(1)

def set_causal_option(self, module):
if isinstance(module, HPUAttentionImpl) and hasattr(
module, 'attn_type'):
self.is_causal = not (
module.attn_type == AttentionType.ENCODER
or module.attn_type == AttentionType.ENCODER_ONLY
or module.attn_type == AttentionType.ENCODER_DECODER)
return
else:
for child_name, child_module in module.named_children():
self.set_causal_option(child_module)

def move_to_device(self, tensor):
return tensor if tensor is None else tensor.to(self.device,
non_blocking=True)
Expand Down
3 changes: 3 additions & 0 deletions vllm/worker/hpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,15 +85,18 @@ def __init__(

is_encoder_decoder_model = self._is_encoder_decoder_model()
ModelRunnerClass: Type[HPUModelRunnerBase] = HPUModelRunner
is_causal = True
if self.model_config.runner_type == "pooling":
ModelRunnerClass = HPUPoolingModelRunner
elif is_encoder_decoder_model:
ModelRunnerClass = HPUEncoderDecoderModelRunner
is_causal = False
self.model_runner: HPUModelRunnerBase = ModelRunnerClass(
vllm_config=vllm_config,
kv_cache_dtype=self.cache_config.cache_dtype,
is_driver_worker=is_driver_worker,
**speculative_args,
is_causal=is_causal,
)
if model_runner_cls is not None:
self.model_runner = model_runner_cls(self.model_runner)
Expand Down