Skip to content

Commit 11f3aa2

Browse files
committed
use old version reject_sample_greedy
Signed-off-by: whx-sjtu <2952154980@qq.com>
1 parent e78eb92 commit 11f3aa2

File tree

2 files changed

+65
-32
lines changed

2 files changed

+65
-32
lines changed

tests/ut/sample/test_rejection_sampler.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ def test_rejection_greedy_sample_pytorch(self):
3636
output_token_ids = torch.full((batch_size, max_spec_len + 1),
3737
PLACEHOLDER_TOKEN_ID)
3838

39-
num_draft_tokens = torch.tensor([2, 2])
39+
cu_num_draft_tokens = torch.tensor([2, 4])
40+
num_draft_tokens = [2, 2]
4041
draft_token_ids = torch.tensor([10, 11, 20, 21])
4142
target_argmax = torch.tensor([10, 99, 20, 22])
4243
bonus_token_ids = torch.tensor([[100], [200]])
@@ -45,10 +46,11 @@ def test_rejection_greedy_sample_pytorch(self):
4546

4647
rejection_greedy_sample_pytorch(
4748
output_token_ids,
48-
num_draft_tokens,
49+
cu_num_draft_tokens,
4950
draft_token_ids,
5051
target_argmax,
5152
bonus_token_ids,
53+
num_draft_tokens,
5254
max_spec_len,
5355
is_greedy,
5456
)

vllm_ascend/sample/rejection_sampler.py

Lines changed: 61 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -156,14 +156,13 @@ def rejection_sample(
156156
bonus_token_ids,
157157
)
158158
else:
159-
num_draft_tokens_tensor = torch.tensor(num_draft_tokens,
160-
device=device)
161159
rejection_greedy_sample_pytorch(
162160
output_token_ids,
163-
num_draft_tokens_tensor,
161+
cu_num_draft_tokens,
164162
draft_token_ids,
165163
target_argmax,
166164
bonus_token_ids,
165+
num_draft_tokens,
167166
max_spec_len,
168167
is_greedy,
169168
)
@@ -311,40 +310,72 @@ def rejection_greedy_sample_spec_len_1_pytorch(
311310

312311
def rejection_greedy_sample_pytorch(
313312
output_token_ids, # [batch_size, max_spec_len + 1]
314-
num_draft_tokens, # [batch_size]
313+
cu_num_draft_tokens, # [batch_size]
315314
draft_token_ids, # [num_tokens]
316315
target_argmax, # [num_tokens]
317-
bonus_token_ids, # [batch_size, 1]
318-
max_spec_len, # int
316+
bonus_token_ids, # [batch_size]
317+
draft_tokens_per_req, # [batch_size], list
318+
max_spec_len,
319319
is_greedy=None, # [batch_size] or None
320320
):
321-
batch_size = output_token_ids.shape[0]
321+
batch_size = output_token_ids.size(0)
322+
num_tokens = draft_token_ids.size(0)
322323
device = output_token_ids.device
324+
draft_tokens_per_req = torch.tensor(draft_tokens_per_req).to(
325+
device, non_blocking=True)
323326
if is_greedy is None:
324327
is_greedy = torch.ones(batch_size, dtype=torch.bool, device=device)
325-
draft_token_mask = draft_token_ids == target_argmax
326-
pos_ids = torch.arange(0, max_spec_len + 1,
327-
device=device).view(1, -1).expand(batch_size, -1)
328-
pos_mask = pos_ids < num_draft_tokens.view(-1, 1)
329-
output_token_mask = torch.zeros([batch_size, max_spec_len + 1],
330-
dtype=torch.int,
331-
device=device)
332-
output_token_mask[pos_mask] = draft_token_mask.to(torch.int)
333-
output_token_mask = torch.cumprod(output_token_mask,
334-
dim=1) # [batch_size, max_spec_len + 1]
335-
extra_accept_pos = torch.max(
336-
pos_ids * output_token_mask, dim=1, keepdim=True)[1] + 1
337-
output_token_mask[:, extra_accept_pos] = True
338-
output_token_mask *= is_greedy.view(-1, 1)
339-
output_token_ids[pos_mask] = target_argmax.to(output_token_ids.dtype)
340-
extra_accept_ids = output_token_ids[:, extra_accept_pos]
341-
output_token_ids[pos_mask] = draft_token_ids.to(output_token_ids.dtype)
342-
output_token_ids[:, extra_accept_pos] = extra_accept_ids.to(
343-
output_token_ids.dtype)
344-
output_token_ids[:, -1] = bonus_token_ids.squeeze(1).to(
345-
output_token_ids.dtype)
346-
output_token_ids[~output_token_mask.bool()] = -1
347-
return output_token_ids
328+
329+
start_indices = cu_num_draft_tokens - draft_tokens_per_req
330+
req_ids = torch.arange(batch_size, device=device)
331+
token_req_ids = torch.repeat_interleave(req_ids, draft_tokens_per_req)
332+
token_positions = torch.arange(
333+
num_tokens, device=device) - start_indices[token_req_ids]
334+
335+
# Find the first mismatch position of each request.
336+
mismatch_global = (draft_token_ids != target_argmax)
337+
if max_spec_len == 0:
338+
first_mismatch_pos_per_req = torch.zeros(batch_size,
339+
dtype=torch.long,
340+
device=device)
341+
else:
342+
# [bs, max_spec_len]
343+
pos_matrix = torch.full((batch_size, max_spec_len),
344+
-1,
345+
dtype=torch.long,
346+
device=device)
347+
pos_matrix[token_req_ids, token_positions] = token_positions
348+
mismatch_matrix = torch.full((batch_size, max_spec_len),
349+
False,
350+
dtype=torch.bool,
351+
device=device)
352+
mismatch_matrix[token_req_ids, token_positions] = mismatch_global
353+
mismatch_positions = torch.where(mismatch_matrix, pos_matrix,
354+
max_spec_len * 2)
355+
first_mismatch_pos_per_req, _ = torch.min(mismatch_positions, dim=1)
356+
no_mismatch_mask = (first_mismatch_pos_per_req == max_spec_len * 2)
357+
first_mismatch_pos_per_req[no_mismatch_mask] = draft_tokens_per_req[
358+
no_mismatch_mask]
359+
360+
# Copy matched target tokens into output.
361+
copy_len = torch.minimum(first_mismatch_pos_per_req + 1,
362+
draft_tokens_per_req)
363+
copy_indices = torch.arange(max_spec_len + 1,
364+
device=device).expand(batch_size, -1)
365+
copy_mask = copy_indices < copy_len.unsqueeze(1)
366+
greedy_mask = is_greedy.unsqueeze(1)
367+
final_copy_mask = copy_mask & greedy_mask
368+
global_idx = start_indices.unsqueeze(1) + copy_indices
369+
output_token_ids[final_copy_mask] = target_argmax[
370+
global_idx[final_copy_mask]].to(output_token_ids.dtype)
371+
# Fill bonus token.
372+
needs_bonus = is_greedy & (first_mismatch_pos_per_req
373+
>= draft_tokens_per_req)
374+
if torch.any(needs_bonus):
375+
bonus_rows = torch.where(needs_bonus)[0]
376+
bonus_cols = draft_tokens_per_req[bonus_rows]
377+
bonus_token_ids = bonus_token_ids.squeeze(1)
378+
output_token_ids[bonus_rows, bonus_cols] = bonus_token_ids[bonus_rows]
348379

349380

350381
def rejection_random_sample_pytorch(

0 commit comments

Comments
 (0)