Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Rope Compatibility with Cos/Sin Position Embedding for Batch Size > 1 #477

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
32 changes: 23 additions & 9 deletions src/liger_kernel/ops/rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def _triton_rope(
sin_row_stride,
sl,
bs: tl.constexpr,
cos_bs: tl.constexpr,
n_qh: tl.constexpr,
n_kh: tl.constexpr,
hd: tl.constexpr,
Expand All @@ -29,7 +30,7 @@ def _triton_rope(
# k size: (bsz, seq_len, num_kv_heads, head_dim)
# k stride: (seq_len * num_kv_heads * head_dim, num_kv_heads * head_dim, head_dim, 1)

# cos size: (1, seq_len, head_dim)
# cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
# stride: (seq_len * head_dim, head_dim, 1)
pid = tl.program_id(0)

Expand All @@ -48,9 +49,19 @@ def _triton_rope(
# and pid % sl to get the sequence index.
# 2. We only need the left half of cos and sin matrix because the right half is just
# a clone of the left half.
cos_row_idx = pid % (sl)
cos = cos + cos_row_idx * cos_row_stride
sin = sin + cos_row_idx * sin_row_stride
batch_idx = pid // sl
cos_row_idx = pid % sl
cos = cos + tl.where(
cos_bs == 1,
cos_row_idx * cos_row_stride,
batch_idx * (sl * cos_row_stride) + cos_row_idx * cos_row_stride
)
sin = sin + tl.where(
cos_bs == 1,
cos_row_idx * sin_row_stride,
batch_idx * (sl * sin_row_stride) + cos_row_idx * sin_row_stride
)

cos_offsets = tl.arange(0, pad_hd // 2)
cos_mask = cos_offsets < hd // 2
cos_row = tl.load(cos + cos_offsets, mask=cos_mask, other=0)
Expand Down Expand Up @@ -118,7 +129,6 @@ def _triton_rope(


def rope_forward(q, k, cos, sin):

# transpose it back to the physical shape because Triton looks at the physical storage
# note: q and k are incontiguous before the transformation and will become contiguous after transpose
q = q.transpose(1, 2)
Expand All @@ -138,6 +148,7 @@ def rope_forward(q, k, cos, sin):
k = k.contiguous()
cos = cos.contiguous()
sin = sin.contiguous()
cos_batch_size = cos.shape[0]

_triton_rope[(n_row,)](
q,
Expand All @@ -150,6 +161,7 @@ def rope_forward(q, k, cos, sin):
sin.stride(-2),
seq_len,
batch_size,
cos_batch_size,
n_q_head,
n_kv_head,
head_dim,
Expand All @@ -167,6 +179,7 @@ def rope_backward(dq, dk, cos, sin):
dk = dk.transpose(1, 2)

batch_size, seq_len, n_q_head, head_dim = dq.shape
cos_batch_size = cos.shape[0]
n_kv_head = dk.shape[2]
pad_hd = triton.next_power_of_2(head_dim)
pad_n_q_head = triton.next_power_of_2(n_q_head)
Expand All @@ -191,6 +204,7 @@ def rope_backward(dq, dk, cos, sin):
sin.stride(-2),
seq_len,
batch_size,
cos_batch_size,
n_q_head,
n_kv_head,
head_dim,
Expand Down Expand Up @@ -221,8 +235,8 @@ def forward(ctx, q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""
q size: (bsz, n_q_head, seq_len, head_dim)
k size: (bsz, n_kv_head, seq_len, head_dim)
cos size: (1, seq_len, head_dim)
sin size: (1, seq_len, head_dim)
cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
sin size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
"""
q, k, cos, sin = rope_forward(q, k, cos, sin)
ctx.save_for_backward(cos, sin)
Expand All @@ -232,8 +246,8 @@ def backward(ctx, dq, dk):
"""
dq size: (bsz, n_q_head, seq_len, head_dim)
dk size: (bsz, n_kv_head, seq_len, head_dim)
cos size: (1, seq_len, head_dim)
sin size: (1, seq_len, head_dim)
cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
sin size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim)
"""

cos, sin = ctx.saved_tensors
Expand Down
4 changes: 2 additions & 2 deletions src/liger_kernel/transformers/rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ def liger_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
Args:
q (torch.Tensor): The query tensor of shape (bsz, n_q_head, seq_len, head_dim).
k (torch.Tensor): The key tensor of shape (bsz, n_kv_head, seq_len, head_dim).
cos (torch.Tensor): The cosine tensor of shape (1, seq_len, head_dim).
sin (torch.Tensor): The sine tensor of shape (1, seq_len, head_dim).
cos (torch.Tensor): The cosine tensor of shape (1, seq_len, head_dim) or (bsz, seq_len, head_dim).
sin (torch.Tensor): The sine tensor of shape (1, seq_len, head_dim) or (bsz, seq_len, head_dim).
position_ids (torch.Tensor, optional): The position ids tensor. Defaults to None.
unsqueeze_dim (int, optional): The dimension to unsqueeze. Defaults to 1.

Expand Down
16 changes: 14 additions & 2 deletions test/transformers/test_rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,12 @@
),
],
)
@pytest.mark.parametrize(
"expand_position_ids",
[True, False],
)
def test_correctness(
bsz, seq_len, num_q_heads, num_kv_heads, head_dim, dtype, atol, rtol
bsz, seq_len, num_q_heads, num_kv_heads, head_dim, dtype, expand_position_ids, atol, rtol
):
rotary_emb = LlamaRotaryEmbedding(head_dim, device=device)

Expand All @@ -70,6 +74,8 @@ def test_correctness(
k2 = _tensor_k.clone().requires_grad_(True)

pos_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0)
if expand_position_ids:
pos_ids = pos_ids.expand(bsz, -1)
cos, sin = rotary_emb(k1, pos_ids)

# validate forward pass
Expand Down Expand Up @@ -111,8 +117,12 @@ def test_correctness(
(torch.bfloat16, 1e-1, 1e-5),
],
)
@pytest.mark.parametrize(
"expand_position_ids",
[True, False],
)
def test_functional_correctness(
bsz, seq_len, num_q_heads, num_kv_heads, head_dim, dtype, atol, rtol
bsz, seq_len, num_q_heads, num_kv_heads, head_dim, expand_position_ids, dtype, atol, rtol
):
_q = torch.randn((bsz, num_q_heads, seq_len, head_dim), device=device, dtype=dtype)
_k = torch.randn((bsz, num_kv_heads, seq_len, head_dim), device=device, dtype=dtype)
Expand All @@ -126,6 +136,8 @@ def test_functional_correctness(
rotary_emb = LlamaRotaryEmbedding(head_dim, device=device)

pos_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0)
if expand_position_ids:
pos_ids = pos_ids.expand(bsz, -1)
cos, sin = rotary_emb(k1, pos_ids)

functional_q, functional_k = liger_rope(q=q1, k=k1, cos=cos, sin=sin)
Expand Down
Loading