Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .buildkite/run-tpu-v1-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ docker run --privileged --net host --shm-size=16G -it \
&& echo TEST_5 \
&& python3 /workspace/vllm/examples/offline_inference/tpu.py \
&& echo TEST_6 \
&& pytest -s -v /workspace/vllm/tests/tpu/worker/test_tpu_model_runner.py" \
&& pytest -s -v /workspace/vllm/tests/tpu/worker/test_tpu_model_runner.py \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it make sense to also enable the following in test_basic.py? enforce_eager=False

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so. We should use the default value of enforce_eager (which is False) in most cases.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indeed, it is set to True in test_basic.py

&& echo TEST_7 \
&& pytest -s -v /workspace/vllm/tests/v1/tpu/test_sampler.py" \


# TODO: This test fails because it uses RANDOM_SEED sampling
Expand Down
69 changes: 5 additions & 64 deletions tests/v1/tpu/test_sampler.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
# SPDX-License-Identifier: Apache-2.0
import tempfile
from time import time

import pytest

from vllm import LLM, envs
Expand All @@ -15,60 +12,6 @@
)


@pytest.mark.parametrize("model_name", ["D4nt3/Qwen2.5-two-layers"])
@pytest.mark.skipif(not current_platform.is_tpu(),
reason="This test needs a TPU")
def test_sampler_compilation(model_name: str, monkeypatch):
"""
Check that no recompilation happens despite changing sampling parameters.
We can't read XLA metrics from the engine process, hence we measure time.
"""
with tempfile.TemporaryDirectory() as temp_dir:
monkeypatch.setenv("VLLM_XLA_CACHE_PATH", temp_dir)
# Compiling model init may still take some time, enforce_eager to skip.
llm = LLM(model_name,
enforce_eager=True,
max_num_seqs=16,
max_model_len=1024,
gpu_memory_utilization=0.5)
prompts = [
"A robot may not injure a human being",
"It is only with the heart that one can see rightly;",
]
# First inference should be slow
sampling_params = SamplingParams(
temperature=0.7,
# top_p=0.6, # TODO too slow!
top_k=10,
min_p=0.2,
max_tokens=16)
s = time()
_ = llm.generate(prompts, sampling_params)
run1 = time() - s

# Second request with different params, but for which we
# compiled for in previous eager iteration.
sampling_params = SamplingParams(temperature=0.1,
top_k=12,
min_p=0.8,
max_tokens=24)
s = time()
_ = llm.generate(prompts, sampling_params)
run2 = time() - s
# Much faster after compiling
assert run1 * 0.1 > run2
print("TIMES", run1, run2)

# Third request with min_p set to "None". It will not trigger
# recompilation as a default 0 value will be used.
sampling_params = SamplingParams(max_tokens=24, temperature=0.0)
s = time()
_ = llm.generate(prompts, sampling_params)
run3 = time() - s
assert run1 * 0.1 > run3
print("TIMES", run1, run3)


