Skip to content

Commit 4e53851

Browse files
mzusmanafeldman-nm
authored andcommitted
[Bugfix][Mamba] Fix Multistep on Mamba-like models (vllm-project#10705)
Signed-off-by: mzusman <mor.zusmann@gmail.com> Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com>
1 parent 9bf5c8d commit 4e53851

File tree

4 files changed

+84
-4
lines changed

4 files changed

+84
-4
lines changed

tests/models/decoder_only/language/test_jamba.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,44 @@ def test_state_cleanup(
275275
"could be related to finished_requests_ids")
276276

277277

278+
@pytest.mark.parametrize("model", MODELS)
279+
@pytest.mark.parametrize("dtype", ["float"])
280+
def test_multistep(
281+
vllm_runner,
282+
model: str,
283+
dtype: str,
284+
example_prompts,
285+
) -> None:
286+
# This test is verifying that multistep works correctly
287+
#on mamba-like models
288+
with vllm_runner(model, num_scheduler_steps=8,
289+
max_num_seqs=2) as vllm_model:
290+
vllm_model.generate_greedy([example_prompts[0]] * 10, 1)
291+
292+
293+
@pytest.mark.parametrize("model", MODELS)
294+
@pytest.mark.parametrize("dtype", ["float"])
295+
@pytest.mark.parametrize("max_tokens", [64])
296+
def test_multistep_correctness(vllm_runner, model: str, dtype: str,
297+
max_tokens: int, example_prompts) -> None:
298+
with vllm_runner(model, num_scheduler_steps=8,
299+
max_num_seqs=2) as vllm_model:
300+
vllm_outputs_multistep = vllm_model.generate_greedy(
301+
example_prompts, max_tokens)
302+
303+
with vllm_runner(model, num_scheduler_steps=1,
304+
max_num_seqs=2) as vllm_model:
305+
vllm_outputs_single_step = vllm_model.generate_greedy(
306+
example_prompts, max_tokens)
307+
308+
check_outputs_equal(
309+
outputs_0_lst=vllm_outputs_multistep,
310+
outputs_1_lst=vllm_outputs_single_step,
311+
name_0="vllm_outputs_multistep",
312+
name_1="vllm_outputs_single_step",
313+
)
314+
315+
278316
@multi_gpu_test(num_gpus=2)
279317
@pytest.mark.parametrize("model", MODELS)
280318
@pytest.mark.parametrize("dtype", ["float"])

tests/models/decoder_only/language/test_mamba.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,3 +283,39 @@ def test_state_cleanup(
283283
except ValueError:
284284
pytest.fail("Mamba inner state wasn't cleaned up between states, "
285285
"could be related to finished_requests_ids")
286+
287+
288+
@pytest.mark.parametrize("model", MODELS)
289+
@pytest.mark.parametrize("dtype", ["float"])
290+
def test_multistep(
291+
vllm_runner,
292+
model: str,
293+
dtype: str,
294+
example_prompts,
295+
) -> None:
296+
with vllm_runner(model, num_scheduler_steps=8,
297+
max_num_seqs=2) as vllm_model:
298+
vllm_model.generate_greedy([example_prompts[0]] * 10, 1)
299+
300+
301+
@pytest.mark.parametrize("model", MODELS)
302+
@pytest.mark.parametrize("dtype", ["float"])
303+
@pytest.mark.parametrize("max_tokens", [64])
304+
def test_multistep_correctness(vllm_runner, model: str, dtype: str,
305+
max_tokens: int, example_prompts) -> None:
306+
with vllm_runner(model, num_scheduler_steps=8,
307+
max_num_seqs=2) as vllm_model:
308+
vllm_outputs_multistep = vllm_model.generate_greedy(
309+
example_prompts, max_tokens)
310+
311+
with vllm_runner(model, num_scheduler_steps=1,
312+
max_num_seqs=2) as vllm_model:
313+
vllm_outputs_single_step = vllm_model.generate_greedy(
314+
example_prompts, max_tokens)
315+
316+
check_outputs_equal(
317+
outputs_0_lst=vllm_outputs_multistep,
318+
outputs_1_lst=vllm_outputs_single_step,
319+
name_0="vllm_outputs_multistep",
320+
name_1="vllm_outputs_single_step",
321+
)

vllm/engine/async_llm_engine.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,9 @@ async def step_async(
300300
ctx.seq_group_metadata_list = seq_group_metadata_list
301301
ctx.scheduler_outputs = scheduler_outputs
302302

303+
finished_requests_ids = self.scheduler[
304+
virtual_engine].get_and_reset_finished_requests_ids()
305+
303306
# Maybe switch from async mode to sync mode
304307
if not allow_async_output_proc and len(ctx.output_queue) > 0:
305308
self._process_model_outputs(ctx=ctx)
@@ -311,13 +314,13 @@ async def step_async(
311314
self._cache_scheduler_outputs_for_multi_step(
312315
virtual_engine, seq_group_metadata_list, scheduler_outputs,
313316
allow_async_output_proc)
317+
else:
318+
finished_requests_ids = list()
314319

315320
assert seq_group_metadata_list is not None
316321
assert scheduler_outputs is not None
317322

318323
if not scheduler_outputs.is_empty():
319-
finished_requests_ids = self.scheduler[
320-
virtual_engine].get_and_reset_finished_requests_ids()
321324

322325
# Check if we have a cached last_output from the previous iteration.
323326
# For supporting PP this is probably the best way to pass the

vllm/engine/llm_engine.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1398,6 +1398,9 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
13981398
ctx.seq_group_metadata_list = seq_group_metadata_list
13991399
ctx.scheduler_outputs = scheduler_outputs
14001400

1401+
finished_requests_ids = self.scheduler[
1402+
virtual_engine].get_and_reset_finished_requests_ids()
1403+
14011404
# Maybe switch from async mode to sync mode
14021405
if not allow_async_output_proc and len(ctx.output_queue) > 0:
14031406
self._process_model_outputs(ctx=ctx)
@@ -1409,13 +1412,13 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
14091412
self._cache_scheduler_outputs_for_multi_step(
14101413
virtual_engine, seq_group_metadata_list, scheduler_outputs,
14111414
allow_async_output_proc)
1415+
else:
1416+
finished_requests_ids = list()
14121417

14131418
assert seq_group_metadata_list is not None
14141419
assert scheduler_outputs is not None
14151420

14161421
if not scheduler_outputs.is_empty():
1417-
finished_requests_ids = self.scheduler[
1418-
virtual_engine].get_and_reset_finished_requests_ids()
14191422

14201423
# Check if we have a cached last_output from the previous iteration.
14211424
# For supporting PP this is probably the best way to pass the

0 commit comments

Comments
 (0)