@@ -48,7 +48,7 @@ def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = Non
4848 return probs
4949
5050def sample (logits , temperature : float = 1.0 , top_k : Optional [int ] = None ):
51- probs = logits_to_probs (logits [0 , - 1 ], temperature , top_k )
51+ probs = logits_to_probs (logits [: , - 1 ], temperature , top_k )
5252 idx_next = multinomial_sample_one_no_sync (probs )
5353 return idx_next , probs
5454
@@ -75,7 +75,7 @@ def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torc
7575 new_tokens .append (next_token )
7676 callback (new_tokens [- 1 ])
7777 new_probs .append (next_prob )
78- cur_token = next_token . view ( 1 , - 1 )
78+ cur_token = next_token
7979
8080 return new_tokens , new_probs
8181
@@ -88,6 +88,7 @@ def generate(
8888 model : Transformer ,
8989 prompt : torch .Tensor ,
9090 max_new_tokens : int ,
91+ batch_size : int ,
9192 * ,
9293 interactive : bool ,
9394 callback = lambda x : x ,
@@ -102,34 +103,34 @@ def generate(
102103
103104 # create an empty tensor of the expected final shape and fill in the current tokens
104105 device = prompt .device
105- T = prompt .numel ( )
106+ T = prompt .size ( - 1 )
106107
107108 # calculate how many tokens to generate based on max_new_tokens and model's upper bound (block_size)
108109 max_seq_length = min (T + max_new_tokens , model .config .block_size ) if not interactive else 350
109110 new_tokens = max_seq_length - T
110111
112+ # format model input
113+ prompt , input_pos = prepare_inputs_for_model (prompt )
114+ prompt = prompt .repeat (batch_size , 1 ) # expand prompt based on batchsize
115+
111116 # full prompt+output will be stored in seq
112- seq = torch .empty (max_seq_length , dtype = prompt .dtype , device = device )
113- seq [:T ] = prompt . view ( - 1 )
117+ seq = torch .empty (batch_size , max_seq_length , dtype = prompt .dtype , device = device )
118+ seq [:, : T ] = prompt
114119
115120 # setup model caches
116121 with torch .device (device ):
117122 if cache_size is None :
118123 cache_size = max_seq_length
119124 assert cache_size >= max_seq_length , "need cache_size to be greater than max_new_tokens + size-of-prompt"
120- model .setup_caches (max_batch_size = 1 , max_seq_length = cache_size , kv_cache_quantization = kv_cache_quantization , linear_causal_mask = linear_causal_mask , prompt_length = T )
121-
122- # format model input
123- x , input_pos = prepare_inputs_for_model (prompt , max_new_tokens )
125+ model .setup_caches (max_batch_size = batch_size , max_seq_length = cache_size , kv_cache_quantization = kv_cache_quantization , linear_causal_mask = linear_causal_mask , prompt_length = T )
124126
125127 # execute prefill
126- next_token = prefill (model , x , input_pos , ** sampling_kwargs ).clone ()
127- seq [T ] = next_token
128+ next_token = prefill (model , prompt . view ( batch_size , - 1 ) , input_pos , ** sampling_kwargs ).clone ()
129+ seq [:, T ] = next_token . squeeze ()
128130 # execute token generation
129131 input_pos = torch .tensor ([T ], device = device , dtype = torch .int )
130- generated_tokens , _ = decode_n_tokens (model , next_token .view (1 , - 1 ), input_pos , new_tokens - 1 , callback = callback , ** sampling_kwargs )
131-
132- seq = torch .cat ((seq [:T + 1 ], * generated_tokens ))
132+ generated_tokens , _ = decode_n_tokens (model , next_token .view (batch_size , - 1 ), input_pos , new_tokens - 1 , callback = callback , ** sampling_kwargs )
133+ seq = torch .cat ((seq [:, :T + 1 ], * generated_tokens ), dim = - 1 )
133134
134135 return seq
135136
@@ -157,6 +158,7 @@ def main(
157158 interactive : bool = False ,
158159 num_samples : int = 5 ,
159160 max_new_tokens : int = 100 ,
161+ batch_size : int = 1 ,
160162 top_k : int = 200 ,
161163 temperature : float = 0.8 ,
162164 checkpoint_path : Path = Path ("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth" ),
@@ -229,9 +231,9 @@ def main(
229231 use_hqq = True
230232 else :
231233 use_hqq = False
232- groupsize = int (quantization .split ("-" )[1 ])
233- assert groupsize in [32 ,64 ,128 ,256 ], f"int4wo groupsize needs to be one of [32,64,128,256] but got { groupsize } "
234- quantize_ (model , int4_weight_only (group_size = groupsize ))
234+ group_size = int (quantization .split ("-" )[1 ])
235+ assert group_size in [32 ,64 ,128 ,256 ], f"int4wo group_size needs to be one of [32,64,128,256] but got { group_size } "
236+ quantize_ (model , int4_weight_only (group_size = group_size ))
235237 if "marlin" in quantization :
236238 from torchao .dtypes import MarlinSparseLayout
237239 quantize_ (model , int4_weight_only (layout = MarlinSparseLayout ()))
@@ -267,9 +269,9 @@ def main(
267269 use_hqq = "hqq" in quantization
268270 quantize_ (model , awq_uintx (quant_dtype = quant_dtype , group_size = group_size , use_hqq = use_hqq ), is_observed_linear )
269271 if "uintx" in quantization :
270- # uintx-nbits-groupsize , e.g. "uintx-2-64"
272+ # uintx-nbits-group_size , e.g. "uintx-2-64"
271273 if "hqq" in quantization :
272- # uintx-nbits-groupsize -hqq
274+ # uintx-nbits-group_size -hqq
273275 use_hqq = True
274276 else :
275277 use_hqq = False
@@ -303,6 +305,7 @@ def main(
303305 model ,
304306 encode_tokens (tokenizer , prompt , bos = True , device = device ),
305307 max_new_tokens ,
308+ batch_size ,
306309 interactive = False ,
307310 temperature = temperature ,
308311 top_k = top_k ,
@@ -375,6 +378,7 @@ def callback(x):
375378 model ,
376379 encoded ,
377380 max_new_tokens ,
381+ batch_size ,
378382 interactive = interactive ,
379383 callback = callback ,
380384 temperature = temperature ,
@@ -392,13 +396,13 @@ def callback(x):
392396 t = time .perf_counter () - t0
393397
394398 if not interactive :
395- tok_list = y .tolist ()
399+ tok_list = y [ 0 ] .tolist ()
396400 # truncate text after end of string token
397- tokens = tok_list if not tokenizer .eos_id () in y else tok_list [:tok_list .index (tokenizer .eos_id ())]
401+ tokens = tok_list if not tokenizer .eos_id () in tok_list else tok_list [:tok_list .index (tokenizer .eos_id ())]
398402 print (tokenizer .decode (tokens ))
399403 else :
400404 print ()
401- tokens_generated = y .size (0 ) - prompt_length
405+ tokens_generated = ( y .size (- 1 ) - prompt_length )
402406 tokens_sec = tokens_generated / t
403407 aggregate_metrics ['tokens_per_sec' ].append (tokens_sec )
404408 print (f"Time for inference { i + 1 } : { t :.02f} sec total, { tokens_sec :.02f} tokens/sec" )
@@ -421,6 +425,8 @@ def callback(x):
421425 bandwidth = model_size * tokpersec
422426 mem = torch .cuda .max_memory_reserved () / 1e9
423427 print (f"Average tokens/sec: { tokpersec :.2f} " )
428+ if batch_size > 1 :
429+ print (f"Average tokens/sec including batches { batch_size * tokpersec :.2f} " )
424430 print (f"Average Bandwidth: { bandwidth :.02f} GB/s" )
425431 print (f"Peak Memory Usage: { mem :.02f} GB" )
426432 print (f"Model Size: { model_size :.02f} GB" )
@@ -439,6 +445,7 @@ def callback(x):
439445 result_txt += f"--interactive " if interactive else ""
440446 result_txt += f"--num_samples { num_samples } "
441447 result_txt += f"--max_new_tokens { max_new_tokens } "
448+ result_txt += f"--batch_size { batch_size } "
442449 result_txt += f"--top_k { top_k } "
443450 result_txt += f"--temperature { temperature } "
444451 result_txt += f"--cache_size { cache_size } " if cache_size else ""
@@ -459,13 +466,15 @@ def callback(x):
459466 parser .add_argument ('--interactive' , action = 'store_true' , help = 'Whether to launch in interactive mode' )
460467 parser .add_argument ('--num_samples' , type = int , default = 5 , help = 'Number of samples.' )
461468 parser .add_argument ('--max_new_tokens' , type = int , default = 200 , help = 'Maximum number of new tokens.' )
469+ parser .add_argument ('--batch_size' , type = int , default = 1 , help = 'Batch size to benchmark with' )
462470 parser .add_argument ('--top_k' , type = int , default = 200 , help = 'Top-k for sampling.' )
463471 parser .add_argument ('--temperature' , type = float , default = 0.8 , help = 'Temperature for sampling.' )
464472 parser .add_argument ('--checkpoint_path' , type = Path , default = Path ("../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth" ), help = 'Model checkpoint path.' )
465473 parser .add_argument ('-q' , '--quantization' , type = str ,
466474 help = (
467475 'Which quantization techniques to apply: int8dq, int8wo, fp6, int4wo-<groupsize>, int4wo-<groupsize>-hqq, autoquant, '
468- + 'autoquant-int4, autoquant-float8, uintx-<nbits>-<groupsize>, uintx-<nbits>-<groupsize>-hqq, sparse-marlin, spinquant, embed-int8wo'
476+ + 'autoquant-int4, autoquant-float8, uintx-<nbits>-<groupsize>, uintx-<nbits>-<groupsize>-hqq, sparse-marlin, spinquant, '
477+ + 'embed-int8wo'
469478 )
470479 )
471480 parser .add_argument ("--calibration_limit" , type = int , default = 10 , help = "Number of calibration examples" )
@@ -484,6 +493,6 @@ def callback(x):
484493
485494 args = parser .parse_args ()
486495 main (
487- args .prompt , args .interactive , args .num_samples , args .max_new_tokens , args .top_k ,
496+ args .prompt , args .interactive , args .num_samples , args .max_new_tokens , args .batch_size , args . top_k ,
488497 args .temperature , args .checkpoint_path , args .quantization , args .calibration_limit , args .calibration_seq_length , args .kv_cache_quantization , args .cache_size , args .linear_causal_mask , args .save , args .compile , args .compile_prefill , args .profile , args .memory_profile , args .device , args .precision , args .write_result
489498 )
0 commit comments