@@ -39,16 +39,20 @@ def parse_args():
3939 parser .add_argument ("--top-k" , type = int , default = - 1 )
4040 parser .add_argument ("--print-output" , action = "store_true" )
4141 parser .add_argument ("--output-len" , type = int , default = 256 )
42+ parser .add_argument ("--model-dir" , type = str , default = None )
43+ parser .add_argument ("--eagle-dir" , type = str , default = None )
44+ parser .add_argument ("--max-model-len" , type = int , default = 2048 )
4245 return parser .parse_args ()
4346
4447
4548def main ():
4649 args = parse_args ()
4750 args .endpoint_type = "openai-chat"
4851
49- model_dir = "meta-llama/Llama-3.1-8B-Instruct"
52+ model_dir = args .model_dir
53+ if args .model_dir is None :
54+ model_dir = "meta-llama/Llama-3.1-8B-Instruct"
5055 tokenizer = AutoTokenizer .from_pretrained (model_dir )
51- max_model_len = 2048
5256
5357 prompts = get_samples (args , tokenizer )
5458 # add_special_tokens is False to avoid adding bos twice when using chat templates
@@ -57,24 +61,26 @@ def main():
5761 ]
5862
5963 if args .method == "eagle" or args .method == "eagle3" :
60- if args .method == "eagle" :
64+ eagle_dir = args .eagle_dir
65+ if args .method == "eagle" and eagle_dir is None :
6166 eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
62- elif args .method == "eagle3" :
67+
68+ elif args .method == "eagle3" and eagle_dir is None :
6369 eagle_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
6470 speculative_config = {
6571 "method" : args .method ,
6672 "model" : eagle_dir ,
6773 "num_speculative_tokens" : args .num_spec_tokens ,
6874 "draft_tensor_parallel_size" : args .draft_tp ,
69- "max_model_len" : max_model_len ,
75+ "max_model_len" : args . max_model_len ,
7076 }
7177 elif args .method == "ngram" :
7278 speculative_config = {
7379 "method" : "ngram" ,
7480 "num_speculative_tokens" : args .num_spec_tokens ,
7581 "prompt_lookup_max" : args .prompt_lookup_max ,
7682 "prompt_lookup_min" : args .prompt_lookup_min ,
77- "max_model_len" : max_model_len ,
83+ "max_model_len" : args . max_model_len ,
7884 }
7985 else :
8086 raise ValueError (f"unknown method: { args .method } " )
@@ -86,7 +92,7 @@ def main():
8692 enable_chunked_prefill = args .enable_chunked_prefill ,
8793 max_num_batched_tokens = args .max_num_batched_tokens ,
8894 enforce_eager = args .enforce_eager ,
89- max_model_len = max_model_len ,
95+ max_model_len = args . max_model_len ,
9096 max_num_seqs = args .max_num_seqs ,
9197 gpu_memory_utilization = 0.8 ,
9298 speculative_config = speculative_config ,
0 commit comments