Skip to content

Commit e9d6a3d

Browse files
yaochengjiChengji Yao
andauthored
[TPU] make ptxla not imported when using tpu_commons (#23081)
Signed-off-by: Chengji Yao <chengjiyao@gmail.com> Signed-off-by: Chengji Yao <chengjiyao@google.com> Co-authored-by: Chengji Yao <chengjiyao@gmail.com>
1 parent a4454e9 commit e9d6a3d

File tree

6 files changed

+94
-78
lines changed

6 files changed

+94
-78
lines changed

vllm/distributed/device_communicators/tpu_communicator.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from vllm.config import get_current_vllm_config
1111
from vllm.logger import init_logger
1212
from vllm.platforms import current_platform
13+
from vllm.platforms.tpu import USE_TPU_COMMONS
1314

1415
from .base_device_communicator import DeviceCommunicatorBase
1516

@@ -18,16 +19,17 @@
1819

1920
logger = init_logger(__name__)
2021

21-
if current_platform.is_tpu():
22-
import torch_xla
23-
import torch_xla.core.xla_model as xm
24-
import torch_xla.runtime as xr
25-
from torch_xla._internal import pjrt
26-
from torch_xla.distributed.xla_multiprocessing import (
27-
create_optimized_replica_groups)
28-
29-
if USE_RAY:
30-
from vllm.executor import ray_utils
22+
if not USE_TPU_COMMONS:
23+
logger.info("tpu_commons not found, using vLLM's TpuCommunicator")
24+
if current_platform.is_tpu():
25+
import torch_xla
26+
import torch_xla.core.xla_model as xm
27+
import torch_xla.runtime as xr
28+
from torch_xla._internal import pjrt
29+
from torch_xla.distributed.xla_multiprocessing import (
30+
create_optimized_replica_groups)
31+
if USE_RAY:
32+
from vllm.executor import ray_utils
3133

3234

3335
class TpuCommunicator(DeviceCommunicatorBase):
@@ -94,10 +96,7 @@ def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
9496
return xm.all_gather(input_, dim=dim)
9597

9698

97-
try:
99+
if USE_TPU_COMMONS:
98100
from tpu_commons.distributed.device_communicators import (
99101
TpuCommunicator as TpuCommonsCommunicator)
100102
TpuCommunicator = TpuCommonsCommunicator # type: ignore
101-
except ImportError:
102-
logger.info("tpu_commons not found, using vLLM's TpuCommunicator")
103-
pass

vllm/model_executor/layers/fused_moe/moe_pallas.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
import torch
55
import torch.nn.functional as F
6-
import torch_xla.experimental.custom_kernel # noqa: F401
76

87

98
def _histogram(input: torch.Tensor, min: int, max: int) -> torch.Tensor:
@@ -41,6 +40,7 @@ def fused_moe(
4140
gating_output: [*, num_experts]
4241
"""
4342
assert expert_map is None, "expert_map is not supported for pallas MoE."
43+
import torch_xla.experimental.custom_kernel # noqa: F401
4444
orig_shape = hidden_states.shape
4545
hidden_size = hidden_states.shape[-1]
4646
num_tokens = hidden_states.shape[:-1].numel()

vllm/model_executor/model_loader/default_loader.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -207,16 +207,21 @@ def _get_weights_iterator(
207207
)
208208

209209
if current_platform.is_tpu():
210-
# In PyTorch XLA, we should call `xm.mark_step` frequently so that
211-
# not too many ops are accumulated in the XLA program.
212-
import torch_xla.core.xla_model as xm
210+
from vllm.platforms.tpu import USE_TPU_COMMONS
213211

214-
def _xla_weights_iterator(iterator: Generator):
215-
for weights in iterator:
216-
yield weights
217-
xm.mark_step()
212+
if not USE_TPU_COMMONS:
213+
# In PyTorch XLA, we should call `xm.mark_step`
214+
# requently so that not too many ops are accumulated
215+
# in the XLA program. import torch_xla.core.xla_model
216+
# as xm
217+
import torch_xla.core.xla_model as xm
218218

219-
weights_iterator = _xla_weights_iterator(weights_iterator)
219+
def _xla_weights_iterator(iterator: Generator):
220+
for weights in iterator:
221+
yield weights
222+
xm.mark_step()
223+
224+
weights_iterator = _xla_weights_iterator(weights_iterator)
220225

221226
if self.counter_before_loading_weights == 0.0:
222227
self.counter_before_loading_weights = time.perf_counter()

vllm/platforms/tpu.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424

2525
logger = init_logger(__name__)
2626

27+
USE_TPU_COMMONS = False
28+
2729

2830
class TpuPlatform(Platform):
2931
_enum = PlatformEnum.TPU
@@ -201,6 +203,7 @@ def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str) -> bool:
201203
try:
202204
from tpu_commons.platforms import TpuPlatform as TpuCommonsPlatform
203205
TpuPlatform = TpuCommonsPlatform # type: ignore
206+
USE_TPU_COMMONS = True
204207
except ImportError:
205208
logger.info("tpu_commons not found, using vLLM's TpuPlatform")
206209
pass

vllm/v1/attention/backends/pallas.py

Lines changed: 51 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,6 @@
55
from typing import Optional
66

77
import torch
8-
import torch_xla.core.xla_builder as xb
9-
import torch_xla.experimental.custom_kernel # noqa: F401
10-
# Required to register custom ops.
11-
from torch.library import impl
12-
from torch_xla._internal.jax_workarounds import requires_jax
13-
from torch_xla.experimental.custom_kernel import XLA_LIB
148

159
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
1610
AttentionLayer, AttentionType)
@@ -37,6 +31,57 @@
3731
"uint8": torch.uint8,
3832
}
3933

34+
try:
35+
import tpu_commons # noqa: F401
36+
except ImportError:
37+
# Lazy import torch_xla
38+
import torch_xla.core.xla_builder as xb
39+
import torch_xla.experimental.custom_kernel # noqa: F401
40+
from torch.library import impl
41+
from torch_xla._internal.jax_workarounds import requires_jax
42+
from torch_xla.experimental.custom_kernel import XLA_LIB
43+
44+
@requires_jax
45+
def kv_cache_update_op_impl(kv: torch.Tensor, slot_mapping: torch.Tensor,
46+
kv_cache: torch.Tensor,
47+
num_kv_update_slices: torch.Tensor,
48+
page_size: int, num_slices_per_block: int):
49+
from vllm.attention.ops.pallas_kv_cache_update import kv_cache_update
50+
new_kv_cache = xb.call_jax(
51+
kv_cache_update,
52+
(kv, slot_mapping, kv_cache, num_kv_update_slices), {
53+
"page_size": page_size,
54+
"num_slices_per_block": num_slices_per_block
55+
})
56+
return new_kv_cache
57+
58+
59+
XLA_LIB.define(
60+
"kv_cache_update_op(Tensor kv, Tensor slot_mapping," \
61+
"Tensor kv_cache, Tensor num_kv_update_slices, int page_size," \
62+
"int num_slices_per_block)" \
63+
"-> Tensor", )
64+
65+
@impl(XLA_LIB, "kv_cache_update_op", "XLA")
66+
def kv_cache_update_op_xla(kv: torch.Tensor, slot_mapping: torch.Tensor,
67+
kv_cache: torch.Tensor,
68+
num_kv_update_slices: torch.Tensor,
69+
page_size: int,
70+
num_slices_per_block: int) -> torch.Tensor:
71+
new_kv_cache = kv_cache_update_op_impl(kv, slot_mapping, kv_cache,
72+
num_kv_update_slices, page_size,
73+
num_slices_per_block)
74+
return new_kv_cache
75+
76+
@impl(XLA_LIB, "kv_cache_update_op", "CompositeExplicitAutograd")
77+
def kv_cache_update_op_non_xla(kv: torch.Tensor,
78+
slot_mapping: torch.Tensor,
79+
kv_cache: torch.Tensor,
80+
num_kv_update_slices: torch.Tensor,
81+
page_size: int,
82+
num_slices_per_block: int) -> torch.Tensor:
83+
return kv_cache
84+
4085

4186
class PallasAttentionBackend(AttentionBackend):
4287

@@ -313,46 +358,6 @@ def write_to_kv_cache(
313358
kv_cache.copy_(new_kv_cache)
314359

315360

316-
@requires_jax
317-
def kv_cache_update_op_impl(kv: torch.Tensor, slot_mapping: torch.Tensor,
318-
kv_cache: torch.Tensor,
319-
num_kv_update_slices: torch.Tensor, page_size: int,
320-
num_slices_per_block: int):
321-
from vllm.attention.ops.pallas_kv_cache_update import kv_cache_update
322-
new_kv_cache = xb.call_jax(
323-
kv_cache_update, (kv, slot_mapping, kv_cache, num_kv_update_slices), {
324-
"page_size": page_size,
325-
"num_slices_per_block": num_slices_per_block
326-
})
327-
return new_kv_cache
328-
329-
330-
XLA_LIB.define(
331-
"kv_cache_update_op(Tensor kv, Tensor slot_mapping, Tensor kv_cache," \
332-
"Tensor num_kv_update_slices, int page_size, int num_slices_per_block)" \
333-
"-> Tensor", )
334-
335-
336-
@impl(XLA_LIB, "kv_cache_update_op", "XLA")
337-
def kv_cache_update_op_xla(kv: torch.Tensor, slot_mapping: torch.Tensor,
338-
kv_cache: torch.Tensor,
339-
num_kv_update_slices: torch.Tensor, page_size: int,
340-
num_slices_per_block: int) -> torch.Tensor:
341-
new_kv_cache = kv_cache_update_op_impl(kv, slot_mapping, kv_cache,
342-
num_kv_update_slices, page_size,
343-
num_slices_per_block)
344-
return new_kv_cache
345-
346-
347-
@impl(XLA_LIB, "kv_cache_update_op", "CompositeExplicitAutograd")
348-
def kv_cache_update_op_non_xla(kv: torch.Tensor, slot_mapping: torch.Tensor,
349-
kv_cache: torch.Tensor,
350-
num_kv_update_slices: torch.Tensor,
351-
page_size: int,
352-
num_slices_per_block: int) -> torch.Tensor:
353-
return kv_cache
354-
355-
356361
# We can move this function to a common utils file if it's also useful for other
357362
# hardware.
358363
def dtype_bits(dtype: torch.dtype):

vllm/v1/worker/tpu_worker.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
"""A TPU worker class."""
4+
45
import os
56
from typing import Any, Optional
67

78
import torch
89
import torch.distributed
910
import torch.nn as nn
10-
import torch_xla.core.xla_model as xm
11-
import torch_xla.debug.profiler as xp
12-
import torch_xla.runtime as xr
1311

1412
import vllm.envs as envs
1513
from vllm.config import VllmConfig
@@ -21,19 +19,27 @@
2119
from vllm.lora.request import LoRARequest
2220
from vllm.model_executor import set_random_seed
2321
from vllm.platforms import current_platform
22+
from vllm.platforms.tpu import USE_TPU_COMMONS
2423
from vllm.tasks import SupportedTask
2524
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv
26-
from vllm.v1.attention.backends.pallas import TPU_HEAD_SIZE_ALIGNMENT
2725
from vllm.v1.core.sched.output import SchedulerOutput
2826
from vllm.v1.kv_cache_interface import (AttentionSpec, KVCacheConfig,
2927
KVCacheSpec)
3028
from vllm.v1.outputs import ModelRunnerOutput
3129
from vllm.v1.utils import report_usage_stats
32-
from vllm.v1.worker.tpu_model_runner import TPUModelRunner
3330
from vllm.v1.worker.utils import bind_kv_cache
3431

3532
logger = init_logger(__name__)
3633

34+
if not USE_TPU_COMMONS:
35+
logger.info("tpu_commons not found, using vLLM's TPUWorker.")
36+
import torch_xla.core.xla_model as xm
37+
import torch_xla.debug.profiler as xp
38+
import torch_xla.runtime as xr
39+
40+
from vllm.v1.attention.backends.pallas import TPU_HEAD_SIZE_ALIGNMENT
41+
from vllm.v1.worker.tpu_model_runner import TPUModelRunner
42+
3743

3844
class TPUWorker:
3945

@@ -325,9 +331,7 @@ def _init_tpu_worker_distributed_environment(
325331
ensure_kv_transfer_initialized(vllm_config)
326332

327333

328-
try:
334+
if USE_TPU_COMMONS:
329335
from tpu_commons.worker import TPUWorker as TPUCommonsWorker
336+
330337
TPUWorker = TPUCommonsWorker # type: ignore
331-
except ImportError:
332-
logger.info("tpu_commons not found, using vLLM's TPUWorker.")
333-
pass

0 commit comments

Comments
 (0)