Skip to content

Commit e46aae7

Browse files
author
Andrew Lapp
committed
fix tests s.t. they mock forgetting the logits processor
1 parent 6b2035e commit e46aae7

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

tests/serve/test_vllm.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import re
22

3+
import pytest
34
import torch
45

56
from outlines.serve.vllm import RegexLogitsProcessor, _patched_apply_logits_processors
@@ -41,14 +42,21 @@ def sample_from_logits(logits):
4142
return torch.multinomial(probs, 1).item()
4243

4344

44-
def test_time_regexp():
45+
@pytest.mark.parametrize("forget_logits_processor", [True, False])
46+
def test_time_regexp(forget_logits_processor):
4547
pattern = r"(0?[1-9]|1[0-2]):[0-5]\d\s?(am|pm)?"
4648
llm = MockModel()
4749
logits_processor = RegexLogitsProcessor(pattern, llm)
4850

4951
token_ids = []
5052
while True:
5153
random_scores = -10 + 20 * torch.rand(len(llm.tokenizer.vocabulary))
54+
55+
# mock "forgetting" the logits processor behavior in
56+
# vLLM tensor-parallel world size > 1
57+
if forget_logits_processor:
58+
logits_processor = RegexLogitsProcessor(pattern, llm)
59+
5260
logits = logits_processor(
5361
input_ids=token_ids,
5462
scores=random_scores,

0 commit comments

Comments
 (0)