|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | + |
| 3 | +from unittest.mock import MagicMock, patch |
| 4 | + |
| 5 | +import pytest |
| 6 | +import torch |
| 7 | +from vllm.sequence import ExecuteModelRequest |
| 8 | +from vllm.spec_decode.metrics import AsyncMetricsCollector |
| 9 | +from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker |
| 10 | +from vllm.spec_decode.top1_proposer import Top1Proposer |
| 11 | + |
| 12 | +from vllm_ascend.worker.multi_step_worker import MultiStepWorker |
| 13 | + |
| 14 | +from .test_utils import mock_spec_decode_sampler |
| 15 | +from .utils import create_batch, mock_worker |
| 16 | + |
| 17 | + |
| 18 | +@pytest.mark.parametrize('queue_size', [4]) |
| 19 | +@pytest.mark.parametrize('batch_size', [1]) |
| 20 | +@pytest.mark.parametrize('k', [1]) |
| 21 | +@pytest.mark.parametrize("acceptance_sampler_method", |
| 22 | + ["rejection_sampler", "typical_acceptance_sampler"]) |
| 23 | +@torch.inference_mode() |
| 24 | +def test_disable_spec_tokens(queue_size: int, batch_size: int, k: int, |
| 25 | + acceptance_sampler_method: str): |
| 26 | + """Verify that speculative tokens are disabled when the batch size |
| 27 | + exceeds the threshold. |
| 28 | + """ |
| 29 | + disable_by_batch_size = 3 |
| 30 | + draft_worker = mock_worker(cls=MultiStepWorker) |
| 31 | + target_worker = mock_worker() |
| 32 | + metrics_collector = MagicMock(spec=AsyncMetricsCollector) |
| 33 | + worker = SpecDecodeWorker(proposer_worker=draft_worker, |
| 34 | + scorer_worker=target_worker, |
| 35 | + spec_decode_sampler=mock_spec_decode_sampler( |
| 36 | + acceptance_sampler_method), |
| 37 | + disable_logprobs=False, |
| 38 | + metrics_collector=metrics_collector, |
| 39 | + disable_by_batch_size=disable_by_batch_size) |
| 40 | + |
| 41 | + exception_secret = 'artificial stop' |
| 42 | + draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret) |
| 43 | + |
| 44 | + seq_group_metadata_list, _, _ = create_batch(batch_size, k) |
| 45 | + execute_model_req = ExecuteModelRequest( |
| 46 | + seq_group_metadata_list=seq_group_metadata_list, |
| 47 | + num_lookahead_slots=k, |
| 48 | + running_queue_size=queue_size) |
| 49 | + |
| 50 | + if queue_size > disable_by_batch_size: |
| 51 | + with patch.object(worker, |
| 52 | + '_run_no_spec', |
| 53 | + side_effect=ValueError(exception_secret)), \ |
| 54 | + pytest.raises(ValueError, match=exception_secret): |
| 55 | + worker.execute_model(execute_model_req=execute_model_req) |
| 56 | + |
| 57 | + # When the batch size is larger than the threshold, |
| 58 | + # we expect no speculative tokens (0). |
| 59 | + expected_num_spec_tokens = None if queue_size < disable_by_batch_size else 0 |
| 60 | + assert seq_group_metadata_list[ |
| 61 | + 0].num_speculative_tokens == expected_num_spec_tokens |
| 62 | + |
| 63 | + draft_worker.sampler_output.side_effect = ValueError(exception_secret) |
| 64 | + |
| 65 | + proposer = Top1Proposer( |
| 66 | + worker=draft_worker, |
| 67 | + device='cpu', # not used |
| 68 | + vocab_size=100, # not used |
| 69 | + # Must be long enough to avoid being skipped due to length. |
| 70 | + max_proposal_len=1024, |
| 71 | + ) |
| 72 | + |
| 73 | + if queue_size < disable_by_batch_size: |
| 74 | + # Should raise exception when executing the mocked draft model. |
| 75 | + with pytest.raises(ValueError, match=exception_secret): |
| 76 | + proposer.get_spec_proposals( |
| 77 | + execute_model_req=ExecuteModelRequest( |
| 78 | + seq_group_metadata_list=seq_group_metadata_list, |
| 79 | + num_lookahead_slots=k), |
| 80 | + seq_ids_with_bonus_token_in_last_step=set()) |
| 81 | + else: |
| 82 | + # Should not execute the draft model because spec decode is disabled |
| 83 | + # for all requests. Accordingly, the proposal length should be 0. |
| 84 | + proposals = proposer.get_spec_proposals( |
| 85 | + execute_model_req=ExecuteModelRequest( |
| 86 | + seq_group_metadata_list=seq_group_metadata_list, |
| 87 | + num_lookahead_slots=k), |
| 88 | + seq_ids_with_bonus_token_in_last_step=set()) |
| 89 | + assert proposals.proposal_lens.tolist() == [0] * batch_size |
0 commit comments