|  | 
|  | 1 | +# SPDX-License-Identifier: Apache-2.0 | 
|  | 2 | +import pytest | 
|  | 3 | + | 
|  | 4 | +import vllm | 
|  | 5 | +from vllm.lora.request import LoRARequest | 
|  | 6 | + | 
|  | 7 | +# This file contains tests to ensure that LoRA works correctly on the TPU | 
|  | 8 | +# backend. We use a series of custom trained adapters for Qwen2.5-3B-Instruct | 
|  | 9 | +# for this. The adapters are: | 
|  | 10 | +# Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_x_adapter, where x ranges | 
|  | 11 | +# from 1 to 4. | 
|  | 12 | + | 
|  | 13 | +# These adapters are trained using a standard huggingface peft training script, | 
|  | 14 | +# where all the inputs are "What is 1+1? \n" and all the outputs are "x". We run | 
|  | 15 | +# 100 training iterations with a training batch size of 100. | 
|  | 16 | + | 
|  | 17 | + | 
|  | 18 | +@pytest.fixture(scope="function", autouse=True) | 
|  | 19 | +def use_v1_only(monkeypatch: pytest.MonkeyPatch): | 
|  | 20 | +    """ | 
|  | 21 | +    Since Multi-LoRA is only supported on the v1 TPU backend, set VLLM_USE_V1=1 | 
|  | 22 | +    for all tests in this file | 
|  | 23 | +    """ | 
|  | 24 | +    with monkeypatch.context() as m: | 
|  | 25 | +        m.setenv("VLLM_USE_V1", "1") | 
|  | 26 | +        yield | 
|  | 27 | + | 
|  | 28 | + | 
|  | 29 | +def setup_vllm(num_loras: int) -> vllm.LLM: | 
|  | 30 | +    return vllm.LLM(model="Qwen/Qwen2.5-3B-Instruct", | 
|  | 31 | +                    num_scheduler_steps=1, | 
|  | 32 | +                    max_model_len=256, | 
|  | 33 | +                    max_seq_len_to_capture=256, | 
|  | 34 | +                    max_num_seqs=8, | 
|  | 35 | +                    enable_lora=True, | 
|  | 36 | +                    max_loras=num_loras, | 
|  | 37 | +                    max_lora_rank=8) | 
|  | 38 | + | 
|  | 39 | + | 
|  | 40 | +def test_single_lora(): | 
|  | 41 | +    """ | 
|  | 42 | +    This test ensures we can run a single LoRA adapter on the TPU backend. | 
|  | 43 | +    We run "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_1_adapter" which | 
|  | 44 | +    will force Qwen2.5-3B-Instruct to claim 1+1=1. | 
|  | 45 | +    """ | 
|  | 46 | + | 
|  | 47 | +    llm = setup_vllm(1) | 
|  | 48 | + | 
|  | 49 | +    prompt = "What is 1+1? \n" | 
|  | 50 | + | 
|  | 51 | +    lora_request = LoRARequest( | 
|  | 52 | +        "lora_adapter_1", 1, | 
|  | 53 | +        "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_1_adapter") | 
|  | 54 | +    output = llm.generate(prompt, | 
|  | 55 | +                          sampling_params=vllm.SamplingParams(max_tokens=256, | 
|  | 56 | +                                                              temperature=0), | 
|  | 57 | +                          lora_request=lora_request)[0].outputs[0].text | 
|  | 58 | + | 
|  | 59 | +    answer = output.strip()[0] | 
|  | 60 | + | 
|  | 61 | +    assert answer.isdigit() | 
|  | 62 | +    assert int(answer) == 1 | 
|  | 63 | + | 
|  | 64 | + | 
|  | 65 | +def test_lora_hotswapping(): | 
|  | 66 | +    """ | 
|  | 67 | +    This test ensures we can run multiple LoRA adapters on the TPU backend, even | 
|  | 68 | +    if we only have space to store 1. | 
|  | 69 | +     | 
|  | 70 | +    We run "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_x_adapter" which | 
|  | 71 | +    will force Qwen2.5-3B-Instruct to claim 1+1=x, for a range of x. | 
|  | 72 | +    """ | 
|  | 73 | + | 
|  | 74 | +    lora_name_template = \ | 
|  | 75 | +        "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_{}_adapter" | 
|  | 76 | +    lora_requests = [ | 
|  | 77 | +        LoRARequest(f"lora_adapter_{i}", i, lora_name_template.format(i)) | 
|  | 78 | +        for i in range(1, 5) | 
|  | 79 | +    ] | 
|  | 80 | + | 
|  | 81 | +    llm = setup_vllm(1) | 
|  | 82 | + | 
|  | 83 | +    prompt = "What is 1+1? \n" | 
|  | 84 | + | 
|  | 85 | +    for i, req in enumerate(lora_requests): | 
|  | 86 | +        output = llm.generate(prompt, | 
|  | 87 | +                              sampling_params=vllm.SamplingParams( | 
|  | 88 | +                                  max_tokens=256, temperature=0), | 
|  | 89 | +                              lora_request=req)[0].outputs[0].text | 
|  | 90 | +        answer = output.strip()[0] | 
|  | 91 | + | 
|  | 92 | +        assert answer.isdigit() | 
|  | 93 | +        assert int(answer) == i + 1 | 
|  | 94 | + | 
|  | 95 | + | 
|  | 96 | +def test_multi_lora(): | 
|  | 97 | +    """ | 
|  | 98 | +    This test ensures we can run multiple LoRA adapters on the TPU backend, when | 
|  | 99 | +    we have enough space to store all of them. | 
|  | 100 | +     | 
|  | 101 | +    We run "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_x_adapter" which | 
|  | 102 | +    will force Qwen2.5-3B-Instruct to claim 1+1=x, for a range of x. | 
|  | 103 | +    """ | 
|  | 104 | +    lora_name_template = \ | 
|  | 105 | +        "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_{}_adapter" | 
|  | 106 | +    lora_requests = [ | 
|  | 107 | +        LoRARequest(f"lora_adapter_{i}", i, lora_name_template.format(i)) | 
|  | 108 | +        for i in range(1, 5) | 
|  | 109 | +    ] | 
|  | 110 | + | 
|  | 111 | +    llm = setup_vllm(4) | 
|  | 112 | + | 
|  | 113 | +    prompt = "What is 1+1? \n" | 
|  | 114 | + | 
|  | 115 | +    for i, req in enumerate(lora_requests): | 
|  | 116 | +        output = llm.generate(prompt, | 
|  | 117 | +                              sampling_params=vllm.SamplingParams( | 
|  | 118 | +                                  max_tokens=256, temperature=0), | 
|  | 119 | +                              lora_request=req)[0].outputs[0].text | 
|  | 120 | + | 
|  | 121 | +        answer = output.strip()[0] | 
|  | 122 | + | 
|  | 123 | +        assert answer.isdigit() | 
|  | 124 | +        assert int(output.strip()[0]) == i + 1 | 
0 commit comments