Skip to content

Commit 1cdbbb3

Browse files
remi-orMcPatate
andauthored
Support sliding window in CB (#40688)
* CB example: better compare feature * Cache managers, still issue w/ effective length * WIP -- fix for effective length * Renames * Wroking, need better parity checks, we mind be missing 1 token * Small fixes * Fixed wrong attn mask and broke cache into pieces * Warmup is slowing down things, disabling it * Cache was too big, fixed * Simplified index objects * Added a profile option to the example * Avoid calls to memory reporing tools * Restore full attention read indices for better latency * Adressed some TODOS and style * Docstrings for cache managers * Docstrings for Schedulers * Refactor scheudlers * [Important] Cache fix for sliding window, check with small sw size * Updated doc for cache memory compute and cache as a whole * Moved a todo * Nits and style * Fix for when sliding window is smaller than max batch per token * Paged interface update * Support for FLash in new API * Fix example CB * Fix bug in CB for paged * Revert example * Style * Review compliance * Style * Styleeeee * Removed NO_SLIDING_WINDOW * Review #2 compliance * Better art * Turn cum_seqlens_k in a dict * Attn mask is now a dict * Update examples/pytorch/continuous_batching.py Co-authored-by: Luc Georges <McPatate@users.noreply.github.com> * Adressed McPatate pro review * Style and fix --------- Co-authored-by: Luc Georges <McPatate@users.noreply.github.com>
1 parent ed10021 commit 1cdbbb3

File tree

11 files changed

+1018
-429
lines changed

11 files changed

+1018
-429
lines changed

examples/pytorch/continuous_batching.py

Lines changed: 58 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -13,49 +13,50 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
import argparse
16+
import contextlib
1617
import json
1718
import os
1819
import time
1920
from typing import Optional
2021

2122
import datasets
2223
import torch
24+
from torch.profiler import ProfilerActivity, profile
25+
from tqdm import tqdm
2326

2427
from transformers import AutoModelForCausalLM, AutoTokenizer
2528
from transformers.generation import GenerationConfig
2629

2730

28-
MODEL_ID = "Qwen/Qwen3-4B-Instruct-2507"
31+
# MODEL_ID = "Qwen/Qwen3-4B-Instruct-2507"
32+
SLIDING_WINDOW = 0
33+
MODEL_ID = "google/gemma-2-2b-it" if SLIDING_WINDOW > 0 else "Qwen/Qwen3-4B-Instruct-2507"
34+
FORCE_MAX_LENGTH = False # should be False unless you are debugging sliding window features
2935

3036

3137
def generate_simple(
32-
attn_implementation: str, simple_batch_inputs: list[int], generation_config: GenerationConfig
33-
) -> list[str]:
34-
attn_implementation = {
38+
attn_impl: str, simple_batch_inputs: list[int], generation_config: GenerationConfig
39+
) -> dict[str, str]:
40+
attn_impl = {
3541
"sdpa_paged": "sdpa",
3642
"eager_paged": "eager",
3743
"flash_paged": "flash_attention_2",
38-
}[attn_implementation]
44+
}[attn_impl]
3945

40-
model = (
41-
AutoModelForCausalLM.from_pretrained(
42-
MODEL_ID,
43-
torch_dtype=torch.bfloat16,
44-
attn_implementation=attn_implementation,
45-
)
46-
.cuda()
47-
.eval()
48-
)
46+
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, dtype=torch.bfloat16, attn_implementation=attn_impl)
47+
model = model.cuda().eval()
48+
if getattr(model.config, "sliding_window", None) is not None:
49+
model.config.sliding_window = SLIDING_WINDOW
4950

50-
decoded_outputs = []
51-
for input_ids in simple_batch_inputs:
51+
decoded_outputs = {}
52+
for input_ids in tqdm(simple_batch_inputs, desc="Generating outputs without CB"):
53+
key = " ".join(map(str, input_ids)) # This will be used to identify the output after batched generation
5254
input_ids = torch.tensor([input_ids]).to("cuda")
53-
attention_mask = torch.ones_like(input_ids)
54-
outputs = model.generate(input_ids, attention_mask=attention_mask, generation_config=generation_config)
55+
# attention_mask = torch.ones_like(input_ids)
56+
outputs = model.generate(input_ids, generation_config=generation_config, use_model_defaults=False)
5557
generated_tokens = outputs[0][input_ids.shape[1] :]
5658
decoded_output = tokenizer.decode(generated_tokens, skip_special_tokens=True)
57-
decoded_outputs.append(decoded_output)
58-
59+
decoded_outputs[key] = decoded_output
5960
return decoded_outputs
6061

6162

