|
2 | 2 |
|
3 | 3 | import numpy as np |
4 | 4 |
|
| 5 | +from vllm.config import SpeculativeConfig, VllmConfig |
5 | 6 | from vllm.v1.spec_decode.ngram_proposer import (NgramProposer, |
6 | 7 | _find_subarray_kmp, |
7 | 8 | _kmp_lps_array) |
@@ -39,50 +40,40 @@ def test_find_subarray_kmp(): |
39 | 40 |
|
40 | 41 |
|
41 | 42 | def test_ngram_proposer(): |
42 | | - proposer = NgramProposer() |
| 43 | + |
| 44 | + def ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer: |
| 45 | + return NgramProposer(vllm_config=VllmConfig( |
| 46 | + speculative_config=SpeculativeConfig.from_dict( |
| 47 | + { |
| 48 | + "prompt_lookup_min": min_n, |
| 49 | + "prompt_lookup_max": max_n, |
| 50 | + "num_speculative_tokens": k, |
| 51 | + "method": "ngram", |
| 52 | + }))) |
43 | 53 |
|
44 | 54 | # No match. |
45 | | - result = proposer.propose( |
46 | | - context_token_ids=np.array([1, 2, 3, 4, 5]), |
47 | | - min_n=2, |
48 | | - max_n=2, |
49 | | - k=2, |
50 | | - ) |
| 55 | + result = ngram_proposer( |
| 56 | + 2, 2, 2).propose(context_token_ids=np.array([1, 2, 3, 4, 5])) |
51 | 57 | assert result is None |
52 | 58 |
|
53 | 59 | # No match for 4-gram. |
54 | | - result = proposer.propose( |
55 | | - context_token_ids=np.array([1, 2, 3, 4, 1, 2, 3]), |
56 | | - min_n=4, |
57 | | - max_n=4, |
58 | | - k=2, |
59 | | - ) |
| 60 | + result = ngram_proposer( |
| 61 | + 4, 4, 2).propose(context_token_ids=np.array([1, 2, 3, 4, 1, 2, 3])) |
60 | 62 | assert result is None |
61 | 63 |
|
62 | 64 | # No match for 4-gram but match for 3-gram. |
63 | | - result = proposer.propose( |
64 | | - context_token_ids=np.array([1, 2, 3, 4, 1, 2, 3]), |
65 | | - min_n=3, |
66 | | - max_n=4, |
67 | | - k=2, |
68 | | - ) |
| 65 | + result = ngram_proposer( |
| 66 | + 3, 4, 2).propose(context_token_ids=np.array([1, 2, 3, 4, 1, 2, 3])) |
69 | 67 | assert np.array_equal(result, np.array([4, 1])) |
70 | 68 |
|
71 | 69 | # Match for both 4-gram and 3-gram. |
72 | 70 | # In this case, the proposer should return the 4-gram match. |
73 | | - result = proposer.propose( |
74 | | - context_token_ids=np.array([2, 3, 4, 5, 1, 2, 3, 4, 1, 2, 3, 4]), |
75 | | - min_n=3, |
76 | | - max_n=4, |
77 | | - k=2, |
78 | | - ) |
| 71 | + result = ngram_proposer(3, 4, 2).propose( |
| 72 | + context_token_ids=np.array([2, 3, 4, 5, 1, 2, 3, 4, 1, 2, 3, 4])) |
79 | 73 | assert np.array_equal(result, np.array([1, 2])) # Not [5, 1] |
80 | 74 |
|
81 | 75 | # Match for 2-gram and 3-gram, but not 4-gram. |
82 | | - result = proposer.propose( |
83 | | - context_token_ids=np.array([3, 4, 5, 2, 3, 4, 1, 2, 3, 4]), |
84 | | - min_n=2, |
85 | | - max_n=4, |
86 | | - k=2, |
87 | | - ) |
| 76 | + result = ngram_proposer( |
| 77 | + 2, 4, |
| 78 | + 2).propose(context_token_ids=np.array([3, 4, 5, 2, 3, 4, 1, 2, 3, 4])) |
88 | 79 | assert np.array_equal(result, np.array([1, 2])) # Not [5, 2] |
0 commit comments