Skip to content

Commit

Permalink
Merge pull request vllm-project#2 from lsy323/lsiyuan/torchax-on-v1
Browse files Browse the repository at this point in the history
Use torchax in pallas.py
  • Loading branch information
yaochengji authored Feb 4, 2025
2 parents 56660db + c9e8d41 commit d121759
Show file tree
Hide file tree
Showing 2 changed files with 691 additions and 28 deletions.
50 changes: 22 additions & 28 deletions vllm/v1/attention/backends/pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,16 @@
from typing import Any, Dict, List, Optional, Tuple, Type

import torch
import torch_xla.experimental.custom_kernel # Required to register custom ops.

# import torch_xla.experimental.custom_kernel # Required to register custom ops.
import torchax
from jax.experimental.pallas.ops.tpu import flash_attention
from jax.experimental.pallas.ops.tpu.paged_attention.paged_attention_kernel import paged_attention
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer,
AttentionMetadata, AttentionType)
from vllm.attention.backends.utils import CommonAttentionState
from vllm.v1.attention.backends.pallas_multi_queries_paged_attention_kernel import \
paged_attention as multi_queries_paged_attention


class PallasAttentionBackend(AttentionBackend):
Expand Down Expand Up @@ -53,9 +57,9 @@ def copy_blocks(
) -> None:
src_indices, dst_indices = src_to_dists
for k_cache, v_cache in kv_caches:
torch.ops.xla.dynamo_set_buffer_donor_(k_cache, True)
# torch.ops.xla.dynamo_set_buffer_donor_(k_cache, True)
k_cache[:, dst_indices] = k_cache[:, src_indices]
torch.ops.xla.dynamo_set_buffer_donor_(v_cache, True)
# torch.ops.xla.dynamo_set_buffer_donor_(v_cache, True)
v_cache[:, dst_indices] = v_cache[:, src_indices]


Expand Down Expand Up @@ -213,7 +217,8 @@ def forward(
# [batch_size, num_heads, seq_len, d_model]
# while the input is [batch_size, seq_len, num_heads, d_model].
# Permute the input to match the required format.
output = torch.ops.xla.flash_attention(
output = torchax.interop.call_jax(
flash_attention.flash_attention,
query.permute(0, 2, 1, 3),
key.permute(0, 2, 1, 3),
value.permute(0, 2, 1, 3),
Expand All @@ -226,7 +231,8 @@ def forward(
num_kv_pages_per_compute_block = 16
num_queries_per_compute_block = 16
assert seq_len % num_queries_per_compute_block == 0
output = torch.ops.xla.multi_queries_paged_attention(
output = torchax.interop.call_jax(
multi_queries_paged_attention,
query,
key_cache,
value_cache,
Expand Down Expand Up @@ -326,26 +332,14 @@ def paged_attention(
else:
megacore_mode = megacore_mode

# NOTE(woosuk): A temporary workaround to avoid the error:
# "xla::paged_attention() Expected a value of type 'str' for
# argument 'megacore_mode' but instead found type 'NoneType'."
if megacore_mode is not None:
output = torch.ops.xla.paged_attention(
query,
key_cache,
value_cache,
context_lens,
block_tables,
pages_per_compute_block,
megacore_mode=megacore_mode,
)
else:
output = torch.ops.xla.paged_attention(
query,
key_cache,
value_cache,
context_lens,
block_tables,
pages_per_compute_block,
)
output = torchax.interop.call_jax(
paged_attention,
query,
key_cache,
value_cache,
context_lens,
block_tables,
pages_per_compute_block,
megacore_mode=megacore_mode,
)
return output
Loading

0 comments on commit d121759

Please sign in to comment.