Skip to content

Conversation

@NickLucche
Copy link
Collaborator

@NickLucche NickLucche commented Apr 7, 2025

Fix #15833.

ERROR 04-07 08:06:22 [core.py:386]   File "/home/nick/vllm/vllm/v1/worker/tpu_model_runner.py", line 609, in execute_model
ERROR 04-07 08:06:22 [core.py:386]     self._update_states(scheduler_output)
ERROR 04-07 08:06:22 [core.py:386]   File "/home/nick/vllm/vllm/v1/worker/tpu_model_runner.py", line 267, in _update_states
ERROR 04-07 08:06:22 [core.py:386]     generator = torch.Generator(device=self.device)
ERROR 04-07 08:06:22 [core.py:386]                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-07 08:06:22 [core.py:386] RuntimeError: XLA device type not an accelerator.

when setting a per-request seed.

I've looked into using a per-request set_rng_state but this doesn't look to be very efficient. I think we should just disable per-request seed for now.

Update:

I brought back the test_custom_dispatcher.py in CI as I figured this was off due to the issue above.

@github-actions
Copy link

github-actions bot commented Apr 7, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added v1 tpu Related to Google TPUs labels Apr 7, 2025
@mergify mergify bot added the ci/build label Apr 7, 2025
@mgoin
Copy link
Member

mgoin commented Apr 7, 2025

Have you tried making the generator not on the device? It seems to work if I use a generator not on device to produce tensors on device

>>> torch.randn((2,3), device=xm.xla_device())
tensor([[-0.2866,  0.0079, -3.2667],
        [-0.2544, -0.3828, -3.1057]], device='xla:0')
>>> torch.randn((2,3), device=xm.xla_device(), generator=torch.Generator(device=xm.xla_device()))
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: XLA device type not an accelerator.
>>> torch.randn((2,3), device=xm.xla_device(), generator=torch.Generator())
tensor([[ 0.9773,  1.1569, -0.3807],
        [ 1.2525, -0.3068, -1.4705]], device='xla:0')

@NickLucche
Copy link
Collaborator Author

Yep I was aware, xla generally falls back to cpu whenever it doesn't support some op, not sure why generator isn't the case here.
Anyways I couldn't see much use for it unless we moved sampling to cpu, that's why I preferred disabling it altogether.

In particular in the snippet you posted the graph looks smt like:

IR {
  %0 = f32[] prim::Constant(), location=<module>@test_generator.py:6, xla_shape=f32[]
  %1 = f32[2,3]{1,0} aten::expand(%0), location=<module>@test_generator.py:6, xla_shape=f32[2,3]{1,0}, ROOT=0
}

Meaning XLA is just compiling some constant but not actually creating the random tensor on device. This will likely trigger recompilation when the traced constant changes. I think this is what this comment is about pytorch/xla#2532 (comment).

Normally when you create a randn on device you see the whole sequence of ops being traced in the graph

IR {
  %0 = s64[] xla::device_data(), location=<module>@test_generator.py:6, xla_shape=s64[]
  %1 = s64[] prim::Constant(), location=<module>@test_generator.py:6, xla_shape=s64[]
  %2 = s64[] aten::mul(%1, %0), location=<module>@test_generator.py:6, xla_shape=s64[]
  %3 = s64[] prim::Constant(), location=<module>@test_generator.py:6, xla_shape=s64[]
  %4 = s64[] aten::add(%3, %2), location=<module>@test_generator.py:6, xla_shape=s64[]
  %5 = f32[] prim::Constant(), location=<module>@test_generator.py:6, xla_shape=f32[]
  %6 = f32[2,3]{1,0} aten::expand(%5), location=<module>@test_generator.py:6, xla_shape=f32[2,3]{1,0}
  %7 = f32[] prim::Constant(), location=<module>@test_generator.py:6, xla_shape=f32[]
  %8 = f32[2,3]{1,0} aten::expand(%7), location=<module>@test_generator.py:6, xla_shape=f32[2,3]{1,0}
  %9 = f32[2,3]{1,0} aten::normal(%8, %6, %4), location=<module>@test_generator.py:6, xla_shape=f32[2,3]{1,0}
}

@bvrockwell
Copy link
Contributor

Thanks @NickLucche, @yaochengji may have some suggestions here.

@yaochengji
Copy link
Collaborator

RuntimeError: XLA device type not an accelerator.

We have this error because we have a seperated device check logic in the sampler. see https://github.com/pytorch/pytorch/blob/6ea5514e0460604e4b0325a7218a7a8ca2e61819/aten/src/ATen/Context.h#L50.

I'd like to suggest disabling generator logic for now also.

Copy link
Collaborator

@yaochengji yaochengji left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks Nick!

@yaochengji yaochengji added the ready ONLY add when PR is ready to merge/full CI is needed label Apr 9, 2025
Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, but should we wait to land this until we can plug into #16291?

@NickLucche
Copy link
Collaborator Author

I will rebase on that

Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
@NickLucche NickLucche force-pushed the tpu-disable-generators branch from b3de56e to 3733339 Compare April 10, 2025 07:45
@yaochengji
Copy link
Collaborator

The TPU CI test failed with the error

Torch XLA does not support per-request seed

@NickLucche
Copy link
Collaborator Author

@yaochengji Thanks for the ping, my bad I forgot I added it back before I made seed requests error out

Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
@mgoin mgoin merged commit 3cc9af8 into vllm-project:main Apr 10, 2025
41 checks passed
yangw-dev pushed a commit to yangw-dev/vllm that referenced this pull request Apr 21, 2025
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: Yang Wang <elainewy@meta.com>
jikunshang pushed a commit to jikunshang/vllm that referenced this pull request Apr 29, 2025
Signed-off-by: NickLucche <nlucches@redhat.com>
lk-chen pushed a commit to lk-chen/vllm that referenced this pull request Apr 29, 2025
Signed-off-by: NickLucche <nlucches@redhat.com>
RichardoMrMu pushed a commit to RichardoMrMu/vllm that referenced this pull request May 12, 2025
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: Mu Huai <tianbowen.tbw@antgroup.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build ready ONLY add when PR is ready to merge/full CI is needed tpu Related to Google TPUs v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug]: [TPU] V1 seems to silently crash after a while

4 participants