diff --git a/src/liger_kernel/ops/rope.py b/src/liger_kernel/ops/rope.py index 0cd88efeb..5ca9b62f1 100644 --- a/src/liger_kernel/ops/rope.py +++ b/src/liger_kernel/ops/rope.py @@ -15,6 +15,7 @@ def _triton_rope( sin_row_stride, sl, bs: tl.constexpr, + cos_bs: tl.constexpr, n_qh: tl.constexpr, n_kh: tl.constexpr, hd: tl.constexpr, @@ -29,7 +30,7 @@ def _triton_rope( # k size: (bsz, seq_len, num_kv_heads, head_dim) # k stride: (seq_len * num_kv_heads * head_dim, num_kv_heads * head_dim, head_dim, 1) - # cos size: (1, seq_len, head_dim) + # cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim) # stride: (seq_len * head_dim, head_dim, 1) pid = tl.program_id(0) @@ -48,9 +49,19 @@ def _triton_rope( # and pid % sl to get the sequence index. # 2. We only need the left half of cos and sin matrix because the right half is just # a clone of the left half. - cos_row_idx = pid % (sl) - cos = cos + cos_row_idx * cos_row_stride - sin = sin + cos_row_idx * sin_row_stride + batch_idx = pid // sl + cos_row_idx = pid % sl + cos = cos + tl.where( + cos_bs == 1, + cos_row_idx * cos_row_stride, + batch_idx * (sl * cos_row_stride) + cos_row_idx * cos_row_stride, + ) + sin = sin + tl.where( + cos_bs == 1, + cos_row_idx * sin_row_stride, + batch_idx * (sl * sin_row_stride) + cos_row_idx * sin_row_stride, + ) + cos_offsets = tl.arange(0, pad_hd // 2) cos_mask = cos_offsets < hd // 2 cos_row = tl.load(cos + cos_offsets, mask=cos_mask, other=0) @@ -118,7 +129,6 @@ def _triton_rope( def rope_forward(q, k, cos, sin): - # transpose it back to the physical shape because Triton looks at the physical storage # note: q and k are incontiguous before the transformation and will become contiguous after transpose q = q.transpose(1, 2) @@ -138,6 +148,7 @@ def rope_forward(q, k, cos, sin): k = k.contiguous() cos = cos.contiguous() sin = sin.contiguous() + cos_batch_size = cos.shape[0] _triton_rope[(n_row,)]( q, @@ -150,6 +161,7 @@ def rope_forward(q, k, cos, sin): sin.stride(-2), seq_len, batch_size, + cos_batch_size, n_q_head, n_kv_head, head_dim, @@ -167,6 +179,7 @@ def rope_backward(dq, dk, cos, sin): dk = dk.transpose(1, 2) batch_size, seq_len, n_q_head, head_dim = dq.shape + cos_batch_size = cos.shape[0] n_kv_head = dk.shape[2] pad_hd = triton.next_power_of_2(head_dim) pad_n_q_head = triton.next_power_of_2(n_q_head) @@ -191,6 +204,7 @@ def rope_backward(dq, dk, cos, sin): sin.stride(-2), seq_len, batch_size, + cos_batch_size, n_q_head, n_kv_head, head_dim, @@ -221,8 +235,8 @@ def forward(ctx, q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """ q size: (bsz, n_q_head, seq_len, head_dim) k size: (bsz, n_kv_head, seq_len, head_dim) - cos size: (1, seq_len, head_dim) - sin size: (1, seq_len, head_dim) + cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim) + sin size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim) """ q, k, cos, sin = rope_forward(q, k, cos, sin) ctx.save_for_backward(cos, sin) @@ -232,8 +246,8 @@ def backward(ctx, dq, dk): """ dq size: (bsz, n_q_head, seq_len, head_dim) dk size: (bsz, n_kv_head, seq_len, head_dim) - cos size: (1, seq_len, head_dim) - sin size: (1, seq_len, head_dim) + cos size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim) + sin size: (1, seq_len, head_dim) or (bsz, seq_len, head_dim) """ cos, sin = ctx.saved_tensors diff --git a/src/liger_kernel/transformers/rope.py b/src/liger_kernel/transformers/rope.py index a40b29af3..de060ea01 100644 --- a/src/liger_kernel/transformers/rope.py +++ b/src/liger_kernel/transformers/rope.py @@ -8,8 +8,8 @@ def liger_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): Args: q (torch.Tensor): The query tensor of shape (bsz, n_q_head, seq_len, head_dim). k (torch.Tensor): The key tensor of shape (bsz, n_kv_head, seq_len, head_dim). - cos (torch.Tensor): The cosine tensor of shape (1, seq_len, head_dim). - sin (torch.Tensor): The sine tensor of shape (1, seq_len, head_dim). + cos (torch.Tensor): The cosine tensor of shape (1, seq_len, head_dim) or (bsz, seq_len, head_dim). + sin (torch.Tensor): The sine tensor of shape (1, seq_len, head_dim) or (bsz, seq_len, head_dim). position_ids (torch.Tensor, optional): The position ids tensor. Defaults to None. unsqueeze_dim (int, optional): The dimension to unsqueeze. Defaults to 1. diff --git a/test/transformers/test_rope.py b/test/transformers/test_rope.py index 74080b57f..190d9f8f7 100644 --- a/test/transformers/test_rope.py +++ b/test/transformers/test_rope.py @@ -46,8 +46,20 @@ ), ], ) +@pytest.mark.parametrize( + "expand_position_ids", + [True, False], +) def test_correctness( - bsz, seq_len, num_q_heads, num_kv_heads, head_dim, dtype, atol, rtol + bsz, + seq_len, + num_q_heads, + num_kv_heads, + head_dim, + dtype, + expand_position_ids, + atol, + rtol, ): rotary_emb = LlamaRotaryEmbedding(head_dim, device=device) @@ -70,6 +82,8 @@ def test_correctness( k2 = _tensor_k.clone().requires_grad_(True) pos_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0) + if expand_position_ids: + pos_ids = pos_ids.expand(bsz, -1) cos, sin = rotary_emb(k1, pos_ids) # validate forward pass @@ -111,8 +125,20 @@ def test_correctness( (torch.bfloat16, 1e-1, 1e-5), ], ) +@pytest.mark.parametrize( + "expand_position_ids", + [True, False], +) def test_functional_correctness( - bsz, seq_len, num_q_heads, num_kv_heads, head_dim, dtype, atol, rtol + bsz, + seq_len, + num_q_heads, + num_kv_heads, + head_dim, + expand_position_ids, + dtype, + atol, + rtol, ): _q = torch.randn((bsz, num_q_heads, seq_len, head_dim), device=device, dtype=dtype) _k = torch.randn((bsz, num_kv_heads, seq_len, head_dim), device=device, dtype=dtype) @@ -126,6 +152,8 @@ def test_functional_correctness( rotary_emb = LlamaRotaryEmbedding(head_dim, device=device) pos_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0) + if expand_position_ids: + pos_ids = pos_ids.expand(bsz, -1) cos, sin = rotary_emb(k1, pos_ids) functional_q, functional_k = liger_rope(q=q1, k=k1, cos=cos, sin=sin)