diff --git a/vllm/worker/hpu_enc_dec_model_runner.py b/vllm/worker/hpu_enc_dec_model_runner.py index 6ffe7071d22d..b45d0f542c24 100644 --- a/vllm/worker/hpu_enc_dec_model_runner.py +++ b/vllm/worker/hpu_enc_dec_model_runner.py @@ -41,8 +41,8 @@ class HpuModelAdapterEncoderDecoder(HpuModelAdapter): - def __init__(self, model, vllm_config, layer_names): - super().__init__(model, vllm_config, layer_names) + def __init__(self, model, vllm_config, layer_names, is_causal): + super().__init__(model, vllm_config, layer_names, False) # We only wrap the language model in HPU graph because some Ops in # vision model will fallback to CPU and cause the graph building fail. diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 80bda7407f49..02cdf6795785 100755 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -242,7 +242,7 @@ def find_rope_layer(parent, path): class HpuModelAdapter(torch.nn.Module): - def __init__(self, model, vllm_config, layer_names): + def __init__(self, model, vllm_config, layer_names, is_causal): super().__init__() self.model = model self.prefill_use_fusedsdpa = "fsdpa" in enabled_flags() @@ -253,9 +253,7 @@ def __init__(self, model, vllm_config, layer_names): self.dtype = vllm_config.model_config.dtype self.layer_names = layer_names self.is_pooler = hasattr(self.model, "_pooler") - self.is_causal = True - if self.is_pooler: - self.set_causal_option(self.model) + self.is_causal = is_causal self.use_merged_prefill = VLLM_MERGED_PREFILL def _set_attn_bias(self, attn_metadata, batch_size, seq_len, device, @@ -441,18 +439,6 @@ def make_empty_intermediate_tensors(self, *args, **kwargs): def generate_proposals(self, *args, **kwargs): return self.model.generate_proposals(*args, **kwargs) - def set_causal_option(self, module): - if isinstance(module, HPUAttentionImpl) and hasattr( - module, 'attn_type'): - self.is_causal = not ( - module.attn_type == AttentionType.ENCODER - or module.attn_type == AttentionType.ENCODER_ONLY - or module.attn_type == AttentionType.ENCODER_DECODER) - return - else: - for child_name, child_module in module.named_children(): - self.set_causal_option(child_module) - # sampler property will be used by spec_decode_worker # don't rename @property @@ -628,6 +614,7 @@ def __init__( return_hidden_states: bool = False, input_registry: InputRegistry = INPUT_REGISTRY, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, + is_causal: bool = True, ): ModelRunnerBase.__init__(self, vllm_config=vllm_config) environment.set_model_config(self.model_config) @@ -715,6 +702,7 @@ def __init__( # For both multi-step scheduling and delayed sampling self.cached_step_outputs: List[torch.Tensor] = [] self.is_pooler = False + self.is_causal = is_causal # For delayed sampling self.cached_step_inputs: List[ ModelInputForHPUWithSamplingMetadata] = [] @@ -839,12 +827,15 @@ def load_model(self) -> None: hidden_layer_markstep_interval) path_to_rope = get_path_to_rope(self.model) torch.hpu.synchronize() - + self.is_causal = True + if self.is_pooler: + self.set_causal_option(self.model) with HabanaMemoryProfiler() as m_wrap: self.model = self._maybe_wrap_in_hpu_graph( self.model, vllm_config=self.vllm_config, - layer_names=path_to_rope) + layer_names=path_to_rope, + is_causal=self.is_causal) msg = f"Wrapping in HPU Graph took {m_wrap.get_summary_string()}" logger.info(msg) with HabanaMemoryProfiler() as m_wrap: @@ -1034,17 +1025,36 @@ def make_attn_bias(self, seq_lens, max_prompt_len, dtype): pad=-1, dtype=torch.long, flat=self.use_merged_prefill) + q_seq_idx_t = seq_idx_t.unsqueeze(-1) kv_seq_idx_t = seq_idx_t.unsqueeze(-2) q_seq_pos_t = seq_pos_t.unsqueeze(-1) kv_seq_pos_t = seq_pos_t.unsqueeze(-2) seq_idx_t = q_seq_idx_t != kv_seq_idx_t seq_pos_t = kv_seq_pos_t > q_seq_pos_t - attn_mask = seq_idx_t | seq_pos_t + attn_mask = (seq_idx_t | seq_pos_t) if self.is_causal else seq_idx_t + if self.is_pooler: + mask_v = torch.where(q_seq_pos_t < 0, True, False) + attn_mask = attn_mask | mask_v + off_value = -3E38 #small number, avoid nan and overflow + else: + off_value = -math.inf attn_bias = torch.zeros_like(attn_mask, dtype=dtype) - attn_bias.masked_fill_(attn_mask, -math.inf) + attn_bias.masked_fill_(attn_mask, off_value) return attn_bias.unsqueeze(1) + def set_causal_option(self, module): + if isinstance(module, HPUAttentionImpl) and hasattr( + module, 'attn_type'): + self.is_causal = not ( + module.attn_type == AttentionType.ENCODER + or module.attn_type == AttentionType.ENCODER_ONLY + or module.attn_type == AttentionType.ENCODER_DECODER) + return + else: + for child_name, child_module in module.named_children(): + self.set_causal_option(child_module) + def move_to_device(self, tensor): return tensor if tensor is None else tensor.to(self.device, non_blocking=True) diff --git a/vllm/worker/hpu_worker.py b/vllm/worker/hpu_worker.py index f20c5adb258c..d8c965e0d20e 100755 --- a/vllm/worker/hpu_worker.py +++ b/vllm/worker/hpu_worker.py @@ -85,15 +85,18 @@ def __init__( is_encoder_decoder_model = self._is_encoder_decoder_model() ModelRunnerClass: Type[HPUModelRunnerBase] = HPUModelRunner + is_causal = True if self.model_config.runner_type == "pooling": ModelRunnerClass = HPUPoolingModelRunner elif is_encoder_decoder_model: ModelRunnerClass = HPUEncoderDecoderModelRunner + is_causal = False self.model_runner: HPUModelRunnerBase = ModelRunnerClass( vllm_config=vllm_config, kv_cache_dtype=self.cache_config.cache_dtype, is_driver_worker=is_driver_worker, **speculative_args, + is_causal=is_causal, ) if model_runner_cls is not None: self.model_runner = model_runner_cls(self.model_runner)