| 
 | 1 | +# SPDX-License-Identifier: Apache-2.0  | 
 | 2 | +"""Test that we handle an Error in model forward and shutdown."""  | 
 | 3 | + | 
 | 4 | +import asyncio  | 
 | 5 | + | 
 | 6 | +import pytest  | 
 | 7 | + | 
 | 8 | +from tests.utils import wait_for_gpu_memory_to_clear  | 
 | 9 | +from tests.v1.shutdown.utils import (SHUTDOWN_TEST_THRESHOLD_BYTES,  | 
 | 10 | +                                     SHUTDOWN_TEST_TIMEOUT_SEC)  | 
 | 11 | +from vllm import LLM, AsyncEngineArgs, SamplingParams  | 
 | 12 | +from vllm.distributed import get_tensor_model_parallel_rank  | 
 | 13 | +from vllm.model_executor.models.llama import LlamaForCausalLM  | 
 | 14 | +from vllm.utils import cuda_device_count_stateless  | 
 | 15 | +from vllm.v1.engine.async_llm import AsyncLLM  | 
 | 16 | +from vllm.v1.engine.exceptions import EngineDeadError  | 
 | 17 | + | 
 | 18 | +MODELS = ["meta-llama/Llama-3.2-1B"]  | 
 | 19 | + | 
 | 20 | + | 
 | 21 | +def evil_forward(self, *args, **kwargs):  | 
 | 22 | +    """Evil forward method that raise an exception after 10 calls."""  | 
 | 23 | +    NUMBER_OF_GOOD_PASSES = 10  | 
 | 24 | + | 
 | 25 | +    if not hasattr(self, "num_calls"):  | 
 | 26 | +        self.num_calls = 0  | 
 | 27 | + | 
 | 28 | +    if (self.num_calls == NUMBER_OF_GOOD_PASSES  | 
 | 29 | +            and get_tensor_model_parallel_rank() == 0):  | 
 | 30 | +        raise Exception("Simulated illegal memory access on Rank 0!")  | 
 | 31 | +    self.num_calls += 1  | 
 | 32 | + | 
 | 33 | +    return self.model(*args, **kwargs)  | 
 | 34 | + | 
 | 35 | + | 
 | 36 | +@pytest.mark.asyncio  | 
 | 37 | +@pytest.mark.parametrize("tensor_parallel_size", [2, 1])  | 
 | 38 | +@pytest.mark.parametrize("model", MODELS)  | 
 | 39 | +async def test_async_llm_model_error(monkeypatch, tensor_parallel_size: int,  | 
 | 40 | +                                     model: str) -> None:  | 
 | 41 | +    """Test that AsyncLLM propagates a forward pass error and frees memory.  | 
 | 42 | +      | 
 | 43 | +    AsyncLLM always uses an MP client.  | 
 | 44 | +    """  | 
 | 45 | +    if cuda_device_count_stateless() < tensor_parallel_size:  | 
 | 46 | +        pytest.skip(reason="Not enough CUDA devices")  | 
 | 47 | + | 
 | 48 | +    # Monkeypatch an error in the model.  | 
 | 49 | +    monkeypatch.setattr(LlamaForCausalLM, "forward", evil_forward)  | 
 | 50 | + | 
 | 51 | +    engine_args = AsyncEngineArgs(model=model,  | 
 | 52 | +                                  enforce_eager=True,  | 
 | 53 | +                                  tensor_parallel_size=tensor_parallel_size)  | 
 | 54 | +    async_llm = AsyncLLM.from_engine_args(engine_args)  | 
 | 55 | + | 
 | 56 | +    async def generate(request_id: str):  | 
 | 57 | +        generator = async_llm.generate("Hello my name is",  | 
 | 58 | +                                       request_id=request_id,  | 
 | 59 | +                                       sampling_params=SamplingParams())  | 
 | 60 | +        try:  | 
 | 61 | +            async for _ in generator:  | 
 | 62 | +                pass  | 
 | 63 | +        except Exception as e:  | 
 | 64 | +            return e  | 
 | 65 | + | 
 | 66 | +    NUM_REQS = 3  | 
 | 67 | +    tasks = [generate(f"request-{idx}") for idx in range(NUM_REQS)]  | 
 | 68 | +    outputs = await asyncio.gather(*tasks)  | 
 | 69 | + | 
 | 70 | +    # Every request should get an EngineDeadError.  | 
 | 71 | +    for output in outputs:  | 
 | 72 | +        assert isinstance(output, EngineDeadError)  | 
 | 73 | + | 
 | 74 | +    # AsyncLLM should be errored.  | 
 | 75 | +    assert async_llm.errored  | 
 | 76 | + | 
 | 77 | +    # We should not be able to make another request.  | 
 | 78 | +    with pytest.raises(EngineDeadError):  | 
 | 79 | +        async for _ in async_llm.generate("Hello my name is",  | 
 | 80 | +                                          request_id="abc",  | 
 | 81 | +                                          sampling_params=SamplingParams()):  | 
 | 82 | +            raise Exception("We should not get here.")  | 
 | 83 | + | 
 | 84 | +    # Confirm all the processes are cleaned up.  | 
 | 85 | +    wait_for_gpu_memory_to_clear(  | 
 | 86 | +        devices=list(range(tensor_parallel_size)),  | 
 | 87 | +        threshold_bytes=2 * 2**30,  | 
 | 88 | +        timeout_s=60,  | 
 | 89 | +    )  | 
 | 90 | + | 
 | 91 | +    # NOTE: shutdown is handled by the API Server if an exception  | 
 | 92 | +    # occurs, so it is expected that we would need to call this.  | 
 | 93 | +    async_llm.shutdown()  | 
 | 94 | + | 
 | 95 | + | 
 | 96 | +@pytest.mark.timeout(SHUTDOWN_TEST_TIMEOUT_SEC)  | 
 | 97 | +@pytest.mark.parametrize("enable_multiprocessing", [True])  | 
 | 98 | +@pytest.mark.parametrize("tensor_parallel_size", [2, 1])  | 
 | 99 | +@pytest.mark.parametrize("model", MODELS)  | 
 | 100 | +def test_llm_model_error(monkeypatch, tensor_parallel_size: int,  | 
 | 101 | +                         enable_multiprocessing: bool, model: str) -> None:  | 
 | 102 | +    """Test that LLM propagates a forward pass error and frees memory.  | 
 | 103 | +    TODO(andy) - LLM without multiprocessing; LLM with multiprocessing  | 
 | 104 | +    and >1 rank  | 
 | 105 | +    """  | 
 | 106 | +    if cuda_device_count_stateless() < tensor_parallel_size:  | 
 | 107 | +        pytest.skip(reason="Not enough CUDA devices")  | 
 | 108 | + | 
 | 109 | +    with monkeypatch.context() as m:  | 
 | 110 | + | 
 | 111 | +        MP_VALUE = "1" if enable_multiprocessing else "0"  | 
 | 112 | +        m.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", MP_VALUE)  | 
 | 113 | + | 
 | 114 | +        # Monkeypatch an error in the model.  | 
 | 115 | +        m.setattr(LlamaForCausalLM, "forward", evil_forward)  | 
 | 116 | + | 
 | 117 | +        llm = LLM(model=model,  | 
 | 118 | +                  enforce_eager=True,  | 
 | 119 | +                  tensor_parallel_size=tensor_parallel_size)  | 
 | 120 | + | 
 | 121 | +        with pytest.raises(  | 
 | 122 | +                EngineDeadError if enable_multiprocessing else Exception):  | 
 | 123 | +            llm.generate("Hello my name is Robert and I")  | 
 | 124 | + | 
 | 125 | +        # Confirm all the processes are cleaned up.  | 
 | 126 | +        wait_for_gpu_memory_to_clear(  | 
 | 127 | +            devices=list(range(tensor_parallel_size)),  | 
 | 128 | +            threshold_bytes=SHUTDOWN_TEST_THRESHOLD_BYTES,  | 
 | 129 | +        )  | 
0 commit comments