|
3 | 3 | import bisect |
4 | 4 | import gc |
5 | 5 | import time |
6 | | -from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast |
| 6 | +from typing import TYPE_CHECKING, Any, Optional, cast |
7 | 7 | from unittest.mock import patch |
8 | 8 |
|
9 | 9 | import numpy as np |
|
23 | 23 | get_layers_from_vllm_config, update_config) |
24 | 24 | from vllm.distributed.kv_transfer import (get_kv_transfer_group, |
25 | 25 | has_kv_transfer_group) |
| 26 | +from vllm.distributed.kv_transfer.kv_connector.utils import copy_kv_blocks |
26 | 27 | from vllm.forward_context import set_forward_context |
27 | 28 | from vllm.logger import init_logger |
28 | 29 | from vllm.lora.layers import BaseLayerWithLoRA |
@@ -1887,75 +1888,6 @@ def _get_padded_token_len(paddings: list[int], x: int) -> int: |
1887 | 1888 | return paddings[index] |
1888 | 1889 |
|
1889 | 1890 |
|
1890 | | -def _make_src_and_dst_indices( |
1891 | | - src_block_ids: list[int], |
1892 | | - dst_block_ids: list[int], |
1893 | | - src_device: Union[torch.device, str], |
1894 | | - dst_device: Union[torch.device, str], |
1895 | | -) -> tuple[torch.Tensor, torch.Tensor]: |
1896 | | - src_indices = torch.tensor(src_block_ids, |
1897 | | - device=src_device, |
1898 | | - dtype=torch.int64) |
1899 | | - dst_indices = torch.tensor(dst_block_ids, |
1900 | | - device=dst_device, |
1901 | | - dtype=torch.int64) |
1902 | | - return src_indices, dst_indices |
1903 | | - |
1904 | | - |
1905 | | -@torch.compile(backend="openxla") |
1906 | | -def _insert_blocks_to_tpu( |
1907 | | - cpu_cache: torch.Tensor, |
1908 | | - tpu_cache: torch.Tensor, |
1909 | | - cpu_block_indices: torch.Tensor, |
1910 | | - tpu_block_indices: torch.Tensor, |
1911 | | -) -> None: |
1912 | | - torch.ops.xla.dynamo_set_buffer_donor_(tpu_cache, True) |
1913 | | - tpu_cache[tpu_block_indices] = cpu_cache[cpu_block_indices].to( |
1914 | | - tpu_cache.device) |
1915 | | - |
1916 | | - |
1917 | | -@torch.compile(backend="openxla") |
1918 | | -def _swap_out_tpu_blocks( |
1919 | | - tpu_cache: torch.Tensor, |
1920 | | - cpu_cache: torch.Tensor, |
1921 | | - tpu_block_indices: torch.Tensor, |
1922 | | - cpu_block_indices: torch.Tensor, |
1923 | | -) -> None: |
1924 | | - """ tpu blocks to cpu blocks""" |
1925 | | - torch.ops.xla.dynamo_set_buffer_donor_(tpu_cache, True) |
1926 | | - cpu_cache[cpu_block_indices] = tpu_cache[tpu_block_indices].cpu() |
1927 | | - |
1928 | | - |
1929 | | -def copy_kv_blocks( |
1930 | | - src_kv_caches: dict[str, torch.Tensor], |
1931 | | - dst_kv_caches: dict[str, torch.Tensor], |
1932 | | - src_block_ids: list[int], |
1933 | | - dst_block_ids: list[int], |
1934 | | - direction: Literal["h2d", "d2h"], |
1935 | | -) -> None: |
1936 | | - """Copy kv blocks between different buffers.""" |
1937 | | - if not src_kv_caches or not dst_kv_caches or \ |
1938 | | - not src_block_ids or not dst_block_ids or \ |
1939 | | - len(src_block_ids) != len(dst_block_ids): |
1940 | | - return |
1941 | | - |
1942 | | - src_device = next(iter(src_kv_caches.values())).device |
1943 | | - dst_device = next(iter(dst_kv_caches.values())).device |
1944 | | - |
1945 | | - src_indices, dst_indices = _make_src_and_dst_indices( |
1946 | | - src_block_ids=src_block_ids, |
1947 | | - dst_block_ids=dst_block_ids, |
1948 | | - src_device=src_device, |
1949 | | - dst_device=dst_device) |
1950 | | - |
1951 | | - _copy_fn = _insert_blocks_to_tpu if direction == "h2d" else \ |
1952 | | - _swap_out_tpu_blocks |
1953 | | - for layer_name in src_kv_caches: |
1954 | | - src_tensor = src_kv_caches[layer_name] |
1955 | | - dst_tensor = dst_kv_caches[layer_name] |
1956 | | - _copy_fn(src_tensor, dst_tensor, src_indices, dst_indices) |
1957 | | - |
1958 | | - |
1959 | 1891 | def _get_padded_num_kv_cache_update_slices(num_tokens: int, max_num_reqs: int, |
1960 | 1892 | page_size: int) -> int: |
1961 | 1893 | """Calculates the padded number of KV cache update slices to avoid |
|
0 commit comments