@@ -45,8 +45,12 @@ def main():
4545 parser .add_argument ("--enable_chunked_prefill" , action = 'store_true' )
4646 parser .add_argument ("--max_num_batched_tokens" , type = int , default = 2048 )
4747 parser .add_argument ("--temp" , type = float , default = 0 )
48+ parser .add_argument ("--use_v1" , type = str , default = "1" , help = '1 or 0' )
4849 args = parser .parse_args ()
4950
51+ # TODO: remove this option once EAGLE in v1 is ready.
52+ os .environ ["VLLM_USE_V1" ] = args .use_v1
53+
5054 model_dir = "meta-llama/Meta-Llama-3-8B-Instruct"
5155 eagle_dir = "abhigoyal/EAGLE-LLaMA3-Instruct-8B-vllm"
5256
@@ -94,10 +98,16 @@ def main():
9498 # to account for the token from the target model that's always going to be
9599 # accepted
96100 acceptance_counts = [0 ] * (args .num_spec_tokens + 1 )
97- for output in outputs :
98- for step , count in enumerate (
99- output .metrics .spec_token_acceptance_counts ):
100- acceptance_counts [step ] += count
101+ if args .use_v1 == '1' :
102+ for output in outputs :
103+ for step , count in enumerate (
104+ output .spec_token_acceptance_counts [0 ]):
105+ acceptance_counts [step ] += count
106+ else :
107+ for output in outputs :
108+ for step , count in enumerate (
109+ output .metrics .spec_token_acceptance_counts ):
110+ acceptance_counts [step ] += count
101111
102112 print ("-" * 50 )
103113 print (f"mean acceptance length: \
0 commit comments