Skip to content

Commit 7139925

Browse files
committed
Apply patch before 4932bcd
Signed-off-by: sibi <85477603+t-sibiraj@users.noreply.github.com>
1 parent 34bc647 commit 7139925

37 files changed

+1949
-1753
lines changed

tests/basic_correctness/test_basic_correctness.py

Lines changed: 58 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -64,31 +64,33 @@ def test_models(
6464
pytest.skip(
6565
f"{backend} does not support gemma2 with full context length.")
6666

67-
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", backend)
68-
69-
# 5042 tokens for gemma2
70-
# gemma2 has alternating sliding window size of 4096
71-
# we need a prompt with more than 4096 tokens to test the sliding window
72-
prompt = "The following numbers of the sequence " + ", ".join(
73-
str(i) for i in range(1024)) + " are:"
74-
example_prompts = [prompt]
75-
76-
with hf_runner(model, dtype=dtype) as hf_model:
77-
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
78-
79-
with VllmRunner(model,
80-
max_model_len=8192,
81-
dtype=dtype,
82-
enforce_eager=enforce_eager,
83-
gpu_memory_utilization=0.7) as vllm_model:
84-
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
85-
86-
check_outputs_equal(
87-
outputs_0_lst=hf_outputs,
88-
outputs_1_lst=vllm_outputs,
89-
name_0="hf",
90-
name_1="vllm",
91-
)
67+
with monkeypatch.context() as m:
68+
m.setenv("VLLM_ATTENTION_BACKEND", backend)
69+
70+
# 5042 tokens for gemma2
71+
# gemma2 has alternating sliding window size of 4096
72+
# we need a prompt with more than 4096 tokens to test the sliding window
73+
prompt = "The following numbers of the sequence " + ", ".join(
74+
str(i) for i in range(1024)) + " are:"
75+
example_prompts = [prompt]
76+
77+
with hf_runner(model, dtype=dtype) as hf_model:
78+
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
79+
80+
with VllmRunner(model,
81+
max_model_len=8192,
82+
dtype=dtype,
83+
enforce_eager=enforce_eager,
84+
gpu_memory_utilization=0.7) as vllm_model:
85+
vllm_outputs = vllm_model.generate_greedy(example_prompts,
86+
max_tokens)
87+
88+
check_outputs_equal(
89+
outputs_0_lst=hf_outputs,
90+
outputs_1_lst=vllm_outputs,
91+
name_0="hf",
92+
name_1="vllm",
93+
)
9294

9395

9496
@multi_gpu_test(num_gpus=2)
@@ -125,29 +127,34 @@ def test_models_distributed(
125127
monkeypatch_context.setenv("VLLM_USE_RAY_COMPILED_DAG", "1")
126128

127129
if attention_backend:
128-
monkeypatch_context.setenv("VLLM_ATTENTION_BACKEND",
129-
attention_backend)
130-
131-
dtype = "half"
132-
max_tokens = 5
133-
134-
# NOTE: take care of the order. run vLLM first, and then run HF.
135-
# vLLM needs a fresh new process without cuda initialization.
136-
# if we run HF first, the cuda initialization will be done and it
137-
# will hurt multiprocessing backend with fork method (the default method).
138-
with vllm_runner(model,
139-
dtype=dtype,
140-
tensor_parallel_size=2,
141-
distributed_executor_backend=distributed_executor_backend
142-
) as vllm_model:
143-
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
144-
145-
with hf_runner(model, dtype=dtype) as hf_model:
146-
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
147-
148-
check_outputs_equal(
149-
outputs_0_lst=hf_outputs,
150-
outputs_1_lst=vllm_outputs,
151-
name_0="hf",
152-
name_1="vllm",
153-
)
130+
monkeypatch_context.setenv(
131+
"VLLM_ATTENTION_BACKEND",
132+
attention_backend,
133+
)
134+
135+
dtype = "half"
136+
max_tokens = 5
137+
138+
# NOTE: take care of the order. run vLLM first, and then run HF.
139+
# vLLM needs a fresh new process without cuda initialization.
140+
# if we run HF first, the cuda initialization will be done and it
141+
# will hurt multiprocessing backend with fork method
142+
# (the default method).
143+
with vllm_runner(
144+
model,
145+
dtype=dtype,
146+
tensor_parallel_size=2,
147+
distributed_executor_backend=distributed_executor_backend,
148+
) as vllm_model:
149+
vllm_outputs = vllm_model.generate_greedy(example_prompts,
150+
max_tokens)
151+
152+
with hf_runner(model, dtype=dtype) as hf_model:
153+
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
154+
155+
check_outputs_equal(
156+
outputs_0_lst=hf_outputs,
157+
outputs_1_lst=vllm_outputs,
158+
name_0="hf",
159+
name_1="vllm",
160+
)

0 commit comments

Comments
 (0)