@pytest.mark.parametrize("model_name", ["Qwen/Qwen2.5-1.5B-Instruct"])
@pytest.mark.skipif(not current_platform.is_tpu(),
reason="This test needs a TPU")
Expand All @@ -77,13 +20,11 @@ def test_sampler_different(model_name: str):
Test significantly different sampling params to assert the model produces
different results.
"""
llm = LLM(
model_name,
enforce_eager=True,
max_num_seqs=1,
max_model_len=64,
# TODO: setting to 0.5 or it will go OOM
gpu_memory_utilization=0.5)
llm = LLM(model_name,
enforce_eager=False,
max_num_seqs=1,
max_model_len=512,
max_num_batched_tokens=512)
prompts = [
"Write a short story about a robot that dreams for the first time."
]
Expand Down
8 changes: 1 addition & 7 deletions vllm/v1/sample/tpu/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def copy_slice(cpu_tensor: torch.Tensor, tpu_tensor: torch.Tensor,
# Copy slice from CPU to corresponding TPU pre-allocated tensor.
# Pad value is the default one.
cpu_tensor[num_reqs:padded_num_reqs] = fill_val
# Subtle compilation: len(tpu_tensor) must be >= `padded_num_reqs`
tpu_tensor[:padded_num_reqs] = cpu_tensor[:padded_num_reqs]

# NOTE NickLucche The sync CPU-TPU graph we produce here must be
Expand All @@ -101,13 +102,6 @@ def copy_slice(cpu_tensor: torch.Tensor, tpu_tensor: torch.Tensor,
copy_slice(input_batch.min_p_cpu_tensor, input_batch.min_p,
DEFAULT_SAMPLING_PARAMS["min_p"])

# copy_slice(input_batch.frequency_penalties_cpu_tensor,
# input_batch.frequency_penalties)
# copy_slice(input_batch.presence_penalties_cpu_tensor,
# input_batch.presence_penalties)
# copy_slice(input_batch.repetition_penalties_cpu_tensor,
# input_batch.repetition_penalties)

xm.mark_step()
xm.wait_device_ops()

Expand Down
8 changes: 6 additions & 2 deletions vllm/v1/worker/tpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ def __init__(
self.max_model_len = model_config.max_model_len
self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size)
self.max_num_tokens = scheduler_config.max_num_batched_tokens
# InputBatch needs to work with sampling tensors greater than padding
# to avoid dynamic shapes. Also, avoid suboptimal alignment.
self.max_num_reqs = max(scheduler_config.max_num_seqs, MIN_NUM_SEQS)

# Model-related.
Expand Down Expand Up @@ -787,6 +789,7 @@ def capture_model(self) -> None:
dummy_hidden = torch.randn((num_tokens, hsize),
device=device,
dtype=torch.bfloat16)
# Compile for [8, 16, .., 128,.., `self.max_num_reqs`]
while True:
indices = torch.zeros(
num_reqs_to_sample,
Expand All @@ -803,7 +806,9 @@ def capture_model(self) -> None:
out = out.cpu()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we'd better add xm.mark_step() before this line unless we use torch.compile.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't it redundant? A sync to CPU will still cause the graph to be flushed and executed

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TL;DR, xm.mark_step computes all the pending tensor, but B.cpu() only computes the exact tensor B.

E.g. say we have the code

A = ...
B = op(A)

The graph output generated by xm.mark_step() and B.cpu() are different.

For xm.mark_step(), we will get both A and B as outputs.
For B.cpu(), only B is the output.

Then if we have another xm.mark_step() later, as A's result is not returned in the previous computation, we have to compute A again.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for explaining this is really helpful.
I don't think this is the case as we only need the sampled output tokens from the sampler step, but I could also add it for completeness.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although we don't intend to use other tensors, a later xm.mark_step will still try to get the results of them. Then we have redundant computation.

BTW, we have an implicit xm.mark_step when using torch.compile. What's one reason I recommend using torch.compile when possible.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And when we use torch.compile, we don't need so many xm.mark_step.

torch.compile is much more easier for pytorch developer to understand.

if num_reqs_to_sample >= self.max_num_reqs:
break
num_reqs_to_sample *= 2
# Make sure to compile the `max_num_reqs` upper-limit case
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@NickLucche thank you for fixing this! qq: is this change actually fix the recompilation issue?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also have such confusion. IMO, the code fix the recompilation issue when max_num_reqs is not power of 2. But in our tests, it's already power of 2.

num_reqs_to_sample = _get_padded_num_reqs_with_upper_limit(
num_reqs_to_sample + 1, self.max_num_reqs)
xm.wait_device_ops()
end = time.perf_counter()
logger.info("Compilation finished in in %.2f [secs].", end - start)
Expand Down Expand Up @@ -896,7 +901,6 @@ def forward(

return hidden_states

# @torch.compile(backend="openxla", fullgraph=True, dynamic=False)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why don't we use torch.compile for this method?

Usually it's easier to know the boundary of TPU computation and avoid recompilation if we wrap it inside a torch.compile

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

once main is stable we'll turn that on. I'd like to do that in a separate PR. Last time around compilation got slower, just wanted to be cautious.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did an experiment base on you branch, I cleaned the xla cache each time before the execution.
without torch.compile: sampler pre-compilation time is 35.79 [secs]
with torch.compile: sampler pre-compilation time is 35.51 [secs]

The compilation time difference is negligible. In the meantime, torch.compile can speed up the execution because the guard check of torch.compile is usually faster than torch/xla's graph trace.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

happy to add it back then! Just would like to merge this PR first if you don't mind, it still fixes the pre-compiliation when max_num_reqs is not a power of 2

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add it here in this PR? Just one line code change.

def sample_from_hidden(
self,
hidden_states: torch.Tensor,
Expand Down