Skip to content

Commit c7b820e

Browse files
committed
Style
1 parent 6e3a460 commit c7b820e

File tree

5 files changed

+12
-13
lines changed

5 files changed

+12
-13
lines changed

examples/pytorch/continuous_batching.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,9 @@ def batch_generate(
241241
if args.output_file is None:
242242
os.makedirs("runs/cb", exist_ok=True)
243243
attn = args.attn.replace("|", "_").replace("/", "_")
244-
args.output_file = f"runs/cb/{args.num_blocks}_{args.max_batch_tokens}_{attn}_{args.matmul_precision}_{args.samples}.json"
244+
args.output_file = (
245+
f"runs/cb/{args.num_blocks}_{args.max_batch_tokens}_{attn}_{args.matmul_precision}_{args.samples}.json"
246+
)
245247

246248
# Run warmup batch generation
247249
batch_generate(

examples/pytorch/continuous_batching_simple.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,14 +48,14 @@
4848
# Prepare tokenizer and dataset
4949
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, padding_side="left")
5050
dataset = datasets.load_dataset("openai/gsm8k", "socratic", split="test")
51-
dataset = dataset.select(range(args.samples))
51+
dataset = dataset.select(range(args.samples))
5252
tokenized_datasets = dataset.map(lambda x: tokenizer(x["question"]), batched=True)
5353
simple_batch_inputs = [item["input_ids"] for item in tokenized_datasets]
5454

5555
# Prepare generation config
5656
generation_config = GenerationConfig(
5757
max_new_tokens=512,
58-
use_cuda_graph=False, # Not supported for simple version
58+
use_cuda_graph=False, # Not supported for simple version
5959
eos_token_id=tokenizer.eos_token_id,
6060
pad_token_id=tokenizer.pad_token_id,
6161
do_sample=False,

src/transformers/generation/continuous_batching/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
from .cache import PagedAttentionCache
16-
from .continuous_api import ContinuousBatchingManager, ContinuousMixin
1716
from .classes import RequestState, RequestStatus
17+
from .continuous_api import ContinuousBatchingManager, ContinuousMixin
1818

1919

2020
__all__ = ["PagedAttentionCache", "RequestState", "RequestStatus", "ContinuousMixin", "ContinuousBatchingManager"]

src/transformers/generation/continuous_batching/cache.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# limitations under the License.
1515
from collections import deque
1616
from math import floor, sqrt
17-
from typing import Any, Optional, TypeVar, Union
17+
from typing import Optional, Union
1818

1919
import torch
2020

@@ -287,9 +287,7 @@ def compute_num_blocks_and_max_batch_tokens(
287287
logger.info(f"Cache memory: {cache_memory}")
288288

289289
# Compute memory footprints
290-
mem_per_activation_token = (
291-
m * self._activation_dtype.itemsize * (self.hidden_size + self.vocab_size)
292-
)
290+
mem_per_activation_token = m * self._activation_dtype.itemsize * (self.hidden_size + self.vocab_size)
293291
mem_per_cache_token = 2 * self.num_heads * self.head_dim * self.num_layers * cache_dtype.itemsize
294292
mem_per_input_token = 8 * m * self._input_dtype.itemsize
295293
logger.info(f"Memory per activation token: {mem_per_activation_token}")
@@ -299,7 +297,7 @@ def compute_num_blocks_and_max_batch_tokens(
299297
# Compute second-degree polynomial coefficients
300298
a = m * self._activation_dtype.itemsize
301299
b = mem_per_input_token + mem_per_cache_token + mem_per_activation_token
302-
c = - cache_memory
300+
c = -cache_memory
303301

304302
# Compute discriminant and greatest solution
305303
discriminant = b**2 - 4 * a * c

src/transformers/generation/continuous_batching/continuous_api.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def __init__(
106106
self.decode_stream = DecodeStream(skip_special_tokens=True)
107107

108108
def return_attention_mask(self) -> bool:
109-
return self.config._attn_implementation != "paged_attention" # we set `is_causal` to True in paged call
109+
return self.config._attn_implementation != "paged_attention" # we set `is_causal` to True in paged call
110110

111111
@traced(standalone=True)
112112
def setup_static_tensors(self):
@@ -161,7 +161,6 @@ def reset_static_tensors(self):
161161
if self.attention_mask is not None:
162162
self.attention_mask[:, :, :t, :c].fill_(torch.finfo(self.model_dtype).min)
163163

164-
165164
def get_model_kwargs(self) -> PagedAttentionArgs:
166165
"""Get model keyword arguments for the current batch."""
167166
# Compute the slice to return
@@ -172,8 +171,8 @@ def get_model_kwargs(self) -> PagedAttentionArgs:
172171
"input_ids": self.input_ids[:, :t],
173172
"attention_mask": self.attention_mask,
174173
"position_ids": self.position_ids[:, :t],
175-
"cu_seq_lens_q": self.cumulative_seqlens_q[:t+1],
176-
"cu_seq_lens_k": self.cumulative_seqlens_k[:t+1],
174+
"cu_seq_lens_q": self.cumulative_seqlens_q[: t + 1],
175+
"cu_seq_lens_k": self.cumulative_seqlens_k[: t + 1],
177176
"write_index": self.write_index[:t],
178177
"read_index": self.read_index[:c],
179178
"logits_indices": self.logits_indices[:t],

0 commit comments

Comments
 (0)