1313# See the License for the specific language governing permissions and
1414# limitations under the License.
1515import argparse
16+ import contextlib
1617import json
1718import os
1819import time
1920from typing import Optional
2021
2122import datasets
2223import torch
24+ from torch .profiler import ProfilerActivity , profile
25+ from tqdm import tqdm
2326
2427from transformers import AutoModelForCausalLM , AutoTokenizer
2528from transformers .generation import GenerationConfig
2629
2730
28- MODEL_ID = "Qwen/Qwen3-4B-Instruct-2507"
31+ # MODEL_ID = "Qwen/Qwen3-4B-Instruct-2507"
32+ SLIDING_WINDOW = 0
33+ MODEL_ID = "google/gemma-2-2b-it" if SLIDING_WINDOW > 0 else "Qwen/Qwen3-4B-Instruct-2507"
34+ FORCE_MAX_LENGTH = False # should be False unless you are debugging sliding window features
2935
3036
3137def generate_simple (
32- attn_implementation : str , simple_batch_inputs : list [int ], generation_config : GenerationConfig
33- ) -> list [ str ]:
34- attn_implementation = {
38+ attn_impl : str , simple_batch_inputs : list [int ], generation_config : GenerationConfig
39+ ) -> dict [ str , str ]:
40+ attn_impl = {
3541 "sdpa_paged" : "sdpa" ,
3642 "eager_paged" : "eager" ,
3743 "flash_paged" : "flash_attention_2" ,
38- }[attn_implementation ]
44+ }[attn_impl ]
3945
40- model = (
41- AutoModelForCausalLM .from_pretrained (
42- MODEL_ID ,
43- torch_dtype = torch .bfloat16 ,
44- attn_implementation = attn_implementation ,
45- )
46- .cuda ()
47- .eval ()
48- )
46+ model = AutoModelForCausalLM .from_pretrained (MODEL_ID , dtype = torch .bfloat16 , attn_implementation = attn_impl )
47+ model = model .cuda ().eval ()
48+ if getattr (model .config , "sliding_window" , None ) is not None :
49+ model .config .sliding_window = SLIDING_WINDOW
4950
50- decoded_outputs = []
51- for input_ids in simple_batch_inputs :
51+ decoded_outputs = {}
52+ for input_ids in tqdm (simple_batch_inputs , desc = "Generating outputs without CB" ):
53+ key = " " .join (map (str , input_ids )) # This will be used to identify the output after batched generation
5254 input_ids = torch .tensor ([input_ids ]).to ("cuda" )
53- attention_mask = torch .ones_like (input_ids )
54- outputs = model .generate (input_ids , attention_mask = attention_mask , generation_config = generation_config )
55+ # attention_mask = torch.ones_like(input_ids)
56+ outputs = model .generate (input_ids , generation_config = generation_config , use_model_defaults = False )
5557 generated_tokens = outputs [0 ][input_ids .shape [1 ] :]
5658 decoded_output = tokenizer .decode (generated_tokens , skip_special_tokens = True )
57- decoded_outputs .append (decoded_output )
58-
59+ decoded_outputs [key ] = decoded_output
5960 return decoded_outputs
6061
6162
@@ -117,7 +118,9 @@ def batch_generate(
117118 data = []
118119 for i , request in enumerate (batch_outputs ):
119120 input_text = tokenizer .decode (batch_outputs [request ].prompt_ids , skip_special_tokens = True )
120- data .append ({"input" : input_text })
121+ # The key is used to tie back to the output of unbatched generation
122+ key = " " .join (map (str , batch_outputs [request ].prompt_ids ))
123+ data .append ({"input" : input_text , "key" : key })
121124
122125 # Try to decode the output
123126 try :
@@ -142,9 +145,11 @@ def batch_generate(
142145
143146 # Compare with classic generate if asked
144147 if expected_outputs is not None :
145- matches = output_text == expected_outputs [i ]
146- data [- 1 ]["ref" ] = expected_outputs [i ]
148+ expected_output = expected_outputs .pop (key )
149+ matches = output_text == expected_output # TODO: rework this for a better distance metric
150+ data [- 1 ]["ref" ] = expected_output
147151 data [- 1 ]["matches" ] = matches
152+ data [- 1 ].pop ("key" )
148153 print (f"Request { i } matches" if matches else f"Request { i } does NOT match!" )
149154
150155 # Compute stats and maybe print them
@@ -191,6 +196,7 @@ def batch_generate(
191196 parser .add_argument ("--output-file" , type = str , default = None )
192197 parser .add_argument ("--compare" , action = "store_true" , default = False )
193198 parser .add_argument ("--metrics" , action = "store_true" , default = False )
199+ parser .add_argument ("--profile" , type = str , default = None )
194200 args = parser .parse_args ()
195201
196202 # If turned on, we setup metrics
@@ -208,6 +214,9 @@ def batch_generate(
208214 dtype = torch .bfloat16 ,
209215 )
210216 model = model .cuda ().eval ()
217+ if getattr (model .config , "sliding_window" , None ) is not None :
218+ print (f"Setting sliding window from { model .config .sliding_window } to { SLIDING_WINDOW } " )
219+ model .config .sliding_window = SLIDING_WINDOW
211220
212221 # If turned on, we compile the model
213222 if args .compile :
@@ -218,16 +227,17 @@ def batch_generate(
218227
219228 # Prepare tokenizer and dataset
220229 tokenizer = AutoTokenizer .from_pretrained (MODEL_ID , padding_side = "left" )
230+
221231 dataset = datasets .load_dataset ("openai/gsm8k" , "socratic" , split = "test" )
222- dataset = dataset .select (range (args .samples )) # Use only 5 examples for the simple version
223- tokenized_datasets = dataset . map ( lambda x : tokenizer ( x [ "question" ]), batched = True )
224- simple_batch_inputs = [item ["input_ids" ] for item in tokenized_datasets ]
232+ dataset = dataset .select (range (args .samples ))
233+
234+ simple_batch_inputs = [tokenizer ( item ["question" ])[ " input_ids" ] for item in dataset ]
225235
226236 # Prepare generation config
227237 generation_config = GenerationConfig (
228238 max_new_tokens = 512 ,
229239 use_cuda_graph = args .use_cuda_graph ,
230- eos_token_id = tokenizer .eos_token_id ,
240+ eos_token_id = tokenizer .pad_token_id if FORCE_MAX_LENGTH else tokenizer . eos_token_id ,
231241 pad_token_id = tokenizer .pad_token_id ,
232242 do_sample = True ,
233243 temperature = 0.8 ,
@@ -247,7 +257,7 @@ def batch_generate(
247257 f"runs/cb/{ args .num_blocks } _{ args .max_batch_tokens } _{ attn } _{ args .matmul_precision } _{ args .samples } .json"
248258 )
249259
250- # Run warmup batch generation
260+ # Run warmup batch generation # TODO: understand why warmup incurs a large overhead during cache creation
251261 batch_generate (
252262 model ,
253263 simple_batch_inputs [: min (5 , args .samples )],
@@ -257,17 +267,26 @@ def batch_generate(
257267 slice_inputs = args .slice_inputs ,
258268 )
259269
260- # Run batch generation
261- gen_time , tok_per_sec = batch_generate (
262- model ,
263- simple_batch_inputs ,
264- generation_config ,
265- tokenizer ,
266- displayed_samples = args .displayed ,
267- output_file = args .output_file ,
268- expected_outputs = expected_outputs ,
269- slice_inputs = args .slice_inputs ,
270- )
270+ if args .profile is not None :
271+ cm = profile (activities = [ProfilerActivity .CPU , ProfilerActivity .CUDA ], record_shapes = True )
272+ else :
273+ cm = contextlib .nullcontext ()
274+ with cm as prof :
275+ # Run batch generation
276+ gen_time , tok_per_sec = batch_generate (
277+ model ,
278+ simple_batch_inputs ,
279+ generation_config ,
280+ tokenizer ,
281+ displayed_samples = args .displayed ,
282+ output_file = args .output_file ,
283+ expected_outputs = expected_outputs ,
284+ slice_inputs = args .slice_inputs ,
285+ )
286+ if args .profile is not None :
287+ filename = args .profile if args .profile .endswith (".json" ) else args .profile + ".json"
288+ prof .export_chrome_trace (filename )
271289
272290# Example usage:
291+ # python examples/pytorch/continuous_batching.py --attn sdpa_paged -mp none --slice-inputs --samples 3 --compare
273292# python examples/pytorch/continuous_batching.py --num-blocks 369 --max-batch-tokens 23 --attn sdpa_paged -mp none --samples 1 --displayed 0 --output-file sliced.json
0 commit comments