99
1010def test_find_longest_matched_ngram_and_propose_tokens ():
1111 tokens = np .array ([1 , 2 , 3 , 4 , 1 , 2 , 3 , 5 , 6 ])
12- assert _find_longest_matched_ngram_and_propose_tokens (origin_tokens = tokens ,
13- min_ngram = 2 ,
14- max_ngram = 2 ,
15- max_model_len = 1024 ,
16- k = 2 ) is None
12+ result = _find_longest_matched_ngram_and_propose_tokens (
13+ origin_tokens = tokens ,
14+ min_ngram = 2 ,
15+ max_ngram = 2 ,
16+ max_model_len = 1024 ,
17+ k = 2 )
18+ assert len (result ) == 0
1719
1820 tokens = np .array ([1 , 2 , 3 , 4 , 1 , 2 , 3 ])
1921 np .testing .assert_array_equal (
@@ -62,7 +64,7 @@ def test_find_longest_matched_ngram_and_propose_tokens():
6264
6365def test_ngram_proposer ():
6466
65- def ngram_proposer (min_n : int , max_n : int , k : int ) -> NgramProposer :
67+ def get_ngram_proposer (min_n : int , max_n : int , k : int ) -> NgramProposer :
6668 # Dummy model config. Just to set max_model_len.
6769 model_config = ModelConfig (model = "facebook/opt-125m" )
6870 return NgramProposer (
@@ -75,36 +77,120 @@ def ngram_proposer(min_n: int, max_n: int, k: int) -> NgramProposer:
7577 )))
7678
7779 # No match.
78- result = ngram_proposer (
79- min_n = 2 , max_n = 2 ,
80- k = 2 ).propose (context_token_ids = np .array ([1 , 2 , 3 , 4 , 5 ]))
81- assert result is None
80+ token_ids_cpu = np .array ([[1 , 2 , 3 , 4 , 5 ]])
81+ result = get_ngram_proposer (min_n = 2 , max_n = 2 , k = 2 ).propose (
82+ sampled_token_ids = [[0 ]],
83+ req_ids = ["0" ],
84+ num_tokens_no_spec = np .array ([len (c ) for c in token_ids_cpu ]),
85+ token_ids_cpu = token_ids_cpu ,
86+ spec_decode_unsupported_reqs = (),
87+ )
88+ assert len (result [0 ]) == 0
8289
8390 # No match for 4-gram.
84- result = ngram_proposer (
85- min_n = 4 , max_n = 4 ,
86- k = 2 ).propose (context_token_ids = np .array ([1 , 2 , 3 , 4 , 1 , 2 , 3 ]))
87- assert result is None
91+ token_ids_cpu = np .array ([[1 , 2 , 3 , 4 , 1 , 2 , 3 ]])
92+ result = get_ngram_proposer (min_n = 4 , max_n = 4 , k = 2 ).propose (
93+ sampled_token_ids = [[0 ]],
94+ req_ids = ["0" ],
95+ num_tokens_no_spec = np .array ([len (c ) for c in token_ids_cpu ]),
96+ token_ids_cpu = token_ids_cpu ,
97+ spec_decode_unsupported_reqs = (),
98+ )
99+ assert len (result [0 ]) == 0
88100
89101 # No match for 4-gram but match for 3-gram.
90- result = ngram_proposer (
91- min_n = 3 , max_n = 4 ,
92- k = 2 ).propose (context_token_ids = np .array ([1 , 2 , 3 , 4 , 1 , 2 , 3 ]))
93- assert np .array_equal (result , np .array ([4 , 1 ]))
102+ token_ids_cpu = np .array ([[1 , 2 , 3 , 4 , 1 , 2 , 3 ]])
103+ result = get_ngram_proposer (min_n = 3 , max_n = 4 , k = 2 ).propose (
104+ sampled_token_ids = [[0 ]],
105+ req_ids = ["0" ],
106+ num_tokens_no_spec = np .array ([len (c ) for c in token_ids_cpu ]),
107+ token_ids_cpu = token_ids_cpu ,
108+ spec_decode_unsupported_reqs = (),
109+ )
110+ assert np .array_equal (result , np .array ([[4 , 1 ]]))
94111
95112 # Match for both 4-gram and 3-gram.
96113 # In this case, the proposer should return the 4-gram match.
97- result = ngram_proposer (min_n = 3 , max_n = 4 , k = 2 ).propose (
98- context_token_ids = np .array ([2 , 3 , 4 , 5 , 1 , 2 , 3 , 4 , 1 , 2 , 3 , 4 ]))
99- assert np .array_equal (result , np .array ([1 , 2 ])) # Not [5, 1]
114+ token_ids_cpu = np .array ([[2 , 3 , 4 , 5 , 1 , 2 , 3 , 4 , 1 , 2 , 3 , 4 ]])
115+ result = get_ngram_proposer (min_n = 3 , max_n = 4 , k = 2 ).propose (
116+ sampled_token_ids = [[0 ]],
117+ req_ids = ["0" ],
118+ num_tokens_no_spec = np .array ([len (c ) for c in token_ids_cpu ]),
119+ token_ids_cpu = token_ids_cpu ,
120+ spec_decode_unsupported_reqs = (),
121+ )
122+ assert np .array_equal (result , np .array ([[1 , 2 ]])) # Not [5, 1]]
100123
101124 # Match for 2-gram and 3-gram, but not 4-gram.
102- result = ngram_proposer (min_n = 2 , max_n = 4 , k = 2 ).propose (
103- context_token_ids = np .array ([3 , 4 , 5 , 2 , 3 , 4 , 1 , 2 , 3 , 4 ]))
104- assert np .array_equal (result , np .array ([1 , 2 ])) # Not [5, 2]
125+ token_ids_cpu = np .array ([[3 , 4 , 5 , 2 , 3 , 4 , 1 , 2 , 3 , 4 ]])
126+ result = get_ngram_proposer (min_n = 2 , max_n = 4 , k = 2 ).propose (
127+ sampled_token_ids = [[0 ]],
128+ req_ids = ["0" ],
129+ num_tokens_no_spec = np .array ([len (c ) for c in token_ids_cpu ]),
130+ token_ids_cpu = token_ids_cpu ,
131+ spec_decode_unsupported_reqs = (),
132+ )
133+ assert np .array_equal (result , np .array ([[1 , 2 ]])) # Not [5, 2]]
105134
106135 # Multiple 3-gram matched, but always pick the first one.
107- result = ngram_proposer (
108- min_n = 3 , max_n = 3 , k = 2 ).propose (context_token_ids = np .array (
109- [1 , 2 , 3 , 100 , 1 , 2 , 3 , 200 , 1 , 2 , 3 , 300 , 1 , 2 , 3 ]))
110- assert np .array_equal (result , np .array ([100 , 1 ]))
136+ token_ids_cpu = np .array (
137+ [[1 , 2 , 3 , 100 , 1 , 2 , 3 , 200 , 1 , 2 , 3 , 300 , 1 , 2 , 3 ]])
138+ result = get_ngram_proposer (min_n = 3 , max_n = 3 , k = 2 ).propose (
139+ sampled_token_ids = [[0 ]],
140+ req_ids = ["0" ],
141+ num_tokens_no_spec = np .array ([len (c ) for c in token_ids_cpu ]),
142+ token_ids_cpu = token_ids_cpu ,
143+ spec_decode_unsupported_reqs = (),
144+ )
145+ assert np .array_equal (result , np .array ([[100 , 1 ]]))
146+
147+ # check empty input
148+ token_ids_cpu = np .array ([[]])
149+ result = get_ngram_proposer (min_n = 2 , max_n = 2 , k = 2 ).propose (
150+ sampled_token_ids = [[0 ]],
151+ req_ids = ["0" ],
152+ num_tokens_no_spec = np .array ([len (c ) for c in token_ids_cpu ]),
153+ token_ids_cpu = token_ids_cpu ,
154+ spec_decode_unsupported_reqs = (),
155+ )
156+ assert len (result [0 ]) == 0
157+
158+ # check multibatch input
159+ # first request has 5 tokens and a match
160+ # second request has 3 tokens and no match. Padded with -1 for max len 5
161+ token_ids_cpu = np .array ([[1 , 2 , 3 , 1 , 2 ], [4 , 5 , 6 , - 1 , - 1 ]])
162+ result = get_ngram_proposer (min_n = 2 , max_n = 2 , k = 2 ).propose (
163+ sampled_token_ids = [[0 ], [1 ]],
164+ req_ids = ["0" , "1" ],
165+ num_tokens_no_spec = np .array ([5 , 3 ]),
166+ token_ids_cpu = token_ids_cpu ,
167+ spec_decode_unsupported_reqs = (),
168+ )
169+ assert len (result [0 ]) == 2
170+ assert np .array_equal (result [0 ], np .array ([3 , 1 ]))
171+ assert np .array_equal (result [1 ], np .array ([]))
172+
173+ # test if 0 threads available: can happen if TP size > CPU count
174+ ngram_proposer = get_ngram_proposer (min_n = 2 , max_n = 2 , k = 2 )
175+ ngram_proposer .num_numba_thread_available = 0
176+ # set max_model_len to 2 * threshold to ensure multithread is used
177+ num_tokens_threshold = ngram_proposer .num_tokens_threshold
178+ ngram_proposer .max_model_len = 2 * num_tokens_threshold
179+ # using multibatch test
180+ middle_integer = num_tokens_threshold // 2
181+ input_1 = [_ for _ in range (num_tokens_threshold )]
182+ input_1 += [middle_integer , middle_integer + 1 ]
183+ input_2 = [- 1 ] * len (input_1 )
184+ input_2 [:3 ] = [4 , 5 , 6 ]
185+ token_ids_cpu = np .array ([input_1 , input_2 ])
186+ result = ngram_proposer .propose (
187+ sampled_token_ids = [[0 ], [1 ]],
188+ req_ids = ["0" , "1" ],
189+ num_tokens_no_spec = np .array ([len (input_1 ), 3 ]),
190+ token_ids_cpu = token_ids_cpu ,
191+ spec_decode_unsupported_reqs = (),
192+ )
193+ assert len (result [0 ]) == 2
194+ assert np .array_equal (result [0 ],
195+ np .array ([middle_integer + 2 , middle_integer + 3 ]))
196+ assert np .array_equal (result [1 ], np .array ([]))
0 commit comments