Skip to content

Commit

Permalink
Enable cache ops for beam search (vllm-project#3)
Browse files Browse the repository at this point in the history
Co-authored-by: Mikhail Dvoretckii <mdvoretckii@habana.ai>
  • Loading branch information
mdvoretc-intel and Mikhail Dvoretckii authored Feb 20, 2024
1 parent 512c414 commit f51f149
Showing 1 changed file with 27 additions and 0 deletions.
27 changes: 27 additions & 0 deletions vllm/hpu/cache_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,30 @@ def reshape_and_cache(key, value, key_cache, value_cache, slot_mapping, is_promp
index.add_(1)
key_cache = key_cache.permute(0, 2, 3, 1)
value_cache = value_cache.permute(0, 2, 3, 1)


def swap_blocks(src, dst, block_mapping):
index_src = torch.zeros((1,), dtype=torch.int32, device=key_caches[0].device)
index_dst = torch.zeros((1,), dtype=torch.int32, device=key_caches[0].device)
for src_idx, dst_idx in block_mapping.items():
index_src[0] = src_idx
index_dst[0] = dst_idx
dst.index_put_([index_dst], src.index_select(0, index_src))
if dst.device.type == 'hpu':
htorch.core.mark_step()
torch.hpu.synchronize()


def copy_blocks(key_caches, value_caches, block_mapping):
index_src = torch.zeros((1,), dtype=torch.int32, device=key_caches[0].device)
index_dst = torch.zeros((1,), dtype=torch.int32, device=key_caches[0].device)
for src, dsts in block_mapping.items():
index_src[0] = src
for dst in dsts:
index_dst[0] = dst
for key_cache in key_caches:
key_cache.index_copy_(0, index_dst, key_cache.index_select(0, index_src))
for value_cache in value_caches:
value_cache.index_copy_(0, index_dst, value_cache.index_select(0, index_src))
if key_caches[0].device.type == 'hpu':
htorch.core.mark_step()

0 comments on commit f51f149

Please sign in to comment.