|
5 | 5 | from typing import Optional |
6 | 6 |
|
7 | 7 | import torch |
8 | | -import torch_xla.core.xla_builder as xb |
9 | | -import torch_xla.experimental.custom_kernel # noqa: F401 |
10 | | -# Required to register custom ops. |
11 | | -from torch.library import impl |
12 | | -from torch_xla._internal.jax_workarounds import requires_jax |
13 | | -from torch_xla.experimental.custom_kernel import XLA_LIB |
14 | 8 |
|
15 | 9 | from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, |
16 | 10 | AttentionLayer, AttentionType) |
|
37 | 31 | "uint8": torch.uint8, |
38 | 32 | } |
39 | 33 |
|
| 34 | +try: |
| 35 | + import tpu_commons # noqa: F401 |
| 36 | +except ImportError: |
| 37 | + # Lazy import torch_xla |
| 38 | + import torch_xla.core.xla_builder as xb |
| 39 | + import torch_xla.experimental.custom_kernel # noqa: F401 |
| 40 | + from torch.library import impl |
| 41 | + from torch_xla._internal.jax_workarounds import requires_jax |
| 42 | + from torch_xla.experimental.custom_kernel import XLA_LIB |
| 43 | + |
| 44 | + @requires_jax |
| 45 | + def kv_cache_update_op_impl(kv: torch.Tensor, slot_mapping: torch.Tensor, |
| 46 | + kv_cache: torch.Tensor, |
| 47 | + num_kv_update_slices: torch.Tensor, |
| 48 | + page_size: int, num_slices_per_block: int): |
| 49 | + from vllm.attention.ops.pallas_kv_cache_update import kv_cache_update |
| 50 | + new_kv_cache = xb.call_jax( |
| 51 | + kv_cache_update, |
| 52 | + (kv, slot_mapping, kv_cache, num_kv_update_slices), { |
| 53 | + "page_size": page_size, |
| 54 | + "num_slices_per_block": num_slices_per_block |
| 55 | + }) |
| 56 | + return new_kv_cache |
| 57 | + |
| 58 | + |
| 59 | + XLA_LIB.define( |
| 60 | + "kv_cache_update_op(Tensor kv, Tensor slot_mapping," \ |
| 61 | + "Tensor kv_cache, Tensor num_kv_update_slices, int page_size," \ |
| 62 | + "int num_slices_per_block)" \ |
| 63 | + "-> Tensor", ) |
| 64 | + |
| 65 | + @impl(XLA_LIB, "kv_cache_update_op", "XLA") |
| 66 | + def kv_cache_update_op_xla(kv: torch.Tensor, slot_mapping: torch.Tensor, |
| 67 | + kv_cache: torch.Tensor, |
| 68 | + num_kv_update_slices: torch.Tensor, |
| 69 | + page_size: int, |
| 70 | + num_slices_per_block: int) -> torch.Tensor: |
| 71 | + new_kv_cache = kv_cache_update_op_impl(kv, slot_mapping, kv_cache, |
| 72 | + num_kv_update_slices, page_size, |
| 73 | + num_slices_per_block) |
| 74 | + return new_kv_cache |
| 75 | + |
| 76 | + @impl(XLA_LIB, "kv_cache_update_op", "CompositeExplicitAutograd") |
| 77 | + def kv_cache_update_op_non_xla(kv: torch.Tensor, |
| 78 | + slot_mapping: torch.Tensor, |
| 79 | + kv_cache: torch.Tensor, |
| 80 | + num_kv_update_slices: torch.Tensor, |
| 81 | + page_size: int, |
| 82 | + num_slices_per_block: int) -> torch.Tensor: |
| 83 | + return kv_cache |
| 84 | + |
40 | 85 |
|
41 | 86 | class PallasAttentionBackend(AttentionBackend): |
42 | 87 |
|
@@ -313,46 +358,6 @@ def write_to_kv_cache( |
313 | 358 | kv_cache.copy_(new_kv_cache) |
314 | 359 |
|
315 | 360 |
|
316 | | -@requires_jax |
317 | | -def kv_cache_update_op_impl(kv: torch.Tensor, slot_mapping: torch.Tensor, |
318 | | - kv_cache: torch.Tensor, |
319 | | - num_kv_update_slices: torch.Tensor, page_size: int, |
320 | | - num_slices_per_block: int): |
321 | | - from vllm.attention.ops.pallas_kv_cache_update import kv_cache_update |
322 | | - new_kv_cache = xb.call_jax( |
323 | | - kv_cache_update, (kv, slot_mapping, kv_cache, num_kv_update_slices), { |
324 | | - "page_size": page_size, |
325 | | - "num_slices_per_block": num_slices_per_block |
326 | | - }) |
327 | | - return new_kv_cache |
328 | | - |
329 | | - |
330 | | -XLA_LIB.define( |
331 | | - "kv_cache_update_op(Tensor kv, Tensor slot_mapping, Tensor kv_cache," \ |
332 | | - "Tensor num_kv_update_slices, int page_size, int num_slices_per_block)" \ |
333 | | - "-> Tensor", ) |
334 | | - |
335 | | - |
336 | | -@impl(XLA_LIB, "kv_cache_update_op", "XLA") |
337 | | -def kv_cache_update_op_xla(kv: torch.Tensor, slot_mapping: torch.Tensor, |
338 | | - kv_cache: torch.Tensor, |
339 | | - num_kv_update_slices: torch.Tensor, page_size: int, |
340 | | - num_slices_per_block: int) -> torch.Tensor: |
341 | | - new_kv_cache = kv_cache_update_op_impl(kv, slot_mapping, kv_cache, |
342 | | - num_kv_update_slices, page_size, |
343 | | - num_slices_per_block) |
344 | | - return new_kv_cache |
345 | | - |
346 | | - |
347 | | -@impl(XLA_LIB, "kv_cache_update_op", "CompositeExplicitAutograd") |
348 | | -def kv_cache_update_op_non_xla(kv: torch.Tensor, slot_mapping: torch.Tensor, |
349 | | - kv_cache: torch.Tensor, |
350 | | - num_kv_update_slices: torch.Tensor, |
351 | | - page_size: int, |
352 | | - num_slices_per_block: int) -> torch.Tensor: |
353 | | - return kv_cache |
354 | | - |
355 | | - |
356 | 361 | # We can move this function to a common utils file if it's also useful for other |
357 | 362 | # hardware. |
358 | 363 | def dtype_bits(dtype: torch.dtype): |
|
0 commit comments