|
16 | 16 | # |
17 | 17 | import pytest |
18 | 18 | import torch |
19 | | -from vllm import LLM, SamplingParams |
| 19 | +from vllm import SamplingParams |
| 20 | + |
| 21 | +from tests.conftest import VllmRunner |
20 | 22 |
|
21 | 23 | MODELS = [ |
22 | 24 | "Qwen/Qwen2.5-0.5B-Instruct", |
|
38 | 40 | def test_models(model: str, tp_size: int, max_tokens: int, temperature: int, |
39 | 41 | ignore_eos: bool) -> None: |
40 | 42 | # Create an LLM. |
41 | | - llm = LLM( |
42 | | - model=model, |
43 | | - tensor_parallel_size=tp_size, |
44 | | - ) |
45 | | - # Prepare sampling_parames |
46 | | - sampling_params = SamplingParams( |
47 | | - max_tokens=max_tokens, |
48 | | - temperature=temperature, |
49 | | - ignore_eos=ignore_eos, |
50 | | - ) |
| 43 | + with VllmRunner(model_name=model, |
| 44 | + tensor_parallel_size=tp_size, |
| 45 | + ) as vllm_model: |
| 46 | + # Prepare sampling_parames |
| 47 | + sampling_params = SamplingParams( |
| 48 | + max_tokens=max_tokens, |
| 49 | + temperature=temperature, |
| 50 | + ignore_eos=ignore_eos, |
| 51 | + ) |
51 | 52 |
|
52 | | - # Generate texts from the prompts. |
53 | | - # The output is a list of RequestOutput objects |
54 | | - outputs = llm.generate(prompts, sampling_params) |
55 | | - torch.npu.synchronize() |
56 | | - # The output length should be equal to prompts length. |
57 | | - assert len(outputs) == len(prompts) |
| 53 | + # Generate texts from the prompts. |
| 54 | + # The output is a list of RequestOutput objects |
| 55 | + outputs = vllm_model.generate(prompts, sampling_params) |
| 56 | + torch.npu.synchronize() |
| 57 | + # The output length should be equal to prompts length. |
| 58 | + assert len(outputs) == len(prompts) |
0 commit comments