Skip to content

Commit 6e3a460

Browse files
committed
Made allocation of attention mask optionnal
1 parent f9f1f41 commit 6e3a460

File tree

1 file changed

+22
-9
lines changed

1 file changed

+22
-9
lines changed

src/transformers/generation/continuous_batching/continuous_api.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
@dataclass
3838
class PagedAttentionArgs:
3939
input_ids: torch.Tensor
40-
attention_mask: torch.Tensor
40+
attention_mask: Optional[torch.Tensor]
4141
position_ids: torch.Tensor
4242
cumulative_seqlens_q: torch.Tensor
4343
cumulative_seqlens_k: torch.Tensor
@@ -105,6 +105,9 @@ def __init__(
105105
self.tokenizer = PreTrainedTokenizerFast.from_pretrained(self.config._name_or_path)
106106
self.decode_stream = DecodeStream(skip_special_tokens=True)
107107

108+
def return_attention_mask(self) -> bool:
109+
return self.config._attn_implementation != "paged_attention" # we set `is_causal` to True in paged call
110+
108111
@traced(standalone=True)
109112
def setup_static_tensors(self):
110113
T = self.max_batch_tokens
@@ -114,9 +117,6 @@ def setup_static_tensors(self):
114117
self.tensor_metadata = tensor_metadata
115118
self.input_ids = torch.empty((1, T), **tensor_metadata)
116119
self.position_ids = torch.empty((1, T), **tensor_metadata)
117-
self.attention_mask = torch.empty(
118-
(1, 1, T, max_token_budget), dtype=self.model_dtype, device=self.model_device
119-
)
120120
self.cumulative_seqlens_q = torch.empty((T + 1,), **tensor_metadata)
121121
self.cumulative_seqlens_k = torch.empty((T + 1,), **tensor_metadata)
122122
self.write_index = torch.empty((T,), **tensor_metadata)
@@ -125,6 +125,13 @@ def setup_static_tensors(self):
125125
self.max_seqlen_q = 0
126126
self.max_seqlen_k = 0
127127
self.output_ids = torch.empty((1, T), **tensor_metadata)
128+
# Since attenention_mask is not always needed, we only allocate it if it is needed
129+
if self.return_attention_mask():
130+
self.attention_mask = torch.empty(
131+
(1, 1, T, max_token_budget), dtype=self.model_dtype, device=self.model_device
132+
)
133+
else:
134+
self.attention_mask = None
128135
# Initialize the tensors by pretending they are in full use
129136
self.actual_tokens = T
130137
self.cache_used = max_token_budget
@@ -143,7 +150,6 @@ def reset_static_tensors(self):
143150
# Reset the tensors
144151
self.input_ids[:, :t].zero_()
145152
self.position_ids[:, :t].zero_()
146-
self.attention_mask[:, :, :t, :c].fill_(torch.finfo(self.model_dtype).min)
147153
self.cumulative_seqlens_q[: t + 1].zero_()
148154
self.cumulative_seqlens_k[: t + 1].zero_()
149155
self.write_index[:t].fill_(-1)
@@ -152,17 +158,20 @@ def reset_static_tensors(self):
152158
self.max_seqlen_q = 0
153159
self.max_seqlen_k = 0
154160
self.output_ids[:, :t].fill_(-1)
161+
if self.attention_mask is not None:
162+
self.attention_mask[:, :, :t, :c].fill_(torch.finfo(self.model_dtype).min)
163+
155164

156165
def get_model_kwargs(self) -> PagedAttentionArgs:
157166
"""Get model keyword arguments for the current batch."""
158167
# Compute the slice to return
159168
t = self.actual_tokens if self.slice_inputs else self.write_index.size(0)
160169
c = self.cache_used if self.slice_inputs else self.read_index.size(0)
161-
# Return the tensors
162-
return {
170+
# Prepare the kwargs
171+
kwargs = {
163172
"input_ids": self.input_ids[:, :t],
173+
"attention_mask": self.attention_mask,
164174
"position_ids": self.position_ids[:, :t],
165-
"attention_mask": self.attention_mask[:, :, :t, :c], # NOTE: this is probably not used for paged attention
166175
"cu_seq_lens_q": self.cumulative_seqlens_q[:t+1],
167176
"cu_seq_lens_k": self.cumulative_seqlens_k[:t+1],
168177
"write_index": self.write_index[:t],
@@ -174,6 +183,10 @@ def get_model_kwargs(self) -> PagedAttentionArgs:
174183
"cache": self.cache,
175184
"use_cache": False,
176185
}
186+
# If the attention mask is not None, we slice it as the others
187+
if self.attention_mask is not None:
188+
kwargs["attention_mask"] = self.attention_mask[:, :, :t, :c]
189+
return kwargs
177190

178191
def __repr__(self):
179192
return (
@@ -303,7 +316,7 @@ def _build_tensors(
303316
self.cache_used = len(read_index)
304317

305318
min_value = torch.finfo(self.model_dtype).min
306-
if self.config._attn_implementation != "paged_attention": # we set `is_causal` to True in paged call`
319+
if self.attention_mask is not None:
307320
for i in range(len(cumulative_seqlens_q) - 1):
308321
if (
309322
cumulative_seqlens_q[i + 1] - cumulative_seqlens_q[i]

0 commit comments

Comments
 (0)