|
| 1 | +import pytest |
| 2 | +import torch |
| 3 | + |
| 4 | +from vllm.sequence import ExecuteModelRequest |
| 5 | +from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer |
| 6 | +from vllm.spec_decode.interfaces import SpeculativeProposals, SpeculativeScores |
| 7 | +from vllm.spec_decode.mqa_scorer import MQAScorer |
| 8 | +from vllm.worker.worker import Worker |
| 9 | + |
| 10 | +from .utils import create_batch, create_worker |
| 11 | + |
| 12 | + |
| 13 | +def create_proposal(batch_size: int, propose_len: int, vocab_size: int, |
| 14 | + device: str) -> SpeculativeProposals: |
| 15 | + proposal_probs = torch.rand((batch_size, propose_len, vocab_size), |
| 16 | + device=device) |
| 17 | + proposal_token_ids = torch.argmax(proposal_probs, dim=-1) |
| 18 | + proposal_lens = torch.tensor([propose_len] * batch_size, device=device) |
| 19 | + return SpeculativeProposals(proposal_token_ids, proposal_probs, |
| 20 | + proposal_lens) |
| 21 | + |
| 22 | + |
| 23 | +def assert_score_equal(score1: SpeculativeScores, |
| 24 | + score2: SpeculativeScores) -> None: |
| 25 | + assert torch.allclose(score1.probs, score2.probs) |
| 26 | + assert torch.allclose(score1.logprobs, score2.logprobs) |
| 27 | + assert torch.equal(score1.token_ids, score2.token_ids) |
| 28 | + |
| 29 | + |
| 30 | +@pytest.mark.parametrize('model_name', ['facebook/opt-125m']) |
| 31 | +@pytest.mark.parametrize('batch_size', [1, 2, 4, 8, 16]) |
| 32 | +@pytest.mark.parametrize('propose_len', [1, 3, 5]) |
| 33 | +@pytest.mark.parametrize('device', ['cuda']) |
| 34 | +def test_scoroer(model_name: str, batch_size: int, propose_len: int, |
| 35 | + device: str) -> None: |
| 36 | + """ |
| 37 | + Compare the batch expansion scorer and mqa scorer return the same score |
| 38 | + """ |
| 39 | + seed = 0 |
| 40 | + block_size = 32 |
| 41 | + num_gpu_blocks = 2048 // block_size |
| 42 | + scorer_worker = create_worker(Worker, model_name, block_size, |
| 43 | + num_gpu_blocks, seed) |
| 44 | + scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor = True |
| 45 | + scorer_worker.model_runner.model.sampler.\ |
| 46 | + should_modify_greedy_probs_inplace = True |
| 47 | + |
| 48 | + vocab_size = scorer_worker.vocab_size |
| 49 | + proposals = create_proposal(batch_size, propose_len, vocab_size, device) |
| 50 | + seq_group_metadatalist, _, _ = create_batch(batch_size, |
| 51 | + propose_len, |
| 52 | + block_size=block_size, |
| 53 | + num_gpu_blocks=num_gpu_blocks) |
| 54 | + requests = ExecuteModelRequest(seq_group_metadatalist, |
| 55 | + num_lookahead_slots=propose_len) |
| 56 | + |
| 57 | + batch_expansion_scorer = BatchExpansionTop1Scorer(scorer_worker, device, |
| 58 | + vocab_size) |
| 59 | + batch_expansion_score = batch_expansion_scorer.score_proposals( |
| 60 | + requests, proposals) |
| 61 | + |
| 62 | + mqa_scorer = MQAScorer(scorer_worker, device, vocab_size) |
| 63 | + mqa_score = mqa_scorer.score_proposals(requests, proposals) |
| 64 | + |
| 65 | + assert_score_equal(batch_expansion_score, mqa_score) |
0 commit comments