@@ -108,7 +108,7 @@ def prepare_tokenizer(args):
108108 args .tokenizer_bin is not None
109109 ), "Please provide tokenizer_bin for stories."
110110 runtime_tokenizer_path = args .tokenizer_bin
111- elif args . decoder_model == "llama3_2" :
111+ elif "llama3_2" in args . decoder_model :
112112 tokenizer = get_tokenizer (args .tokenizer_model )
113113 assert isinstance (
114114 tokenizer , TiktokenTokenizer
@@ -240,7 +240,7 @@ def prequant_algorithm(model, prefill_config, args):
240240
241241 if args .range_setting == "mse_with_act_loss" :
242242 wrapped_model = WrappedLlamaModel (
243- model , atten_mask , args .use_kv_cache , args .max_seq_length , args .device
243+ model , * atten_mask , args .use_kv_cache , args .max_seq_length , args .device
244244 )
245245 act_bits , weight_bits = {
246246 "8a8w" : (8 , 8 ),
@@ -355,20 +355,20 @@ def eval_llm(args):
355355
356356 logging .info ("Quantizing the model..." )
357357 model = convert_pt2e (model )
358- logging .info ("Quantization complete! Here is some sample generated text:" )
359-
360- graph_module_inference (
361- use_kv_cache = False ,
362- get_example_inputs = lambda use_kv_cache = False : inputs ,
363- module = model ,
364- tokenizer = tokenizer ,
365- ar_len = args .max_seq_len ,
366- max_seq_len = args .max_seq_len ,
367- kv_updater = args .kv_updater ,
368- prompt = "Can you tell me about Facebook?" ,
369- use_i64_token = use_i64_token ,
370- event_name = "convert_pt2e_prompt" ,
371- )
358+ # logging.info("Quantization complete! Here is some sample generated text:")
359+
360+ # graph_module_inference(
361+ # use_kv_cache=False,
362+ # get_example_inputs=lambda use_kv_cache=False: inputs,
363+ # module=model,
364+ # tokenizer=tokenizer,
365+ # ar_len=args.max_seq_len,
366+ # max_seq_len=args.max_seq_len,
367+ # kv_updater=args.kv_updater,
368+ # prompt="Can you tell me about Facebook?",
369+ # use_i64_token=use_i64_token,
370+ # event_name="convert_pt2e_prompt",
371+ # )
372372
373373 logging .info ("Evaluation of QDQ model:" )
374374 graph_module_inference (
@@ -380,6 +380,7 @@ def eval_llm(args):
380380 max_seq_len = args .max_seq_len ,
381381 kv_updater = args .kv_updater ,
382382 tasks = ["wikitext" ],
383+ tasks_limit = 0.1 ,
383384 use_i64_token = use_i64_token ,
384385 event_name = "convert_pt2e_prompt" ,
385386 )
@@ -424,9 +425,7 @@ def main() -> None:
424425 )
425426 parser .add_argument (
426427 "--decoder_model" ,
427- choices = ["stories260k" , "stories110m" , "llama3_2" ]
428- + list (SUPPORTED_LLM_MODELS .keys ()),
429- help = f"The Llama model to export. Current available options are: [stories260k, stories110m, llama3_2] + { SUPPORTED_LLM_MODELS .keys ()} " ,
428+ help = f"The Llama model to export. Current available options are: { SUPPORTED_LLM_MODELS .keys ()} " ,
430429 required = True ,
431430 )
432431 parser .add_argument (
0 commit comments