3737@dataclass
3838class PagedAttentionArgs :
3939 input_ids : torch .Tensor
40- attention_mask : torch .Tensor
40+ attention_mask : Optional [ torch .Tensor ]
4141 position_ids : torch .Tensor
4242 cumulative_seqlens_q : torch .Tensor
4343 cumulative_seqlens_k : torch .Tensor
@@ -105,6 +105,9 @@ def __init__(
105105 self .tokenizer = PreTrainedTokenizerFast .from_pretrained (self .config ._name_or_path )
106106 self .decode_stream = DecodeStream (skip_special_tokens = True )
107107
108+ def return_attention_mask (self ) -> bool :
109+ return self .config ._attn_implementation != "paged_attention" # we set `is_causal` to True in paged call
110+
108111 @traced (standalone = True )
109112 def setup_static_tensors (self ):
110113 T = self .max_batch_tokens
@@ -114,9 +117,6 @@ def setup_static_tensors(self):
114117 self .tensor_metadata = tensor_metadata
115118 self .input_ids = torch .empty ((1 , T ), ** tensor_metadata )
116119 self .position_ids = torch .empty ((1 , T ), ** tensor_metadata )
117- self .attention_mask = torch .empty (
118- (1 , 1 , T , max_token_budget ), dtype = self .model_dtype , device = self .model_device
119- )
120120 self .cumulative_seqlens_q = torch .empty ((T + 1 ,), ** tensor_metadata )
121121 self .cumulative_seqlens_k = torch .empty ((T + 1 ,), ** tensor_metadata )
122122 self .write_index = torch .empty ((T ,), ** tensor_metadata )
@@ -125,6 +125,13 @@ def setup_static_tensors(self):
125125 self .max_seqlen_q = 0
126126 self .max_seqlen_k = 0
127127 self .output_ids = torch .empty ((1 , T ), ** tensor_metadata )
128+ # Since attenention_mask is not always needed, we only allocate it if it is needed
129+ if self .return_attention_mask ():
130+ self .attention_mask = torch .empty (
131+ (1 , 1 , T , max_token_budget ), dtype = self .model_dtype , device = self .model_device
132+ )
133+ else :
134+ self .attention_mask = None
128135 # Initialize the tensors by pretending they are in full use
129136 self .actual_tokens = T
130137 self .cache_used = max_token_budget
@@ -143,7 +150,6 @@ def reset_static_tensors(self):
143150 # Reset the tensors
144151 self .input_ids [:, :t ].zero_ ()
145152 self .position_ids [:, :t ].zero_ ()
146- self .attention_mask [:, :, :t , :c ].fill_ (torch .finfo (self .model_dtype ).min )
147153 self .cumulative_seqlens_q [: t + 1 ].zero_ ()
148154 self .cumulative_seqlens_k [: t + 1 ].zero_ ()
149155 self .write_index [:t ].fill_ (- 1 )
@@ -152,17 +158,20 @@ def reset_static_tensors(self):
152158 self .max_seqlen_q = 0
153159 self .max_seqlen_k = 0
154160 self .output_ids [:, :t ].fill_ (- 1 )
161+ if self .attention_mask is not None :
162+ self .attention_mask [:, :, :t , :c ].fill_ (torch .finfo (self .model_dtype ).min )
163+
155164
156165 def get_model_kwargs (self ) -> PagedAttentionArgs :
157166 """Get model keyword arguments for the current batch."""
158167 # Compute the slice to return
159168 t = self .actual_tokens if self .slice_inputs else self .write_index .size (0 )
160169 c = self .cache_used if self .slice_inputs else self .read_index .size (0 )
161- # Return the tensors
162- return {
170+ # Prepare the kwargs
171+ kwargs = {
163172 "input_ids" : self .input_ids [:, :t ],
173+ "attention_mask" : self .attention_mask ,
164174 "position_ids" : self .position_ids [:, :t ],
165- "attention_mask" : self .attention_mask [:, :, :t , :c ], # NOTE: this is probably not used for paged attention
166175 "cu_seq_lens_q" : self .cumulative_seqlens_q [:t + 1 ],
167176 "cu_seq_lens_k" : self .cumulative_seqlens_k [:t + 1 ],
168177 "write_index" : self .write_index [:t ],
@@ -174,6 +183,10 @@ def get_model_kwargs(self) -> PagedAttentionArgs:
174183 "cache" : self .cache ,
175184 "use_cache" : False ,
176185 }
186+ # If the attention mask is not None, we slice it as the others
187+ if self .attention_mask is not None :
188+ kwargs ["attention_mask" ] = self .attention_mask [:, :, :t , :c ]
189+ return kwargs
177190
178191 def __repr__ (self ):
179192 return (
@@ -303,7 +316,7 @@ def _build_tensors(
303316 self .cache_used = len (read_index )
304317
305318 min_value = torch .finfo (self .model_dtype ).min
306- if self .config . _attn_implementation != "paged_attention" : # we set `is_causal` to True in paged call`
319+ if self .attention_mask is not None :
307320 for i in range (len (cumulative_seqlens_q ) - 1 ):
308321 if (
309322 cumulative_seqlens_q [i + 1 ] - cumulative_seqlens_q [i ]
0 commit comments