Skip to content

Commit

Permalink
[Minor] Fix duplication of ignored seq group in engine step (vllm-pro…
Browse files Browse the repository at this point in the history
  • Loading branch information
simon-mo authored Nov 16, 2023
1 parent 3b3161d commit 0874258
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 1 deletion.
27 changes: 27 additions & 0 deletions tests/test_regression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
"""Containing tests that check for regressions in vLLM's behavior.
It should include tests that are reported by users and making sure they
will never happen again.
"""
from vllm import LLM, SamplingParams


def test_duplicated_ignored_sequence_group():
"""https://github.com/vllm-project/vllm/issues/1655"""

sampling_params = SamplingParams(temperature=0.01,
top_p=0.1,
max_tokens=256)
llm = LLM(model="facebook/opt-125m",
max_num_batched_tokens=4096,
tensor_parallel_size=1)
prompts = ["This is a short prompt", "This is a very long prompt " * 1000]
outputs = llm.generate(prompts, sampling_params=sampling_params)

assert len(prompts) == len(outputs)


if __name__ == "__main__":
import pytest
pytest.main([__file__])
2 changes: 1 addition & 1 deletion vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,7 +567,7 @@ def step(self) -> List[RequestOutput]:
blocks_to_copy=scheduler_outputs.blocks_to_copy,
)

return self._process_model_outputs(output, scheduler_outputs) + ignored
return self._process_model_outputs(output, scheduler_outputs)

def _log_system_stats(
self,
Expand Down

0 comments on commit 0874258

Please sign in to comment.