|
5 | 5 | import tempfile |
6 | 6 |
|
7 | 7 | import depyf |
| 8 | +import pytest |
8 | 9 |
|
9 | 10 | from vllm.config import CompilationLevel |
10 | 11 |
|
11 | | -temp_dir = tempfile.mkdtemp() |
12 | | -with depyf.prepare_debug(temp_dir): |
13 | | - from vllm import LLM, SamplingParams |
14 | | - |
15 | | - prompts = [ |
16 | | - "A robot may not injure a human being", |
17 | | - "It is only with the heart that one can see rightly;", |
18 | | - "The greatest glory in living lies not in never falling,", |
19 | | - ] |
20 | | - answers = [ |
21 | | - " or, through inaction, allow a human being to come to harm.", |
22 | | - " what is essential is invisible to the eye.", |
23 | | - " but in rising every time we fall.", |
24 | | - ] |
25 | | - N = 1 |
26 | | - # Currently, top-p sampling is disabled. `top_p` should be 1.0. |
27 | | - sampling_params = SamplingParams(temperature=0.7, |
28 | | - top_p=1.0, |
29 | | - n=N, |
30 | | - max_tokens=16) |
31 | | - |
32 | | - # Set `enforce_eager=True` to avoid ahead-of-time compilation. |
33 | | - # In real workloads, `enforace_eager` should be `False`. |
34 | | - |
35 | | - # disable custom dispatcher, let Dynamo takes over |
36 | | - # all the control |
37 | | - llm = LLM(model="Qwen/Qwen2.5-1.5B-Instruct", |
38 | | - max_model_len=512, |
39 | | - max_num_seqs=64, |
40 | | - enforce_eager=True, |
41 | | - compilation_config={"level": CompilationLevel.DYNAMO_AS_IS}) |
42 | | - outputs = llm.generate(prompts, sampling_params) |
43 | | - for output, answer in zip(outputs, answers): |
44 | | - prompt = output.prompt |
45 | | - generated_text = output.outputs[0].text |
46 | | - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") |
47 | | - assert generated_text.startswith(answer) |
48 | | - |
49 | | -compiled_codes = sorted( |
50 | | - glob.glob(os.path.join(temp_dir, "__transformed_code*.py"))) |
51 | | - |
52 | | -for i, compiled_code in enumerate(compiled_codes): |
53 | | - print("{} file: {}".format(i + 1, compiled_code)) |
54 | | - |
55 | | -# We should only trigger Dynamo compilation 4 times: |
56 | | -# 1. forward pass (symbolic) |
57 | | -# 2. compute_logits (symbolic) |
58 | | -# 3. forward pass (shape 16) |
59 | | -# 4. forward pass (shape 32) |
60 | | -# and later calls should not trigger Dynamo compilation again. |
61 | | -# NOTE: It might still trigger XLA compilation. |
62 | | - |
63 | | -# Check we have 4 compiled codes |
64 | | -assert len(compiled_codes) == 4 |
65 | | - |
66 | | -kv_cache_prefix = "kv_cache" |
67 | | -attn_prefix = "ragged_paged_attention" |
68 | | - |
69 | | -# Check all the compilations are as expected |
70 | | -compiled_fns = sorted( |
71 | | - glob.glob(os.path.join(temp_dir, "__compiled_fn*Captured*.py"))) |
72 | | - |
73 | | -for i, compiled_fn in enumerate(compiled_fns): |
74 | | - print("{} file: {}".format(i + 1, compiled_fn)) |
75 | | - |
76 | | -# The first compilation is symbolic, so it should not have any kv_caches |
77 | | -with open(compiled_fns[0]) as f: |
78 | | - content = f.read() |
79 | | - assert kv_cache_prefix not in content |
80 | | - |
81 | | -# The second compilation is symbolic, so it should not have any kv_caches |
82 | | -with open(compiled_fns[1]) as f: |
83 | | - content = f.read() |
84 | | - assert kv_cache_prefix not in content |
85 | | - |
86 | | -# The third compilation is shape 16, so it should have kv_caches and the |
87 | | -# ragged_paged_attention |
88 | | -with open(compiled_fns[2]) as f: |
89 | | - content = f.read() |
90 | | - assert (kv_cache_prefix in content and attn_prefix in content) |
91 | | - |
92 | | -# The forth compilation is shape 32, so it should have kv_caches and the |
93 | | -# ragged_paged_attention |
94 | | -with open(compiled_fns[3]) as f: |
95 | | - content = f.read() |
96 | | - assert (kv_cache_prefix in content and attn_prefix in content) |
| 12 | + |
| 13 | +@pytest.mark.skip(reason="Not working; needs investigation.") |
| 14 | +def test_tpu_compilation(): |
| 15 | + temp_dir = tempfile.mkdtemp() |
| 16 | + with depyf.prepare_debug(temp_dir): |
| 17 | + from vllm import LLM, SamplingParams |
| 18 | + |
| 19 | + prompts = [ |
| 20 | + "A robot may not injure a human being", |
| 21 | + "It is only with the heart that one can see rightly;", |
| 22 | + "The greatest glory in living lies not in never falling,", |
| 23 | + ] |
| 24 | + answers = [ |
| 25 | + " or, through inaction, allow a human being to come to harm.", |
| 26 | + " what is essential is invisible to the eye.", |
| 27 | + " but in rising every time we fall.", |
| 28 | + ] |
| 29 | + N = 1 |
| 30 | + # Currently, top-p sampling is disabled. `top_p` should be 1.0. |
| 31 | + sampling_params = SamplingParams(temperature=0.7, |
| 32 | + top_p=1.0, |
| 33 | + n=N, |
| 34 | + max_tokens=16) |
| 35 | + |
| 36 | + # Set `enforce_eager=True` to avoid ahead-of-time compilation. |
| 37 | + # In real workloads, `enforace_eager` should be `False`. |
| 38 | + |
| 39 | + # disable custom dispatcher, let Dynamo takes over |
| 40 | + # all the control |
| 41 | + llm = LLM(model="Qwen/Qwen2.5-1.5B-Instruct", |
| 42 | + max_model_len=512, |
| 43 | + max_num_seqs=64, |
| 44 | + enforce_eager=True, |
| 45 | + compilation_config={"level": CompilationLevel.DYNAMO_AS_IS}) |
| 46 | + outputs = llm.generate(prompts, sampling_params) |
| 47 | + for output, answer in zip(outputs, answers): |
| 48 | + prompt = output.prompt |
| 49 | + generated_text = output.outputs[0].text |
| 50 | + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") |
| 51 | + assert generated_text.startswith(answer) |
| 52 | + |
| 53 | + compiled_codes = sorted( |
| 54 | + glob.glob(os.path.join(temp_dir, "__transformed_code*.py"))) |
| 55 | + |
| 56 | + for i, compiled_code in enumerate(compiled_codes): |
| 57 | + print("{} file: {}".format(i + 1, compiled_code)) |
| 58 | + |
| 59 | + # We should only trigger Dynamo compilation 4 times: |
| 60 | + # 1. forward pass (symbolic) |
| 61 | + # 2. compute_logits (symbolic) |
| 62 | + # 3. forward pass (shape 16) |
| 63 | + # 4. forward pass (shape 32) |
| 64 | + # and later calls should not trigger Dynamo compilation again. |
| 65 | + # NOTE: It might still trigger XLA compilation. |
| 66 | + |
| 67 | + # Check we have 4 compiled codes |
| 68 | + assert len(compiled_codes) == 4 |
| 69 | + |
| 70 | + kv_cache_prefix = "kv_cache" |
| 71 | + attn_prefix = "ragged_paged_attention" |
| 72 | + |
| 73 | + # Check all the compilations are as expected |
| 74 | + compiled_fns = sorted( |
| 75 | + glob.glob(os.path.join(temp_dir, "__compiled_fn*Captured*.py"))) |
| 76 | + |
| 77 | + for i, compiled_fn in enumerate(compiled_fns): |
| 78 | + print("{} file: {}".format(i + 1, compiled_fn)) |
| 79 | + |
| 80 | + # The first compilation is symbolic, so it should not have any kv_caches |
| 81 | + with open(compiled_fns[0]) as f: |
| 82 | + content = f.read() |
| 83 | + assert kv_cache_prefix not in content |
| 84 | + |
| 85 | + # The second compilation is symbolic, so it should not have any kv_caches |
| 86 | + with open(compiled_fns[1]) as f: |
| 87 | + content = f.read() |
| 88 | + assert kv_cache_prefix not in content |
| 89 | + |
| 90 | + # The third compilation is shape 16, so it should have kv_caches and the |
| 91 | + # ragged_paged_attention |
| 92 | + with open(compiled_fns[2]) as f: |
| 93 | + content = f.read() |
| 94 | + assert (kv_cache_prefix in content and attn_prefix in content) |
| 95 | + |
| 96 | + # The forth compilation is shape 32, so it should have kv_caches and the |
| 97 | + # ragged_paged_attention |
| 98 | + with open(compiled_fns[3]) as f: |
| 99 | + content = f.read() |
| 100 | + assert (kv_cache_prefix in content and attn_prefix in content) |
0 commit comments