@@ -117,7 +118,9 @@ def batch_generate(
117118
data = []
118119
for i, request in enumerate(batch_outputs):
119120
input_text = tokenizer.decode(batch_outputs[request].prompt_ids, skip_special_tokens=True)
120-
data.append({"input": input_text})
121+
# The key is used to tie back to the output of unbatched generation
122+
key = " ".join(map(str, batch_outputs[request].prompt_ids))
123+
data.append({"input": input_text, "key": key})
121124

122125
# Try to decode the output
123126
try:
@@ -142,9 +145,11 @@ def batch_generate(
142145

143146
# Compare with classic generate if asked
144147
if expected_outputs is not None:
145-
matches = output_text == expected_outputs[i]
146-
data[-1]["ref"] = expected_outputs[i]
148+
expected_output = expected_outputs.pop(key)
149+
matches = output_text == expected_output # TODO: rework this for a better distance metric
150+
data[-1]["ref"] = expected_output
147151
data[-1]["matches"] = matches
152+
data[-1].pop("key")
148153
print(f"Request {i} matches" if matches else f"Request {i} does NOT match!")
149154

150155
# Compute stats and maybe print them
@@ -191,6 +196,7 @@ def batch_generate(
191196
parser.add_argument("--output-file", type=str, default=None)
192197
parser.add_argument("--compare", action="store_true", default=False)
193198
parser.add_argument("--metrics", action="store_true", default=False)
199+
parser.add_argument("--profile", type=str, default=None)
194200
args = parser.parse_args()
195201

196202
# If turned on, we setup metrics
@@ -208,6 +214,9 @@ def batch_generate(
208214
dtype=torch.bfloat16,
209215
)
210216
model = model.cuda().eval()
217+
if getattr(model.config, "sliding_window", None) is not None:
218+
print(f"Setting sliding window from {model.config.sliding_window} to {SLIDING_WINDOW}")
219+
model.config.sliding_window = SLIDING_WINDOW
211220

212221
# If turned on, we compile the model
213222
if args.compile:
@@ -218,16 +227,17 @@ def batch_generate(
218227

219228
# Prepare tokenizer and dataset
220229
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, padding_side="left")
230+
221231
dataset = datasets.load_dataset("openai/gsm8k", "socratic", split="test")
222-
dataset = dataset.select(range(args.samples)) # Use only 5 examples for the simple version
223-
tokenized_datasets = dataset.map(lambda x: tokenizer(x["question"]), batched=True)
224-
simple_batch_inputs = [item["input_ids"] for item in tokenized_datasets]
232+
dataset = dataset.select(range(args.samples))
233+
234+
simple_batch_inputs = [tokenizer(item["question"])["input_ids"] for item in dataset]
225235

226236
# Prepare generation config
227237
generation_config = GenerationConfig(
228238
max_new_tokens=512,
229239
use_cuda_graph=args.use_cuda_graph,
230-
eos_token_id=tokenizer.eos_token_id,
240+
eos_token_id=tokenizer.pad_token_id if FORCE_MAX_LENGTH else tokenizer.eos_token_id,
231241
pad_token_id=tokenizer.pad_token_id,
232242
do_sample=True,
233243
temperature=0.8,
@@ -247,7 +257,7 @@ def batch_generate(
247257
f"runs/cb/{args.num_blocks}_{args.max_batch_tokens}_{attn}_{args.matmul_precision}_{args.samples}.json"
248258
)
249259

250-
# Run warmup batch generation
260+
# Run warmup batch generation # TODO: understand why warmup incurs a large overhead during cache creation
251261
batch_generate(
252262
model,
253263
simple_batch_inputs[: min(5, args.samples)],
@@ -257,17 +267,26 @@ def batch_generate(
257267
slice_inputs=args.slice_inputs,
258268
)
259269

260-
# Run batch generation
261-
gen_time, tok_per_sec = batch_generate(
262-
model,
263-
simple_batch_inputs,
264-
generation_config,
265-
tokenizer,
266-
displayed_samples=args.displayed,
267-
output_file=args.output_file,
268-
expected_outputs=expected_outputs,
269-
slice_inputs=args.slice_inputs,
270-
)
270+
if args.profile is not None:
271+
cm = profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True)
272+
else:
273+
cm = contextlib.nullcontext()
274+
with cm as prof:
275+
# Run batch generation
276+
gen_time, tok_per_sec = batch_generate(
277+
model,
278+
simple_batch_inputs,
279+
generation_config,
280+
tokenizer,
281+
displayed_samples=args.displayed,
282+
output_file=args.output_file,
283+
expected_outputs=expected_outputs,
284+
slice_inputs=args.slice_inputs,
285+
)
286+
if args.profile is not None:
287+
filename = args.profile if args.profile.endswith(".json") else args.profile + ".json"
288+
prof.export_chrome_trace(filename)
271289

272290
# Example usage:
291+
# python examples/pytorch/continuous_batching.py --attn sdpa_paged -mp none --slice-inputs --samples 3 --compare
273292
# python examples/pytorch/continuous_batching.py --num-blocks 369 --max-batch-tokens 23 --attn sdpa_paged -mp none --samples 1 --displayed 0 --output-file sliced.json

src/transformers/generation/continuous_batching/__init__.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,14 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
from .cache import PagedAttentionCache
16-
from .classes import RequestState, RequestStatus
1716
from .continuous_api import ContinuousBatchingManager, ContinuousMixin
17+
from .requests import RequestState, RequestStatus
1818

1919

20-
__all__ = ["PagedAttentionCache", "RequestState", "RequestStatus", "ContinuousMixin", "ContinuousBatchingManager"]
20+
__all__ = [
21+
"ContinuousBatchingManager",
22+
"ContinuousMixin",
23+
"PagedAttentionCache",
24+
"RequestState",
25+
"RequestStatus",
26+
]

0 commit comments

Comments
 (0)