@@ -242,7 +242,7 @@ def find_rope_layer(parent, path):
242242
243243class HpuModelAdapter (torch .nn .Module ):
244244
245- def __init__ (self , model , vllm_config , layer_names ):
245+ def __init__ (self , model , vllm_config , layer_names , is_causal ):
246246 super ().__init__ ()
247247 self .model = model
248248 self .prefill_use_fusedsdpa = "fsdpa" in enabled_flags ()
@@ -253,9 +253,7 @@ def __init__(self, model, vllm_config, layer_names):
253253 self .dtype = vllm_config .model_config .dtype
254254 self .layer_names = layer_names
255255 self .is_pooler = hasattr (self .model , "_pooler" )
256- self .is_causal = True
257- if self .is_pooler :
258- self .set_causal_option (self .model )
256+ self .is_causal = is_causal
259257 self .use_merged_prefill = VLLM_MERGED_PREFILL
260258
261259 def _set_attn_bias (self , attn_metadata , batch_size , seq_len , device ,
@@ -441,18 +439,6 @@ def make_empty_intermediate_tensors(self, *args, **kwargs):
441439 def generate_proposals (self , * args , ** kwargs ):
442440 return self .model .generate_proposals (* args , ** kwargs )
443441
444- def set_causal_option (self , module ):
445- if isinstance (module , HPUAttentionImpl ) and hasattr (
446- module , 'attn_type' ):
447- self .is_causal = not (
448- module .attn_type == AttentionType .ENCODER
449- or module .attn_type == AttentionType .ENCODER_ONLY
450- or module .attn_type == AttentionType .ENCODER_DECODER )
451- return
452- else :
453- for child_name , child_module in module .named_children ():
454- self .set_causal_option (child_module )
455-
456442 # sampler property will be used by spec_decode_worker
457443 # don't rename
458444 @property
@@ -628,6 +614,7 @@ def __init__(
628614 return_hidden_states : bool = False ,
629615 input_registry : InputRegistry = INPUT_REGISTRY ,
630616 mm_registry : MultiModalRegistry = MULTIMODAL_REGISTRY ,
617+ is_causal : bool = True ,
631618 ):
632619 ModelRunnerBase .__init__ (self , vllm_config = vllm_config )
633620 environment .set_model_config (self .model_config )
@@ -716,6 +703,7 @@ def __init__(
716703 # For both multi-step scheduling and delayed sampling
717704 self .cached_step_outputs : List [torch .Tensor ] = []
718705 self .is_pooler = False
706+ self .is_causal = is_causal
719707 # For delayed sampling
720708 self .cached_step_inputs : List [
721709 ModelInputForHPUWithSamplingMetadata ] = []
@@ -865,12 +853,15 @@ def load_model(self) -> None:
865853 hidden_layer_markstep_interval )
866854 path_to_rope = get_path_to_rope (self .model )
867855 torch .hpu .synchronize ()
868-
856+ self .is_causal = True
857+ if self .is_pooler :
858+ self .set_causal_option (self .model )
869859 with HabanaMemoryProfiler () as m_wrap :
870860 self .model = self ._maybe_wrap_in_hpu_graph (
871861 self .model ,
872862 vllm_config = self .vllm_config ,
873- layer_names = path_to_rope )
863+ layer_names = path_to_rope ,
864+ is_causal = self .is_causal )
874865 msg = f"Wrapping in HPU Graph took { m_wrap .get_summary_string ()} "
875866 logger .info (msg )
876867 with HabanaMemoryProfiler () as m_wrap :
@@ -1060,17 +1051,36 @@ def make_attn_bias(self, seq_lens, max_prompt_len, dtype):
10601051 pad = - 1 ,
10611052 dtype = torch .long ,
10621053 flat = self .use_merged_prefill )
1054+
10631055 q_seq_idx_t = seq_idx_t .unsqueeze (- 1 )
10641056 kv_seq_idx_t = seq_idx_t .unsqueeze (- 2 )
10651057 q_seq_pos_t = seq_pos_t .unsqueeze (- 1 )
10661058 kv_seq_pos_t = seq_pos_t .unsqueeze (- 2 )
10671059 seq_idx_t = q_seq_idx_t != kv_seq_idx_t
10681060 seq_pos_t = kv_seq_pos_t > q_seq_pos_t
1069- attn_mask = seq_idx_t | seq_pos_t
1061+ attn_mask = (seq_idx_t | seq_pos_t ) if self .is_causal else seq_idx_t
1062+ if self .is_pooler :
1063+ mask_v = torch .where (q_seq_pos_t < 0 , True , False )
1064+ attn_mask = attn_mask | mask_v
1065+ off_value = - 3E38 #small number, avoid nan and overflow
1066+ else :
1067+ off_value = - math .inf
10701068 attn_bias = torch .zeros_like (attn_mask , dtype = dtype )
1071- attn_bias .masked_fill_ (attn_mask , - math . inf )
1069+ attn_bias .masked_fill_ (attn_mask , off_value )
10721070 return attn_bias .unsqueeze (1 )
10731071
1072+ def set_causal_option (self , module ):
1073+ if isinstance (module , HPUAttentionImpl ) and hasattr (
1074+ module , 'attn_type' ):
1075+ self .is_causal = not (
1076+ module .attn_type == AttentionType .ENCODER
1077+ or module .attn_type == AttentionType .ENCODER_ONLY
1078+ or module .attn_type == AttentionType .ENCODER_DECODER )
1079+ return
1080+ else :
1081+ for child_name , child_module in module .named_children ():
1082+ self .set_causal_option (child_module )
1083+
10741084 def move_to_device (self , tensor ):
10751085 return tensor if tensor is None else tensor .to (self .device ,
10761086 non_blocking = True )
0 commit comments