|
7 | 7 |
|
8 | 8 | from vllm import LLM, SamplingParams |
9 | 9 |
|
10 | | -parser = argparse.ArgumentParser() |
11 | | - |
12 | | -parser.add_argument( |
13 | | - "--dataset", |
14 | | - type=str, |
15 | | - default="./examples/data/gsm8k.jsonl", |
16 | | - help="downloaded from the eagle repo " \ |
17 | | - "https://github.com/SafeAILab/EAGLE/blob/main/eagle/data/" |
18 | | -) |
19 | | -parser.add_argument("--max_num_seqs", type=int, default=8) |
20 | | -parser.add_argument("--num_prompts", type=int, default=80) |
21 | | -parser.add_argument("--num_spec_tokens", type=int, default=2) |
22 | | -parser.add_argument("--tp", type=int, default=1) |
23 | | -parser.add_argument("--draft_tp", type=int, default=1) |
24 | | -parser.add_argument("--enforce_eager", action='store_true') |
25 | | -parser.add_argument("--enable_chunked_prefill", action='store_true') |
26 | | -parser.add_argument("--max_num_batched_tokens", type=int, default=2048) |
27 | | -parser.add_argument("--temp", type=float, default=0) |
28 | | - |
29 | | -args = parser.parse_args() |
30 | | - |
31 | | -print(args) |
32 | | - |
33 | | -model_dir = "meta-llama/Meta-Llama-3-8B-Instruct" |
34 | | -eagle_dir = "abhigoyal/EAGLE-LLaMA3-Instruct-8B-vllm" |
35 | | - |
36 | | -max_model_len = 2048 |
37 | | - |
38 | | -tokenizer = AutoTokenizer.from_pretrained(model_dir) |
39 | | - |
40 | | -if os.path.exists(args.dataset): |
41 | | - prompts = [] |
42 | | - num_prompts = args.num_prompts |
43 | | - with open(args.dataset) as f: |
44 | | - for line in f: |
45 | | - data = json.loads(line) |
46 | | - prompts.append(data["turns"][0]) |
47 | | -else: |
48 | | - prompts = ["The future of AI is", "The president of the United States is"] |
49 | | - |
50 | | -prompts = prompts[:args.num_prompts] |
51 | | -num_prompts = len(prompts) |
52 | | - |
53 | | -prompt_ids = [ |
54 | | - tokenizer.apply_chat_template([{ |
55 | | - "role": "user", |
56 | | - "content": prompt |
57 | | - }], |
58 | | - add_generation_prompt=True) |
59 | | - for prompt in prompts |
60 | | -] |
61 | | - |
62 | | -llm = LLM( |
63 | | - model=model_dir, |
64 | | - trust_remote_code=True, |
65 | | - tensor_parallel_size=args.tp, |
66 | | - enable_chunked_prefill=args.enable_chunked_prefill, |
67 | | - max_num_batched_tokens=args.max_num_batched_tokens, |
68 | | - enforce_eager=args.enforce_eager, |
69 | | - max_model_len=max_model_len, |
70 | | - max_num_seqs=args.max_num_seqs, |
71 | | - gpu_memory_utilization=0.8, |
72 | | - speculative_config={ |
73 | | - "model": eagle_dir, |
74 | | - "num_speculative_tokens": args.num_spec_tokens, |
75 | | - "draft_tensor_parallel_size": args.draft_tp, |
76 | | - "max_model_len": max_model_len, |
77 | | - }, |
78 | | - disable_log_stats=False, |
79 | | -) |
80 | | - |
81 | | -sampling_params = SamplingParams(temperature=args.temp, max_tokens=256) |
82 | | - |
83 | | -outputs = llm.generate(prompt_token_ids=prompt_ids, |
84 | | - sampling_params=sampling_params) |
85 | | - |
86 | | -# calculate the average number of accepted tokens per forward pass, +1 is |
87 | | -# to account for the token from the target model that's always going to be |
88 | | -# accepted |
89 | | -acceptance_counts = [0] * (args.num_spec_tokens + 1) |
90 | | -for output in outputs: |
91 | | - for step, count in enumerate(output.metrics.spec_token_acceptance_counts): |
92 | | - acceptance_counts[step] += count |
93 | | - |
94 | | -print(f"mean acceptance length: \ |
95 | | - {sum(acceptance_counts) / acceptance_counts[0]:.2f}") |
| 10 | + |
| 11 | +def load_prompts(dataset_path, num_prompts): |
| 12 | + if os.path.exists(dataset_path): |
| 13 | + prompts = [] |
| 14 | + try: |
| 15 | + with open(dataset_path) as f: |
| 16 | + for line in f: |
| 17 | + data = json.loads(line) |
| 18 | + prompts.append(data["turns"][0]) |
| 19 | + except Exception as e: |
| 20 | + print(f"Error reading dataset: {e}") |
| 21 | + return [] |
| 22 | + else: |
| 23 | + prompts = [ |
| 24 | + "The future of AI is", "The president of the United States is" |
| 25 | + ] |
| 26 | + |
| 27 | + return prompts[:num_prompts] |
| 28 | + |
| 29 | + |
| 30 | +def main(): |
| 31 | + parser = argparse.ArgumentParser() |
| 32 | + parser.add_argument( |
| 33 | + "--dataset", |
| 34 | + type=str, |
| 35 | + default="./examples/data/gsm8k.jsonl", |
| 36 | + help="downloaded from the eagle repo " \ |
| 37 | + "https://github.com/SafeAILab/EAGLE/blob/main/eagle/data/" |
| 38 | + ) |
| 39 | + parser.add_argument("--max_num_seqs", type=int, default=8) |
| 40 | + parser.add_argument("--num_prompts", type=int, default=80) |
| 41 | + parser.add_argument("--num_spec_tokens", type=int, default=2) |
| 42 | + parser.add_argument("--tp", type=int, default=1) |
| 43 | + parser.add_argument("--draft_tp", type=int, default=1) |
| 44 | + parser.add_argument("--enforce_eager", action='store_true') |
| 45 | + parser.add_argument("--enable_chunked_prefill", action='store_true') |
| 46 | + parser.add_argument("--max_num_batched_tokens", type=int, default=2048) |
| 47 | + parser.add_argument("--temp", type=float, default=0) |
| 48 | + args = parser.parse_args() |
| 49 | + |
| 50 | + model_dir = "meta-llama/Meta-Llama-3-8B-Instruct" |
| 51 | + eagle_dir = "abhigoyal/EAGLE-LLaMA3-Instruct-8B-vllm" |
| 52 | + |
| 53 | + max_model_len = 2048 |
| 54 | + |
| 55 | + tokenizer = AutoTokenizer.from_pretrained(model_dir) |
| 56 | + |
| 57 | + prompts = load_prompts(args.dataset, args.num_prompts) |
| 58 | + |
| 59 | + prompt_ids = [ |
| 60 | + tokenizer.apply_chat_template([{ |
| 61 | + "role": "user", |
| 62 | + "content": prompt |
| 63 | + }], |
| 64 | + add_generation_prompt=True) |
| 65 | + for prompt in prompts |
| 66 | + ] |
| 67 | + |
| 68 | + llm = LLM( |
| 69 | + model=model_dir, |
| 70 | + trust_remote_code=True, |
| 71 | + tensor_parallel_size=args.tp, |
| 72 | + enable_chunked_prefill=args.enable_chunked_prefill, |
| 73 | + max_num_batched_tokens=args.max_num_batched_tokens, |
| 74 | + enforce_eager=args.enforce_eager, |
| 75 | + max_model_len=max_model_len, |
| 76 | + max_num_seqs=args.max_num_seqs, |
| 77 | + gpu_memory_utilization=0.8, |
| 78 | + speculative_config={ |
| 79 | + "model": eagle_dir, |
| 80 | + "num_speculative_tokens": args.num_spec_tokens, |
| 81 | + "draft_tensor_parallel_size": args.draft_tp, |
| 82 | + "max_model_len": max_model_len, |
| 83 | + }, |
| 84 | + disable_log_stats=False, |
| 85 | + ) |
| 86 | + |
| 87 | + sampling_params = SamplingParams(temperature=args.temp, max_tokens=256) |
| 88 | + |
| 89 | + outputs = llm.generate(prompt_token_ids=prompt_ids, |
| 90 | + sampling_params=sampling_params) |
| 91 | + |
| 92 | + # calculate the average number of accepted tokens per forward pass, +1 is |
| 93 | + # to account for the token from the target model that's always going to be |
| 94 | + # accepted |
| 95 | + acceptance_counts = [0] * (args.num_spec_tokens + 1) |
| 96 | + for output in outputs: |
| 97 | + for step, count in enumerate( |
| 98 | + output.metrics.spec_token_acceptance_counts): |
| 99 | + acceptance_counts[step] += count |
| 100 | + |
| 101 | + print("-" * 50) |
| 102 | + print(f"mean acceptance length: \ |
| 103 | + {sum(acceptance_counts) / acceptance_counts[0]:.2f}") |
| 104 | + print("-" * 50) |
| 105 | + |
| 106 | + |
| 107 | +if __name__ == "__main__": |
| 108 | + main() |
0 commit comments