From 71c8afe7b58aef07090983e3870189e6b1a3a701 Mon Sep 17 00:00:00 2001
From: Ying Sheng <sqy1415@gmail.com>
Date: Mon, 16 Sep 2024 07:20:25 +0000
Subject: [PATCH] add hybrid kv

---
 python/sglang/bench_latency.py                |  16 +-
 .../srt/layers/parallel_utils/__init__.py     |   1 +
 .../layers/parallel_utils/parallel_state.py   |  96 ++++
 python/sglang/srt/layers/radix_attention.py   | 279 ++++++++++
 python/sglang/srt/layers/sp_linear.py         | 503 ++++++++++++++++++
 python/sglang/srt/managers/schedule_batch.py  | 118 +++-
 .../srt/managers/seq_parallel_layout.py       | 302 +++++++++++
 .../srt/model_executor/cuda_graph_runner.py   |  13 +
 .../srt/model_executor/forward_batch_info.py  | 147 ++++-
 .../sglang/srt/model_executor/model_runner.py |  45 +-
 python/sglang/srt/models/llama.py             |  54 +-
 python/sglang/srt/server_args.py              |   7 +
 test/srt/test_seq_parallel_attn_kernel.py     | 233 ++++++++
 .../test_seq_parallel_attn_kernel_simple.py   | 279 ++++++++++
 test/srt/test_sp_comm_group.py                |  70 +++
 test/srt/test_sp_decode_attn.py               | 191 +++++++
 16 files changed, 2322 insertions(+), 32 deletions(-)
 create mode 100644 python/sglang/srt/layers/parallel_utils/__init__.py
 create mode 100644 python/sglang/srt/layers/parallel_utils/parallel_state.py
 create mode 100644 python/sglang/srt/layers/sp_linear.py
 create mode 100644 python/sglang/srt/managers/seq_parallel_layout.py
 create mode 100644 test/srt/test_seq_parallel_attn_kernel.py
 create mode 100644 test/srt/test_seq_parallel_attn_kernel_simple.py
 create mode 100644 test/srt/test_sp_comm_group.py
 create mode 100644 test/srt/test_sp_decode_attn.py

diff --git a/python/sglang/bench_latency.py b/python/sglang/bench_latency.py
index 9006b7150aa..19f4960c5ac 100644
--- a/python/sglang/bench_latency.py
+++ b/python/sglang/bench_latency.py
@@ -115,7 +115,7 @@ def from_cli_args(cls, args: argparse.Namespace):
         )
 
 
-def load_model(server_args, tp_rank):
+def load_model(server_args, tp_rank, sp_rank: int = 0):
     suppress_other_loggers()
     rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
 
@@ -130,6 +130,8 @@ def load_model(server_args, tp_rank):
         gpu_id=tp_rank,
         tp_rank=tp_rank,
         tp_size=server_args.tp_size,
+        sp_rank=sp_rank,
+        sp_size=server_args.sp_size,
         nccl_port=28888,
         server_args=server_args,
     )
@@ -206,6 +208,8 @@ def extend(reqs, model_runner):
         req_to_token_pool=model_runner.req_to_token_pool,
         token_to_kv_pool=model_runner.token_to_kv_pool,
         tree_cache=None,
+        sp_size=model_runner.sp_size,
+        sp_rank=model_runner.sp_rank,
     )
     batch.prepare_for_extend(model_runner.model_config.vocab_size)
     sample_output, logits_output = model_runner.forward(batch, ForwardMode.EXTEND)
@@ -225,11 +229,12 @@ def correctness_test(
     server_args,
     bench_args,
     tp_rank,
+    sp_rank=0,
 ):
     rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
 
     # Load the model
-    model_runner, tokenizer = load_model(server_args, tp_rank)
+    model_runner, tokenizer = load_model(server_args, tp_rank, sp_rank)
 
     # Prepare inputs
     input_ids, reqs = prepare_inputs_for_correctness_test(bench_args, tokenizer)
@@ -336,11 +341,12 @@ def latency_test(
     server_args,
     bench_args,
     tp_rank,
+    sp_rank=0,
 ):
     rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
 
     # Load the model
-    model_runner, tokenizer = load_model(server_args, tp_rank)
+    model_runner, tokenizer = load_model(server_args, tp_rank, sp_rank)
 
     # Prepare inputs for warm up
     reqs = prepare_synthetic_inputs_for_latency_test(
@@ -458,16 +464,18 @@ def main(server_args, bench_args):
         )
 
     if server_args.tp_size == 1:
-        work_func(server_args, bench_args, 0)
+        work_func(server_args, bench_args, 0, 0)
     else:
         workers = []
         for tp_rank in range(server_args.tp_size):
+            sp_rank = tp_rank % server_args.sp_size
             proc = multiprocessing.Process(
                 target=work_func,
                 args=(
                     server_args,
                     bench_args,
                     tp_rank,
+                    sp_rank,
                 ),
             )
             proc.start()
diff --git a/python/sglang/srt/layers/parallel_utils/__init__.py b/python/sglang/srt/layers/parallel_utils/__init__.py
new file mode 100644
index 00000000000..f8104e1d30d
--- /dev/null
+++ b/python/sglang/srt/layers/parallel_utils/__init__.py
@@ -0,0 +1 @@
+from .parallel_state import *
diff --git a/python/sglang/srt/layers/parallel_utils/parallel_state.py b/python/sglang/srt/layers/parallel_utils/parallel_state.py
new file mode 100644
index 00000000000..4c1a05f0724
--- /dev/null
+++ b/python/sglang/srt/layers/parallel_utils/parallel_state.py
@@ -0,0 +1,96 @@
+from typing import List, Optional
+
+import torch
+from vllm.distributed import initialize_model_parallel as vllm_initialize_model_parallel
+from vllm.distributed.parallel_state import (
+    GroupCoordinator,
+    get_tensor_model_parallel_rank,
+    get_tensor_model_parallel_world_size,
+    get_world_group,
+    init_model_parallel_group,
+)
+
+_SP: Optional[GroupCoordinator] = None
+
+
+def get_sp_group():
+    assert _SP is not None, "sequence parallel group is not initialized"
+    return _SP
+
+
+def init_sequence_parallel_group(
+    group_ranks: List[List[int]], local_rank: int, backend: str
+) -> GroupCoordinator:
+    return GroupCoordinator(
+        group_ranks=group_ranks,
+        local_rank=local_rank,
+        torch_distributed_backend=backend,
+        use_pynccl=True,
+    )
+
+
+def initialize_model_parallel(
+    tensor_model_parallel_size: int = 1,
+    pipeline_model_parallel_size: int = 1,
+    sequence_parallel_size: int = 1,
+    backend: Optional[str] = None,
+) -> None:
+    """
+    Initialize model parallel groups and sequence parallel groups.
+
+    For sequence parallelism, we partition SP groups within a TP group, and assign
+    gpus with adjacent ranks to the same SP group. For example, with TP size 8
+    and SP size 2, we have 1 TP group and 4 SP groups:
+    SP groups:
+        [g0, g1], [g2, g3], [g4, g5], [g6, g7]
+    Their KV TP rank:
+        [ 0,  0], [ 1,  1], [ 2,  2], [ 3,  3]
+    Given that we replicate KV heads within the same seq parallel group, we also say that
+    the KV TP size is 4 (8//2), and gpus in each SP group have KV-tp rank from 0 to 3.
+    """
+    assert torch.distributed.is_initialized()
+    world_size: int = torch.distributed.get_world_size()
+    backend = backend or torch.distributed.get_backend(get_world_group().device_group)
+
+    num_sequence_parallel_groups: int = world_size // sequence_parallel_size
+    global _SP
+    assert _SP is None, "sequence parallel group is already initialized"
+    group_ranks = []
+    for i in range(num_sequence_parallel_groups):
+        ranks = list(
+            range(i * sequence_parallel_size, (i + 1) * sequence_parallel_size)
+        )
+        group_ranks.append(ranks)
+    _SP = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend)
+
+    vllm_initialize_model_parallel(
+        tensor_model_parallel_size, pipeline_model_parallel_size, backend
+    )
+
+
+def sequence_parallel_is_initialized():
+    return _SP is not None
+
+
+def get_sequence_parallel_world_size():
+    return get_sp_group().world_size
+
+
+def get_sequence_parallel_rank():
+    return get_sp_group().rank_in_group
+
+
+def get_sequence_parallel_global_rank():
+    return get_tensor_model_parallel_rank()
+
+
+# NOTE: For sequence parallelism, we partition Q tensors along the head dimension.
+# But K/V tensors are partitioned along the head dimension in TP and partitioned
+# along the sequence dimensions in SP. Therefore, their TP size and rank is adjusted
+# accordingly as below.
+def get_kv_tensor_model_parallel_world_size():
+    return get_tensor_model_parallel_world_size() // get_sequence_parallel_world_size()
+
+
+def get_kv_tensor_model_parallel_rank():
+    return get_tensor_model_parallel_rank() // get_sequence_parallel_world_size()
diff --git a/python/sglang/srt/layers/radix_attention.py b/python/sglang/srt/layers/radix_attention.py
index 91735a1b810..9ec35cc70a6 100644
--- a/python/sglang/srt/layers/radix_attention.py
+++ b/python/sglang/srt/layers/radix_attention.py
@@ -20,10 +20,12 @@
 import torch
 from flashinfer.cascade import merge_state
 from torch import nn
+from torch.distributed import P2POp, batch_isend_irecv, irecv, isend
 
 from sglang.global_config import global_config
 from sglang.srt.layers.decode_attention import decode_attention_fwd
 from sglang.srt.layers.extend_attention import extend_attention_fwd
+from sglang.srt.layers.parallel_utils import get_sp_group
 from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
 from sglang.srt.model_executor.model_runner import global_server_args_dict
 
@@ -64,6 +66,11 @@ def __init__(
         self.logit_cap = logit_cap if logit_cap is not None and logit_cap > 0 else 0
 
     def extend_forward_triton(self, q, k, v, input_metadata: InputMetadata):
+        if input_metadata.sp_size > 1:
+            raise NotImplementedError(
+                "Sequence parallel is not supported with Triton backend."
+            )
+
         if self.qk_head_dim != self.v_head_dim:
             o = q.new_empty((q.shape[0], self.tp_q_head_num * self.v_head_dim))
         else:
@@ -93,6 +100,11 @@ def extend_forward_triton(self, q, k, v, input_metadata: InputMetadata):
         return o
 
     def decode_forward_triton(self, q, k, v, input_metadata: InputMetadata):
+        if input_metadata.sp_size > 1:
+            raise NotImplementedError(
+                "Sequence parallel is not supported with Triton backend."
+            )
+
         if self.qk_head_dim != self.v_head_dim:
             o = q.new_empty((q.shape[0], self.tp_q_head_num * self.v_head_dim))
         else:
@@ -117,6 +129,8 @@ def decode_forward_triton(self, q, k, v, input_metadata: InputMetadata):
         return o
 
     def extend_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
+        if input_metadata.sp_size > 1:
+            return self.seq_parallel_extend_forward_flashinfer(q, k, v, input_metadata)
         # using two wrappers is unnecessary in the current PR, but are prepared for future PRs
         prefill_wrapper_paged = input_metadata.flashinfer_prefill_wrapper_paged
         if self.sliding_window_size != -1:
@@ -171,6 +185,8 @@ def extend_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
         return o.view(-1, self.tp_q_head_num * self.head_dim)
 
     def decode_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
+        if input_metadata.sp_size > 1:
+            return self.seq_parallel_decode_forward_flashinfer(q, k, v, input_metadata)
         decode_wrapper = input_metadata.flashinfer_decode_wrapper
         if self.sliding_window_size != -1:
             decode_wrapper = decode_wrapper[0]
@@ -191,6 +207,257 @@ def decode_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
 
         return o.view(-1, self.tp_q_head_num * self.head_dim)
 
+    def launch_sp_comm_ops(
+        self, kv_to_recv, kv_to_send, from_rank, to_rank, my_rank, sp_size, itr
+    ):
+        # Interleaving workers for send and recv to avoid deadlock
+        def _send_first():
+            flags = [None for _ in range(sp_size)]
+            for _rank in range(sp_size):
+                _next = _rank
+                flag = True
+                while flags[_next] is None:
+                    flags[_next] = flag
+                    _next = (_next + itr) % sp_size
+                    flag = not flag
+            return flags[my_rank]
+
+        def _send(handles, group):
+            if my_rank != to_rank:
+                to_global_rank = group.first_rank + to_rank
+                for t in kv_to_send:
+                    handles.append(
+                        P2POp(
+                            op=isend,
+                            tensor=t,
+                            peer=to_global_rank,
+                            group=group.device_group,
+                        )
+                    )
+
+        def _recv(handles, group):
+            if my_rank != from_rank:
+                from_global_rank = group.first_rank + from_rank
+                for t in kv_to_recv:
+                    handles.append(
+                        P2POp(
+                            op=irecv,
+                            tensor=t,
+                            peer=from_global_rank,
+                            group=group.device_group,
+                        )
+                    )
+
+        handles = []
+        reqs = []
+        sp_group = get_sp_group()
+
+        if _send_first():
+            _send(handles, sp_group)
+            _recv(handles, sp_group)
+        else:
+            _recv(handles, sp_group)
+            _send(handles, sp_group)
+        if handles:
+            reqs = batch_isend_irecv(handles)
+        return reqs
+
+    def wait_sp_comm_ops(self, reqs):
+        for req in reqs:
+            req.wait()
+
+    def seq_parallel_extend_forward_flashinfer(
+        self, q, k, v, input_metadata: InputMetadata
+    ):
+        """Here we adopted a unique parallelization strategy.
+        For each SP worker, we have either (1) QKV of entire sequences:
+            q tensor: [padded_total_num_tokens, q_head_num // SP_SIZE, head_dim]
+            k tensor: [padded_total_num_tokens, k_head_num, head_dim]
+            v tensor: [padded_total_num_tokens, v_head_num, head_dim]
+        Or (2) Q of entire sequences and KV of the current SP shard:
+            q tensor: [padded_total_num_tokens, q_head_num // SP_SIZE, head_dim]
+            k tensor: [padded_sp_shard_num_tokens, k_head_num, head_dim]
+            v tensor: [padded_sp_shard_num_tokens, v_head_num, head_dim]
+
+        Case (1) saves cross-SP-worker communication, while case (2) saves computation
+        to get K and V for entire sequences but need computation in SP attn.
+        """
+
+        def append_merge_shard(shard_list, o, s):
+            if len(shard_list) == 0:
+                shard_list.append((o, s))
+            else:
+                o_prev, s_prev = shard_list[-1]
+                o, s = merge_state(o_prev, s_prev, o, s)
+                shard_list[-1] = (o, s)
+
+        sp_rank = input_metadata.sp_rank
+        sp_size = input_metadata.sp_size
+        num_shards = num_iters = sp_size
+        sp_shard_size = (q.shape[0] + sp_size - 1) // sp_size
+        assert k.shape[0] == v.shape[0] and (
+            k.shape[0] == q.shape[0] or k.shape[0] == sp_shard_size
+        ), "Invalid K and V partition in sequence parallel."
+
+        qs = []
+        for i in range(num_shards):
+            qs.append(q[sp_shard_size * i : sp_shard_size * (i + 1)])
+        need_comm = k.shape[0] == sp_shard_size  # Case 2.
+
+        owned_sids = [sp_rank]
+        kv_shards = [None for _ in range(num_shards)]
+        output_shards = [[] for _ in range(num_shards)]
+
+        if need_comm:  # We have already got sharded K and V.
+            local_k = k.contiguous().view(-1, self.tp_k_head_num, self.head_dim)
+            local_v = v.contiguous().view(-1, self.tp_v_head_num, self.head_dim)
+            for i in range(sp_size):
+                if i == sp_rank:
+                    kv_shards[i] = (local_k, local_v)
+                else:  # reserve space for kv tensors received from other peers
+                    kv_shards[i] = (
+                        torch.empty_like(local_k),
+                        torch.empty_like(local_v),
+                    )
+        else:  # We need to manually shard K and V.
+            for i in range(num_shards):
+                k_shard = k[sp_shard_size * i : sp_shard_size * (i + 1)]
+                v_shard = v[sp_shard_size * i : sp_shard_size * (i + 1)]
+                kv_shards[i] = (
+                    k_shard.contiguous().view(-1, self.tp_k_head_num, self.head_dim),
+                    v_shard.contiguous().view(-1, self.tp_v_head_num, self.head_dim),
+                )
+            local_k, local_v = kv_shards[sp_rank]
+
+        # For communication
+        to_rank = sp_rank  # which SP worker to send my sequence KV shard to.
+        from_rank = sp_rank  # which SP worker to receive the sequence KV shard from.
+        sid = sp_rank  # start from the worker's own shard
+        for itr in range(num_iters):
+            to_rank = (to_rank + 1) % sp_size
+            from_rank = (from_rank - 1) % sp_size
+            if need_comm:  # Launch async communication operations
+                comm_reqs = self.launch_sp_comm_ops(
+                    kv_shards[from_rank],
+                    kv_shards[sp_rank],
+                    from_rank,
+                    to_rank,
+                    sp_rank,
+                    sp_size,
+                    itr,
+                )
+            q_shard = qs[sid]
+            k_shard, v_shard = kv_shards[sid]
+            # Self attention within the SP shard.
+            attn_wrapper = (  # Only the last SP shard needs a mask.
+                input_metadata.flashinfer_prefill_wrapper_sp_causal
+                if sid == sp_size - 1
+                else input_metadata.flashinfer_prefill_wrapper_ragged
+            )
+            o, s = attn_wrapper.forward_return_lse(
+                q_shard.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
+                k_shard.contiguous().view(-1, self.tp_k_head_num, self.head_dim),
+                v_shard.contiguous().view(-1, self.tp_v_head_num, self.head_dim),
+                causal=True,
+                sm_scale=self.scaling,
+                logits_soft_cap=self.logit_cap,
+            )
+            append_merge_shard(output_shards[sid], o, s)
+            # Cross SP shard attention.
+            # NOTE: below schedule is for load balancing. Basically, at iteration i,
+            # (i starting from 0), each SP worker will run i paged attentions.
+            for existing_sid in owned_sids:
+                if existing_sid == sid:
+                    continue
+                # Due to the causal nature of the attention, swap pids if necessary.
+                i, j = (
+                    (existing_sid, sid) if existing_sid > sid else (sid, existing_sid)
+                )
+                q_shard = qs[i]
+                k_shard, v_shard = kv_shards[j]
+                attn_wrapper = (  # Only the last SP shard needs a mask.
+                    input_metadata.flashinfer_prefill_wrapper_sp_full
+                    if i == sp_size - 1
+                    else input_metadata.flashinfer_prefill_wrapper_ragged
+                )
+                o, s = attn_wrapper.forward_return_lse(
+                    q_shard.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
+                    k_shard.contiguous().view(-1, self.tp_k_head_num, self.head_dim),
+                    v_shard.contiguous().view(-1, self.tp_v_head_num, self.head_dim),
+                    causal=False,
+                    sm_scale=self.scaling,
+                    logits_soft_cap=self.logit_cap,
+                )
+                append_merge_shard(output_shards[i], o, s)
+
+            if need_comm:  # Wait for async communication to complete.
+                self.wait_sp_comm_ops(comm_reqs)
+            if sp_rank != from_rank:
+                owned_sids.append(from_rank)
+            sid = from_rank
+
+        # Concat all output shards along the sequence dimension.
+        os = [o for shard_list in output_shards for o, _ in shard_list]
+        o = torch.cat(os, dim=0)
+
+        self.store_kv_cache(local_k, local_v, input_metadata)
+
+        if input_metadata.total_num_tokens >= global_config.layer_sync_threshold:
+            torch.cuda.synchronize()
+
+        return o.view(-1, self.tp_q_head_num * self.head_dim)
+
+    # TODO(yifan): check if flashinfer seq_parallel is broken after the rebase
+    def seq_parallel_decode_forward_flashinfer(
+        self, q, k, v, input_metadata: InputMetadata
+    ):
+        sp_size = input_metadata.sp_size
+        sp_rank = input_metadata.sp_rank
+        total_num_heads = self.tp_q_head_num * sp_size
+
+        sp_offset = input_metadata.sp_local_token_offset
+        sp_len = input_metadata.sp_local_token_length
+        sp_slice = slice(sp_offset, sp_offset + sp_len)
+        cache_k = k[sp_slice]
+        cache_v = v[sp_slice]
+        self.store_kv_cache(cache_k, cache_v, input_metadata)
+
+        # Convert Q back by gathering all TP heads.
+        q = q.contiguous().view(-1, self.tp_q_head_num, self.head_dim)
+        gathered_q = get_sp_group().all_gather(q.view(1, *q.shape), dim=0)
+        q = torch.empty_like(gathered_q).view(-1, total_num_heads, self.head_dim)
+        for i in range(sp_size):
+            idxs = _get_sequence_parallel_head_idxes(
+                total_num_heads, self.tp_k_head_num, i, sp_size
+            )
+            q[:, idxs] = gathered_q[i]
+
+        o, s = input_metadata.flashinfer_decode_wrapper.forward_return_lse(
+            q.contiguous().view(-1, total_num_heads, self.head_dim),
+            input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id),
+            sm_scale=self.scaling,
+            logits_soft_cap=self.logit_cap,
+        )
+
+        # TODO: in fact we can use all-to-all to gather the output and state here
+        # to collect only q head shards that are needed by the current SP worker.
+        # All-to-all will save communication and `merge_state` computation.
+        os = get_sp_group().all_gather(o.view(1, *o.shape), dim=0)
+        ss = get_sp_group().all_gather(s.view(1, *s.shape), dim=0)
+        for i in range(sp_size):
+            if i != sp_rank:
+                o, s = merge_state(os[i], ss[i], o, s)
+
+        # TODO: consequently, if we use all-to-all rather than all-gather, we don't
+        # need to partition the output again along the head dimension.
+        # Partition the output again along the head dimension.
+        idxs = _get_sequence_parallel_head_idxes(
+            total_num_heads, self.tp_k_head_num, sp_rank, sp_size
+        )
+        o = o[:, idxs]
+
+        return o.view(-1, self.tp_q_head_num * self.head_dim)
+
     def forward(self, q, k, v, input_metadata: InputMetadata):
         if k is not None:
             assert v is not None
@@ -206,3 +473,15 @@ def store_kv_cache(self, cache_k, cache_v, input_metadata: InputMetadata):
         input_metadata.token_to_kv_pool.set_kv_buffer(
             self.layer_id, input_metadata.out_cache_loc, cache_k, cache_v
         )
+
+
+def _get_sequence_parallel_head_idxes(total_num_heads, num_kv_heads, sp_rank, sp_size):
+    group_size = total_num_heads // num_kv_heads
+    shard_num_heads = group_size // sp_size
+
+    idxes = [
+        group_size * i + sp_rank * shard_num_heads + j
+        for i in range(num_kv_heads)
+        for j in range(0, shard_num_heads)
+    ]
+    return idxes
diff --git a/python/sglang/srt/layers/sp_linear.py b/python/sglang/srt/layers/sp_linear.py
new file mode 100644
index 00000000000..f8757738aee
--- /dev/null
+++ b/python/sglang/srt/layers/sp_linear.py
@@ -0,0 +1,503 @@
+# Adapted from
+# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/layers/linear.py#L1
+import logging
+from typing import Dict, Iterable, Optional, Tuple
+
+import torch
+from torch.nn.parameter import Parameter
+from vllm.distributed import divide, get_tensor_model_parallel_world_size
+from vllm.model_executor.layers.linear import (
+    ColumnParallelLinear,
+    RowParallelLinear,
+    adjust_marlin_shard,
+)
+from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
+
+from sglang.srt.layers.parallel_utils import (
+    get_kv_tensor_model_parallel_rank,
+    get_kv_tensor_model_parallel_world_size,
+    get_sequence_parallel_rank,
+    get_sequence_parallel_world_size,
+)
+
+logger = logging.getLogger(__name__)
+
+
+def adjust_bitsandbytes_shard(
+    param: Parameter, kv_offsets: Dict[str, Tuple[int, int]], loaded_shard_id: str
+) -> Tuple[int, int]:
+    """Adjust the quantization offsets and sizes for BitsAndBytes sharding."""
+
+    total, _ = kv_offsets["total"]
+    orig_offset, orig_size = kv_offsets[loaded_shard_id]
+
+    quantized_total = param.data.shape[0]
+    quantized_offset = orig_offset * quantized_total // total
+    quantized_size = orig_size * quantized_total // total
+
+    return quantized_size, quantized_offset
+
+
+def adjust_scalar_to_fused_array(param, loaded_weight, shard_id):
+    """For fused modules (KV) we have an array of length
+    N that holds 1 scale for each "logical" matrix. So the param
+    is an array of length N. The loaded_weight corresponds to
+    one of the shards on disk. Here, we slice the param based on
+    the shard_id for loading.
+    """
+    kv_idxs = {"k": 0, "v": 1}
+
+    if isinstance(shard_id, str):
+        shard_id = kv_idxs[shard_id]
+    elif not isinstance(shard_id, int):
+        raise ValueError(f"Unknown Shard Id {shard_id}")
+
+    # AutoFP8 scales do not have a shape
+    # compressed-tensors scales do have a shape
+    if len(loaded_weight.shape) != 0:
+        assert loaded_weight.shape[0] == 1
+        loaded_weight = loaded_weight[0]
+
+    return param[shard_id], loaded_weight
+
+
+class QKVParallelLinear(torch.nn.Module):
+    def __init__(
+        self,
+        hidden_size: int,
+        head_size: int,
+        total_num_heads: int,
+        total_num_kv_heads: Optional[int] = None,
+        bias: bool = True,
+        skip_bias_add: bool = False,
+        params_dtype: Optional[torch.dtype] = None,
+        quant_config: Optional[QuantizationConfig] = None,
+        prefix: str = "",
+    ):
+        super().__init__()
+        # q projection can be naively tensor parallelized. However, to adapt to
+        # GQA, we need to manually partition q heads for sequence parallelism.
+        # See _get_sequence_parallel_head_idxes() for details.
+        self.q_proj = ColumnSeqParallelLinear(
+            hidden_size,
+            head_size,
+            total_num_heads,
+            total_num_kv_heads,
+            bias,
+            skip_bias_add,
+            params_dtype,
+            quant_config,
+            f"{prefix}.q_proj",
+        )
+        # kv projection needs both tensor and sequence parallelization
+        self.kv_proj = KVSeqParallelLinear(
+            hidden_size,
+            head_size,
+            total_num_heads,
+            total_num_kv_heads,
+            bias,
+            skip_bias_add,
+            params_dtype,
+            quant_config,
+            f"{prefix}.kv_proj",
+        )
+        self.hidden_size = hidden_size
+        self.kv_size = self.kv_proj.num_kv_heads * self.kv_proj.head_size
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+        q, _ = self.q_proj(hidden_states)
+        kv, _ = self.kv_proj(hidden_states)
+        k, v = kv.split([self.kv_size, self.kv_size], dim=-1)
+        return q, k, v
+
+
+class ColumnSeqParallelLinear(ColumnParallelLinear):
+    def __init__(
+        self,
+        hidden_size: int,
+        head_size: int,
+        total_num_heads: int,
+        total_num_kv_heads: Optional[int] = None,
+        bias: bool = True,
+        skip_bias_add: bool = False,
+        params_dtype: Optional[torch.dtype] = None,
+        quant_config: Optional[QuantizationConfig] = None,
+        prefix: str = "",
+    ):
+        self.hidden_size = hidden_size
+        self.head_size = head_size
+        if total_num_kv_heads is None:
+            total_num_kv_heads = total_num_heads
+        self.total_num_kv_heads = total_num_kv_heads
+        # Divide the weight matrix along the last dimension.
+        tp_size = get_tensor_model_parallel_world_size()
+        self.total_num_heads = total_num_heads
+        self.num_heads = divide(total_num_heads, tp_size)
+        # num_kv_heads is used for tracking the number of groups in GQA.
+        kv_tp_size = get_kv_tensor_model_parallel_world_size()
+        if kv_tp_size >= self.total_num_kv_heads:
+            self.num_kv_heads = 1
+            self.num_kv_head_replicas = divide(kv_tp_size, self.total_num_kv_heads)
+        else:
+            self.num_kv_heads = divide(self.total_num_kv_heads, kv_tp_size)
+            self.num_kv_head_replicas = 1
+
+        input_size = self.hidden_size
+        # NOTE: here we use total_num_heads to make the parent class happy because
+        # it expects pure tensor parallelism along the heads dimension. output_size
+        # here is the total size of all TP and SP workers.
+        output_size = self.total_num_heads * self.head_size
+
+        super().__init__(
+            input_size=input_size,
+            output_size=output_size,
+            bias=bias,
+            gather_output=False,
+            skip_bias_add=skip_bias_add,
+            params_dtype=params_dtype,
+            quant_config=quant_config,
+            prefix=prefix,
+        )
+
+    def weight_loader(
+        self,
+        param: Parameter,
+        loaded_weight: torch.Tensor,
+    ):
+        kv_tp_rank = get_kv_tensor_model_parallel_rank()
+        kv_tp_size = get_kv_tensor_model_parallel_world_size()
+        sp_size = get_sequence_parallel_world_size()
+        sp_rank = get_sequence_parallel_rank()
+
+        output_dim = getattr(param, "output_dim", None)
+        param_data = param.data
+        if output_dim is not None:
+            shard_size = param_data.shape[output_dim]
+            # Load TP weight shard
+            tp_shard_size = shard_size * sp_size
+            start_idx = kv_tp_rank * tp_shard_size
+            loaded_weight = loaded_weight.narrow(output_dim, start_idx, tp_shard_size)
+            # Load SP weight shard
+            tp_num_heads = self.total_num_heads // kv_tp_size
+            idxes = torch.tensor(
+                _get_sequence_parallel_head_idxes(
+                    tp_num_heads, self.num_kv_heads, sp_rank, sp_size
+                ),
+                dtype=torch.int32,
+            )
+            weight_shape = loaded_weight.shape
+            tp_shard_shape = _reshape_dimension(
+                weight_shape, output_dim, [tp_num_heads, self.head_size]
+            )
+            sp_shard_shape = _reshape_dimension(weight_shape, output_dim, [shard_size])
+            loaded_weight = (
+                loaded_weight.reshape(tp_shard_shape)
+                .index_select(output_dim, idxes)
+                .contiguous()
+                .view(sp_shard_shape)
+            )
+
+        # Special case for loading scales off disk, which often do not
+        # have a shape (such as in the case of AutoFP8).
+        if len(loaded_weight.shape) == 0:
+            loaded_weight = loaded_weight.reshape(1)
+
+        assert param_data.shape == loaded_weight.shape
+        param_data.copy_(loaded_weight)
+
+
+# Adapted from
+# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/layers/linear.py#L422
+class KVSeqParallelLinear(ColumnParallelLinear):
+    """Linear layers for the attention's KV transformation.
+
+    Linear layers for the linear transformation of the key, and value
+    vectors in the attention layer. The weight matrix is concatenated along
+    the output dimension. The layer is parallelized along the head dimension.
+    When the number of key/value heads is smaller than the number of query
+    heads (e.g., multi-query/grouped-query attention), the key/value head may
+    be replicated while the query heads are partitioned.
+
+    Args:
+        hidden_size: input hidden state size of the transformer.
+        head_size: size of each attention head.
+        total_num_heads: total number of attention query heads.
+        total_num_kv_heads: total number of attention key/value heads. If
+                            None, assume total_num_kv_heads = total_num_heads.
+        bias: If true, add bias.
+        skip_bias_add: This was added to enable performance optimizations where
+                       bias can be fused with other element-wise operations. we
+                       skip adding bias but instead return it.
+        params_dtype: Data type for the parameters.
+        linear_method: (Maybe quantized) linear method.
+    """
+
+    def __init__(
+        self,
+        hidden_size: int,
+        head_size: int,
+        total_num_heads: int,
+        total_num_kv_heads: Optional[int] = None,
+        bias: bool = True,
+        skip_bias_add: bool = False,
+        params_dtype: Optional[torch.dtype] = None,
+        quant_config: Optional[QuantizationConfig] = None,
+        prefix: str = "",
+    ):
+        self.hidden_size = hidden_size
+        self.head_size = head_size
+        if total_num_kv_heads is None:
+            total_num_kv_heads = total_num_heads
+        self.total_num_kv_heads = total_num_kv_heads
+        # Divide the weight matrix along the last dimension.
+        kv_tp_size = get_kv_tensor_model_parallel_world_size()
+        if kv_tp_size >= self.total_num_kv_heads:
+            self.num_kv_heads = 1
+            self.num_kv_head_replicas = divide(kv_tp_size, self.total_num_kv_heads)
+        else:
+            self.num_kv_heads = divide(self.total_num_kv_heads, kv_tp_size)
+            self.num_kv_head_replicas = 1
+        input_size = self.hidden_size
+        # NOTE: here we use tp_size to make the parent class happy because it
+        # expects pure tensor parallelism along the num_heads dimension.
+        tp_size = get_tensor_model_parallel_world_size()
+        output_size = 2 * self.num_kv_heads * tp_size * self.head_size
+        self.output_sizes = [
+            self.num_kv_heads * self.head_size * tp_size,  # k_proj
+            self.num_kv_heads * self.head_size * tp_size,  # v_proj
+        ]
+
+        super().__init__(
+            input_size=input_size,
+            output_size=output_size,
+            bias=bias,
+            gather_output=False,
+            skip_bias_add=skip_bias_add,
+            params_dtype=params_dtype,
+            quant_config=quant_config,
+            prefix=prefix,
+        )
+
+    def weight_loader(
+        self,
+        param: Parameter,
+        loaded_weight: torch.Tensor,
+        loaded_shard_id: Optional[str] = None,
+    ):
+        param_data = param.data
+        output_dim = getattr(param, "output_dim", None)
+        # Special case for AQLM codebooks.
+        is_metadata = getattr(param, "is_metadata", False)
+
+        # Special case for per-tensor scales in fused case.
+        needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
+
+        if loaded_shard_id is None:
+            # Loaded weight is already fused on disk (qkv/mlp).
+            if get_sequence_parallel_world_size() > 1:
+                raise NotImplementedError(
+                    "Fused weight loading is not supported in SP."
+                )
+            if output_dim is None:
+                assert param_data.shape == loaded_weight.shape
+                param_data.copy_(loaded_weight)
+                return
+            shard_offsets = [
+                # (shard_id, shard_offset, shard_size)
+                (
+                    "k",
+                    0,
+                    self.total_num_kv_heads * self.head_size,
+                ),
+                (
+                    "v",
+                    self.total_num_kv_heads * self.head_size,
+                    self.total_num_kv_heads * self.head_size,
+                ),
+            ]
+            packed_dim = getattr(param, "packed_dim", None)
+            for shard_id, shard_offset, shard_size in shard_offsets:
+                # If quantized, we need to adjust the offset and size to account
+                # for the packing.
+                if packed_dim == output_dim:
+                    shard_size = shard_size // param.pack_factor
+                    shard_offset = shard_offset // param.pack_factor
+                loaded_weight_shard = loaded_weight.narrow(
+                    output_dim, shard_offset, shard_size
+                )
+                self.weight_loader(param, loaded_weight_shard, shard_id)
+            return
+
+        kv_tp_rank = get_kv_tensor_model_parallel_rank()
+        assert loaded_shard_id in ["k", "v"]
+
+        # If output dim is defined, use the default loading process.
+        if output_dim is not None:
+            if loaded_shard_id == "k":
+                shard_offset = 0
+                shard_size = self.num_kv_heads * self.head_size
+            elif loaded_shard_id == "v":
+                shard_offset = self.num_kv_heads * self.head_size
+                shard_size = self.num_kv_heads * self.head_size
+            # Special case for Quantized Weights.
+            # If quantized, we need to adjust the offset and size to account
+            # for the packing.
+            packed_dim = getattr(param, "packed_dim", None)
+            if packed_dim == output_dim:
+                shard_size = shard_size // param.pack_factor
+                shard_offset = shard_offset // param.pack_factor
+
+                # Special case for Marlin.
+                shard_size, shard_offset = adjust_marlin_shard(
+                    param, shard_size, shard_offset
+                )
+
+            use_bitsandbytes = getattr(param, "use_bitsandbytes", False)
+            if use_bitsandbytes:
+                orig_kv_offsets = {
+                    "k": (
+                        0,
+                        self.num_kv_heads * self.head_size,
+                    ),
+                    "v": (
+                        self.num_kv_heads * self.head_size,
+                        self.num_kv_heads * self.head_size,
+                    ),
+                    "total": (
+                        2 * self.num_kv_heads * self.head_size,
+                        0,
+                    ),
+                }
+                shard_size, shard_offset = adjust_bitsandbytes_shard(
+                    param, orig_kv_offsets, loaded_shard_id
+                )
+
+            param_data = param_data.narrow(output_dim, shard_offset, shard_size)
+            shard_id = kv_tp_rank // self.num_kv_head_replicas
+            start_idx = shard_id * shard_size
+            loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
+        # Special case for for AQLM codebooks.
+        elif is_metadata:
+            # metadata indicates fixed size concatenated along dim 0
+            shard_size = loaded_weight.shape[0]
+            shard_index = ["k", "v"].index(loaded_shard_id)
+            param_data = param_data.narrow(0, shard_index * shard_size, shard_size)
+        # Special case for per-tensor scales in fused case.
+        elif needs_scalar_to_array:
+            param_data, loaded_weight = adjust_scalar_to_fused_array(
+                param_data, loaded_weight, loaded_shard_id
+            )
+        else:
+            ignore_warning = getattr(param, "ignore_warning", False)
+            if not ignore_warning:
+                logger.warning(
+                    "Loading a weight without `output_dim` attribute in "
+                    "QKVParallelLinear, assume the weight is the same "
+                    "for all partitions."
+                )
+        assert param_data.shape == loaded_weight.shape
+        param_data.copy_(loaded_weight)
+
+
+class RowSeqParallelLinear(RowParallelLinear):
+    """TODO: add doc string."""
+
+    def __init__(
+        self,
+        input_size: int,
+        output_size: int,
+        total_num_heads: int,
+        num_kv_heads: int,
+        head_dim: int,
+        bias: bool = True,
+        input_is_parallel: bool = True,
+        skip_bias_add: bool = False,
+        params_dtype: Optional[torch.dtype] = None,
+        reduce_results: bool = True,
+        quant_config: Optional[QuantizationConfig] = None,
+        prefix: str = "",
+    ):
+        super().__init__(
+            input_size,
+            output_size,
+            bias,
+            input_is_parallel,
+            skip_bias_add,
+            params_dtype,
+            reduce_results,
+            quant_config,
+            prefix,
+        )
+        self.total_num_heads = total_num_heads
+        self.num_kv_heads = num_kv_heads
+        self.head_dim = head_dim
+
+    def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
+        kv_tp_rank = get_kv_tensor_model_parallel_rank()
+        kv_tp_size = get_kv_tensor_model_parallel_world_size()
+        sp_size = get_sequence_parallel_world_size()
+        sp_rank = get_sequence_parallel_rank()
+
+        input_dim = getattr(param, "input_dim", None)
+        param_data = param.data
+        if input_dim is not None:
+            shard_size = param_data.shape[input_dim]
+            # Load TP weight shard
+            tp_shard_size = shard_size * sp_size
+            start_idx = kv_tp_rank * tp_shard_size
+            loaded_weight = loaded_weight.narrow(input_dim, start_idx, tp_shard_size)
+            # Load SP weight shard
+            tp_num_heads = self.total_num_heads // kv_tp_size
+            idxes = torch.tensor(
+                _get_sequence_parallel_head_idxes(
+                    tp_num_heads, self.num_kv_heads, sp_rank, sp_size
+                )
+            )
+            weight_shape = loaded_weight.shape
+            tp_shard_shape = _reshape_dimension(
+                weight_shape, input_dim, [tp_num_heads, self.head_dim]
+            )
+            sp_shard_shape = _reshape_dimension(weight_shape, input_dim, [shard_size])
+            loaded_weight = (
+                loaded_weight.reshape(tp_shard_shape)
+                .index_select(input_dim, idxes)
+                .contiguous()
+                .view(sp_shard_shape)
+            )
+
+        # Special case for loading scales off disk, which often do not
+        # have a shape (such as in the case of AutoFP8).
+        if len(loaded_weight.shape) == 0:
+            loaded_weight = loaded_weight.reshape(1)
+
+        assert param_data.shape == loaded_weight.shape
+        param_data.copy_(loaded_weight)
+
+
+def _get_sequence_parallel_head_idxes(total_num_heads, num_kv_heads, sp_rank, sp_size):
+    group_size = total_num_heads // num_kv_heads
+    shard_num_heads = group_size // sp_size
+
+    idxes = [
+        group_size * i + sp_rank * shard_num_heads + j
+        for i in range(num_kv_heads)
+        for j in range(0, shard_num_heads)
+    ]
+    return idxes
+
+
+def _reshape_dimension(shape: Tuple[int], dim_idx: int, new_dims: Iterable[int]):
+    if isinstance(new_dims, int):
+        new_dims = (new_dims,)
+    if not isinstance(shape, tuple):
+        raise TypeError("shape must be a tuple")
+    if not isinstance(new_dims, (list, tuple)):
+        raise TypeError("new_dims must be a list or a tuple")
+    if dim_idx < 0 or dim_idx >= len(shape):
+        raise IndexError("dim_idx out of range")
+
+    new_shape = shape[:dim_idx] + tuple(new_dims) + shape[dim_idx + 1 :]
+    return new_shape
diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py
index c80cf2e2723..48c1dcb0e20 100644
--- a/python/sglang/srt/managers/schedule_batch.py
+++ b/python/sglang/srt/managers/schedule_batch.py
@@ -21,11 +21,18 @@
 from dataclasses import dataclass
 from typing import TYPE_CHECKING, List, Optional, Union
 
+import numpy as np
 import torch
 
 from sglang.global_config import global_config
 from sglang.srt.constrained import RegexGuide
 from sglang.srt.constrained.jump_forward import JumpForwardMap
+from sglang.srt.managers.seq_parallel_layout import (
+    seq_parallel_decode_indices,
+    seq_parallel_input_ids_decode,
+    seq_parallel_input_ids_extend,
+    seq_parallel_local_len_extend,
+)
 from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
 from sglang.srt.mem_cache.chunk_cache import ChunkCache
 from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
@@ -348,8 +355,22 @@ class ScheduleBatch:
     return_logprob: bool = False
     top_logprobs_nums: List[int] = None
 
+    # Sequence Parallel params
+    sp_size: int = None
+    sp_rank: int = None
+    prefill_extend_lens: np.ndarray = None
+    sp_decode_local_lens: np.ndarray = None
+
     @classmethod
-    def init_new(cls, reqs, req_to_token_pool, token_to_kv_pool, tree_cache):
+    def init_new(
+        cls,
+        reqs,
+        req_to_token_pool,
+        token_to_kv_pool,
+        tree_cache,
+        sp_size: int = 1,
+        sp_rank: int = 0,
+    ):
         return_logprob = any(req.return_logprob for req in reqs)
 
         return cls(
@@ -358,6 +379,8 @@ def init_new(cls, reqs, req_to_token_pool, token_to_kv_pool, tree_cache):
             token_to_kv_pool=token_to_kv_pool,
             tree_cache=tree_cache,
             return_logprob=return_logprob,
+            sp_size=sp_size,
+            sp_rank=sp_rank,
         )
 
     def batch_size(self):
@@ -402,8 +425,24 @@ def prepare_for_extend(self, vocab_size: int):
         extend_num_tokens = sum(len(ids) for ids in input_ids)
         seq_lens = []
 
+        if self.sp_size == 1:
+            flatten_input_ids = sum(input_ids, [])
+        else:
+            flatten_input_ids = seq_parallel_input_ids_extend(
+                input_ids, self.sp_size, bs
+            )
+
         # Allocate memory
         req_pool_indices_cpu = self.alloc_req_slots(bs)
+        if self.sp_size > 1:
+            ext_lens = np.asarray(
+                [len(req.fill_ids) - len(req.prefix_indices) for req in reqs]
+            )
+            extend_local_token_nums = seq_parallel_local_len_extend(
+                self.sp_rank, self.sp_size, ext_lens
+            )
+            self.prefill_extend_lens = ext_lens
+            extend_num_tokens = int(np.sum(extend_local_token_nums))
         out_cache_loc = self.alloc_token_slots(extend_num_tokens)
 
         pt = 0
@@ -418,6 +457,11 @@ def prepare_for_extend(self, vocab_size: int):
                     :pre_len
                 ] = req.prefix_indices
 
+            if self.sp_size > 1:
+                ext_len = extend_local_token_nums[i]
+                # Prefix are stored elsewhere and not affected by the layout of
+                # **this** request.
+                seq_len = pre_len + ext_len
             self.req_to_token_pool.req_to_token[req.req_pool_idx][pre_len:seq_len] = (
                 out_cache_loc[pt : pt + ext_len]
             )
@@ -425,7 +469,7 @@ def prepare_for_extend(self, vocab_size: int):
 
         # Set fields
         with torch.device("cuda"):
-            self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32)
+            self.input_ids = torch.tensor(flatten_input_ids, dtype=torch.int32)
             self.req_pool_indices = torch.tensor(req_pool_indices_cpu)
             self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32)
             self.position_ids_offsets = torch.zeros((bs,), dtype=torch.int64)
@@ -636,13 +680,38 @@ def prepare_for_decode(self, input_ids=None):
         self.input_ids = torch.tensor(input_ids, dtype=torch.int32, device="cuda")
         self.seq_lens.add_(1)
 
+        if self.sp_size > 1:
+            seq_lens_cpu = self.seq_lens.cpu().numpy()
+            input_ids = seq_parallel_input_ids_decode(
+                input_ids, self.sp_size, seq_lens_cpu
+            )
+        self.input_ids = torch.tensor(input_ids, dtype=torch.int32, device="cuda")
+
         # Alloc mem
         bs = self.batch_size()
+        if self.sp_size > 1:
+            sp_local_indices = seq_parallel_decode_indices(
+                self.sp_rank, self.sp_size, seq_lens_cpu
+            )
+            bs = len(sp_local_indices)
+
         self.out_cache_loc = self.alloc_token_slots(bs)
 
-        self.req_to_token_pool.req_to_token[
-            self.req_pool_indices, self.seq_lens - 1
-        ] = self.out_cache_loc
+        if self.sp_size > 1:
+            # With SP, reqs are partitioned across SP workers so we need to use
+            # decode_local_lens instead of seq_lens when preparing KV cache.
+            bs = self.batch_size()
+            sp_decode_local_lens = self._sp_decode_local_len(range(bs))
+            self.sp_decode_local_lens = torch.from_numpy(sp_decode_local_lens)
+            local_req_indices = self.req_pool_indices[sp_local_indices]
+            local_lens_cpu = sp_decode_local_lens[sp_local_indices]
+            self.req_to_token_pool.req_to_token[
+                local_req_indices, local_lens_cpu - 1
+            ] = self.out_cache_loc
+        else:
+            self.req_to_token_pool.req_to_token[
+                self.req_pool_indices, self.seq_lens - 1
+            ] = self.out_cache_loc
 
         self.sampling_info.update_regex_vocab_mask(self)
 
@@ -665,6 +734,8 @@ def filter_batch(self, unfinished_indices: List[int]):
         self.out_cache_loc = None
         self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in unfinished_indices]
         self.return_logprob = any(req.return_logprob for req in self.reqs)
+        if self.sp_size > 1:
+            self.prefill_extend_lens = self.prefill_extend_lens[new_indices]
 
         self.sampling_info.filter(unfinished_indices, new_indices)
 
@@ -686,6 +757,10 @@ def merge(self, other: "ScheduleBatch"):
         self.out_cache_loc = None
         self.top_logprobs_nums.extend(other.top_logprobs_nums)
         self.return_logprob = any(req.return_logprob for req in self.reqs)
+        if self.sp_size > 1:
+            self.prefill_extend_lens = np.concatenate(
+                [self.prefill_extend_lens, other.prefill_extend_lens]
+            )
 
     def check_sample_results(self, sample_output: SampleOutput):
         if not torch.all(sample_output.success):
@@ -701,3 +776,36 @@ def check_sample_results(self, sample_output: SampleOutput):
             sample_output.batch_next_token_ids = batch_next_token_ids
 
         return sample_output.batch_next_token_ids
+
+    def _sp_decode_local_len(self, local_req_indices: np.ndarray):
+        """
+        Args:
+            local_req_indices(np.ndarray): 1D int array indexing selected
+            requests that stores KV-Cache on this SP rank.
+        Returns:
+            local_len(np.ndarray): 1D int array, describing the local KV cache
+            length on this SP rank, for selected request indices.
+        """
+        sp_size = self.sp_size
+
+        extend_lens = self.prefill_extend_lens[local_req_indices]
+        cur_lens = self.seq_lens.cpu().numpy()[local_req_indices]
+        decode_lens = cur_lens - extend_lens
+
+        extend_chunk_size = np.ceil(extend_lens / sp_size).astype(np.int32)
+        if self.sp_rank != sp_size - 1:
+            extend_size = extend_chunk_size
+        else:
+            extend_size = extend_lens - extend_chunk_size * (sp_size - 1)
+        # note that sp_len (as well as decode_lens) already increased 1.
+        # NOTE: for decoding tokens, assume there's no prefix, they are located:
+        # dec token 0 = all token [extend_lens] = stored at extend_lens % sp
+        # decode token i = stored at (extend_lens + i) % sp
+        # Hence, for the remainder tokens, they are stored at extend_lens % sp,
+        # extend_lens % sp + 1, ...
+        # For example, if sp = 4, extend lens = 6, the first decode remainder
+        # token is at rank 3 (7 % 4)
+        decode_extra_tok_offset = (self.sp_rank - extend_lens - 1) % sp_size
+        decode_extra_tok = decode_extra_tok_offset < (decode_lens % sp_size)
+        decode_size = decode_lens // sp_size + decode_extra_tok
+        return extend_size + decode_size
diff --git a/python/sglang/srt/managers/seq_parallel_layout.py b/python/sglang/srt/managers/seq_parallel_layout.py
new file mode 100644
index 00000000000..279e77f7c4a
--- /dev/null
+++ b/python/sglang/srt/managers/seq_parallel_layout.py
@@ -0,0 +1,302 @@
+"""Util functions for sequence parallel layout and runtime metadata."""
+
+import itertools
+from typing import TYPE_CHECKING, Sequence, Union
+
+import numpy as np
+import torch
+
+if TYPE_CHECKING:
+    from sglang.srt.managers.schedule_batch import ScheduleBatch
+    from sglang.srt.model_executor.forward_batch_info import InputMetadata
+    from sglang.srt.model_executor.model_runner import ModelRunner
+
+
+#### Offset of a sequence parallel shard under the sequence parallel layout.
+def _seq_parallel_offset_extend(sp_rank, sp_size, extend_seq_lens: np.ndarray):
+    return np.sum(np.ceil(extend_seq_lens / sp_size).astype(np.int32)) * sp_rank
+
+
+def _seq_parallel_offset_decode(sp_rank, sp_size, seq_lens: np.ndarray):
+    return np.sum((seq_lens % sp_size) < sp_rank)
+
+
+#### Indices from sequence parallel layout to normal layout
+def _sp_to_normal_indices_extend(sp_size, extend_seq_lens: np.ndarray):
+    """
+    Indices from the Sequence Parallel layout (padded) to the normal layout.
+    """
+    sp_seq_lens = np.ceil(extend_seq_lens / sp_size).astype(np.int32)
+    sp_len = np.sum(sp_seq_lens)
+    sp_seq_offset = np.concatenate(
+        [np.asarray([0], dtype=np.int32), np.cumsum(sp_seq_lens[:-1])]
+    )
+    sp_arange = np.arange(sp_size).reshape(-1, 1)
+    indices = []
+    for i in range(len(extend_seq_lens)):
+        sp_idx = np.arange(sp_seq_lens[i]).reshape(1, -1).repeat(sp_size, axis=0)
+        sp_idx = (sp_idx + sp_seq_offset[i] + sp_len * sp_arange).reshape(-1)
+        sp_idx = sp_idx[: extend_seq_lens[i]]
+        indices.append(sp_idx)
+    indices = np.concatenate(indices)
+    return indices
+
+
+def _sp_to_normal_indices_decode(sp_size, seq_lens: np.ndarray):
+    """
+    Indices from the Sequence Parallel layout (padded) to the normal layout.
+    """
+    req_sp_rank = seq_lens % sp_size
+    sp_rank_size = [np.sum(req_sp_rank == r) for r in range(sp_size)]
+    req_sp_offset = np.cumsum(np.asarray([0] + sp_rank_size[:-1]))
+    req_sp_offset = req_sp_offset[req_sp_rank]
+    for sp_rank in range(sp_size):
+        local_reqs = req_sp_rank == sp_rank
+        req_sp_index = np.cumsum(local_reqs) - 1
+        req_sp_offset += req_sp_index * local_reqs  # mask out reqs not here.
+    return req_sp_offset
+
+
+#### From normal layout to sequence parallel layout. Only for debug purpose
+def _debug_normal_to_sp_indices_decode(sp_size, seq_lens):
+    """(Debug only) Indices from normal layout to the SP layout (padded)."""
+    indices = [
+        seq_parallel_decode_indices(sp_rank, sp_size, seq_lens)
+        for sp_rank in range(sp_size)
+    ]
+    indices = [(np.arange(len(idxs)), idxs) for idxs in indices]
+    return indices
+
+
+def _debug_normal_to_sp_indices_extend(sp_size, seq_lens):
+    """(Debug only) Indices from normal layout to the SP layout (padded)."""
+    indices = []
+    sp_seq_lens = np.ceil(seq_lens / sp_size).astype(np.int32)
+    seq_offset = np.concatenate(
+        [np.asarray([0], dtype=np.int32), np.cumsum(seq_lens[:-1])]
+    )
+    sp_seq_offset = np.concatenate(
+        [np.asarray([0], dtype=np.int32), np.cumsum(sp_seq_lens[:-1])]
+    )
+    for sp_rank in range(sp_size):
+        start_idx = seq_offset + sp_seq_lens * sp_rank
+        end_idx = np.minimum(seq_offset + sp_seq_lens * (sp_rank + 1), seq_lens)
+        normal_layout_idx = np.concatenate(
+            [np.arange(start_idx[i], end_idx[i]) for i in range(len(seq_lens))]
+        )
+        if sp_rank == sp_size - 1:
+            length = end_idx - start_idx
+            target_layout_idx = np.concatenate(
+                [
+                    np.arange(sp_seq_offset[i], sp_seq_offset[i] + length[i])
+                    for i in range(len(seq_lens))
+                ]
+            )
+        else:
+            target_layout_idx = np.arange(len(normal_layout_idx))
+        indices.append((target_layout_idx, normal_layout_idx))
+    return indices
+
+
+def _debug_normal_to_sp(indices, output_tensor, tensor):
+    """
+    Use the indices generated above to translate from a normal layout to a
+    SP layout (padded). Due to the padding, `output_tensor`'s shape is different
+    from the input `tensor`'s.
+    """
+    for idxs in indices:
+        output_tensor[idxs] = tensor
+    output_tensor = output_tensor.contiguous()
+    return output_tensor
+
+
+#### Padding
+def seq_parallel_pad_zeros(
+    indices: torch.Tensor, seq_lens, sp_size: int, only_last_shard: bool = False
+):
+    """
+    Add padding zeros to SP-layout indices (must be a 1D tensor) so that the last
+    SP shard will have its sequences padded after each sequence and all SP shards
+    can have the same length.
+
+    This function is used to (1) adjust the positions tensor to align input_ids with
+    their positions during positional encoding and (2) adjust the output cache location
+    to write KV cache of padded tokens to slot 0 (reserved for dummy output).
+    """
+    sp_seq_lens = np.ceil(seq_lens / sp_size).astype(np.int32)
+    last_sp_seq_lens = seq_lens - sp_seq_lens * (sp_size - 1)
+    padded_num_tokens = np.sum(sp_seq_lens).astype(np.int32)
+    if only_last_shard:
+        padded_indices = torch.zeros(
+            padded_num_tokens, dtype=indices.dtype, device=indices.device
+        )
+        padded_stt = stt = 0
+    else:
+        padded_indices = torch.zeros(
+            sp_size * padded_num_tokens, dtype=indices.dtype, device=indices.device
+        )
+        # All non-last shards do not need padding and hence can be copied.
+        padded_stt = padded_num_tokens * (sp_size - 1)
+        stt = padded_stt
+        padded_indices[:padded_stt] = indices[:stt]
+
+    bs = seq_lens.size
+    for i in range(bs):
+        padded_end = padded_stt + sp_seq_lens[i]
+        end = stt + last_sp_seq_lens[i]
+        padded_indices[padded_stt : padded_stt + last_sp_seq_lens[i]] = indices[stt:end]
+        padded_stt = padded_end
+        stt = end
+    return padded_indices
+
+
+def _get_num_padding_tokens(sp_size, extend_seq_lens: np.ndarray):
+    """Get the number of tokens padded for SP."""
+    padded_size = np.ceil(extend_seq_lens / sp_size).astype(np.int32)
+    return sp_size * padded_size - extend_seq_lens
+
+
+#### Get length/indices of sequence parallel local tokens within a batch
+def seq_parallel_local_len_extend(
+    sp_rank, sp_size, extend_seq_lens: Union[int, np.ndarray]
+):
+    """Get the number of tokens in this SP. Padding is not considered."""
+    padded_size = np.ceil(extend_seq_lens / sp_size).astype(np.int32)
+    return (
+        padded_size
+        if sp_rank != sp_size - 1
+        else extend_seq_lens - (sp_size - 1) * padded_size
+    )
+
+
+def seq_parallel_extend_local_token_slice(sp_rank, sp_size, seq_len: int):
+    """Get the SP local slice for a single request's extended input ids."""
+    start = int(np.ceil(seq_len / sp_size) * sp_rank)
+    length = seq_parallel_local_len_extend(sp_rank, sp_size, seq_len)
+    return slice(start, start + length)
+
+
+def seq_parallel_decode_indices(sp_rank, sp_size, seq_lens: np.ndarray):
+    """Get Indices from the normal layout to the sequence parallel layout."""
+    return np.nonzero((seq_lens % sp_size) == sp_rank)[0]
+
+
+#### Transpose to sequence parallel layout
+def seq_parallel_input_ids_extend(
+    input_ids: Sequence[Sequence[int]], sp_size: int, bs: int
+):
+    # Note: The flatten input ids with Sequence Parallel is in form of:
+    # [req_0_sp_0, req_1_sp_0, ... req_n_sp_0,
+    #  req_0_sp_1, req_1_sp_1, ..., req_n_sp_1,
+    #   ...
+    #  req_0_sp_m, req_0_padding, req_1_sp_m, req_1_padding, ...]
+    # ]
+    # The padding is for collection primitives which needs each candidate to
+    # have the same size. Since we don't expect too many requests in SP,
+    # the extra compute caused by this is affordable.
+    flatten_input_ids = [[] for _ in range(sp_size)]
+    num_padding_tokens = _get_num_padding_tokens(
+        sp_size, np.asarray([len(ids) for ids in input_ids])
+    )
+    for i in range(bs):
+        for sp_rank in range(sp_size):
+            ids = input_ids[i]
+            local_slice = seq_parallel_extend_local_token_slice(
+                sp_rank, sp_size, len(ids)
+            )
+            flatten_input_ids[sp_rank].extend(ids[local_slice])
+        flatten_input_ids[-1].extend([0] * num_padding_tokens[i])
+    flatten_input_ids = list(itertools.chain(*flatten_input_ids))
+    return flatten_input_ids
+
+
+def seq_parallel_input_ids_decode(
+    input_ids: Sequence[int], sp_size: int, seq_lens: np.ndarray
+):
+    input_indices_sp = [[] for _ in range(sp_size)]
+    # NOTE: in the extend phase, we evenly do sequence partition on extended
+    # tokens (extend_len). However, since prefix lens is cleaned, we instead
+    # use the whole sequence length (seq_lens) for the round-robin KV-cache.
+    for sp_rank in range(sp_size):
+        indices = seq_parallel_decode_indices(sp_rank, sp_size, seq_lens)
+        input_indices_sp[sp_rank].extend(indices)
+    flatten_input_indices = list(itertools.chain(*input_indices_sp))
+    flatten_input_ids = np.asarray(input_ids)[flatten_input_indices]
+    return flatten_input_ids
+
+
+#### Handle metadata
+def init_sequence_parallel_args(
+    model_runner: "ModelRunner", batch: "ScheduleBatch", forward_mode
+):
+    from sglang.srt.model_executor.forward_batch_info import ForwardMode
+
+    sp_rank = model_runner.sp_rank
+    sp_size = model_runner.sp_size
+    seq_lens = batch.seq_lens
+    extend_seq_lens_cpu = batch.prefill_extend_lens
+    num_tokens = batch.input_ids.numel()
+    if sp_size > 1:
+        # During the runtime, we should use positions[local_token_indices]
+        # to get positions for each SP shard.
+        if forward_mode == ForwardMode.DECODE:
+            seq_lens_cpu = seq_lens.cpu().numpy()
+            sp_to_normal_indices = _sp_to_normal_indices_decode(sp_size, seq_lens_cpu)
+            sp_local_token_length = seq_parallel_decode_indices(
+                sp_rank, sp_size, seq_lens_cpu
+            ).size
+            sp_local_token_offset = _seq_parallel_offset_decode(
+                sp_rank, sp_size, seq_lens_cpu
+            )
+            # Convert positions to SP layout and add padding zeros.
+            normal_to_sp_indices = np.argsort(sp_to_normal_indices)
+            # positions = positions[normal_to_sp_indices]
+        else:
+            sp_to_normal_indices = _sp_to_normal_indices_extend(
+                sp_size, extend_seq_lens_cpu
+            )
+            sp_local_token_length = seq_parallel_local_len_extend(
+                sp_rank, sp_size, extend_seq_lens_cpu
+            )
+            sp_local_token_offset = _seq_parallel_offset_extend(
+                sp_rank, sp_size, extend_seq_lens_cpu
+            )
+            # Convert positions to SP layout and add padding zeros.
+            normal_to_sp_indices = np.argsort(sp_to_normal_indices)
+            # positions = positions[normal_to_sp_indices]
+            # positions = seq_parallel_pad_zeros(positions, extend_seq_lens_cpu, sp_size)
+            # Add padding zeros to out_cache_loc and write KV of padded tokens that may
+            # exist in the last SP shard to slot 0 (reserved for dummy output).
+            if sp_rank == sp_size - 1:
+                batch.out_cache_loc = seq_parallel_pad_zeros(
+                    batch.out_cache_loc, extend_seq_lens_cpu, sp_size, True
+                )
+    else:
+        sp_to_normal_indices = np.arange(num_tokens)
+        normal_to_sp_indices = np.arange(num_tokens)
+        sp_local_token_length = num_tokens
+        sp_local_token_offset = 0
+
+    _debug_normal_to_sp_metadata = None
+    if False and sp_size > 1:
+        if forward_mode == ForwardMode.DECODE:
+            _debug_normal_to_sp_metadata = _debug_normal_to_sp_indices_decode(
+                sp_size, seq_lens_cpu
+            )
+        else:
+            _debug_normal_to_sp_metadata = _debug_normal_to_sp_indices_extend(
+                sp_size, extend_seq_lens_cpu
+            )
+
+    init_args = {
+        "sp_size": sp_size,
+        "sp_rank": sp_rank,
+        "sp_to_normal_indices": sp_to_normal_indices,
+        "sp_local_token_length": sp_local_token_length,
+        "sp_local_token_offset": sp_local_token_offset,
+        "_debug_normal_to_sp_metadata": _debug_normal_to_sp_metadata,
+        "flashinfer_prefill_wrapper_sp_full": model_runner.flashinfer_prefill_wrapper_sp_full,
+        "flashinfer_prefill_wrapper_sp_causal": model_runner.flashinfer_prefill_wrapper_sp_causal,
+    }
+    aux_args = {"normal_to_sp_indices": normal_to_sp_indices}
+    return init_args, aux_args
diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py
index 4459213b02f..38d73910726 100644
--- a/python/sglang/srt/model_executor/cuda_graph_runner.py
+++ b/python/sglang/srt/model_executor/cuda_graph_runner.py
@@ -194,6 +194,13 @@ def capture_one_batch_size(self, bs: int, forward: Callable):
         seq_lens = self.seq_lens[:bs]
         position_ids_offsets = self.position_ids_offsets[:bs]
         out_cache_loc = self.out_cache_loc[:bs]
+        # TODO (yonghao): fix parameter initialization below.
+        normal_to_sp_indices = None
+        sp_decode_local_lens = torch.ceil(seq_lens / self.model_runner.sp_size).to(
+            torch.int32
+        )
+        sp_local_token_offset = 0
+        sp_local_token_length = torch.sum(sp_decode_local_lens).to(torch.int32)
 
         # FlashInfer inputs
         if not _grouped_size_compiled_for_decode_kernels(
@@ -237,6 +244,8 @@ def capture_one_batch_size(self, bs: int, forward: Callable):
             seq_lens,
             None,
             flashinfer_decode_wrapper,
+            normal_to_sp_indices=normal_to_sp_indices,
+            sp_decode_local_lens=sp_decode_local_lens,
         )
 
         # Run and capture
@@ -254,6 +263,10 @@ def run_once():
                 top_logprobs_nums=0,
                 positions=(seq_lens - 1 + position_ids_offsets).to(torch.int64),
                 flashinfer_decode_wrapper=flashinfer_decode_wrapper,
+                sp_rank=self.model_runner.sp_rank,
+                sp_size=self.model_runner.sp_size,
+                sp_local_token_offset=sp_local_token_offset,
+                sp_local_token_length=sp_local_token_length,
             )
 
             return forward(input_ids, input_metadata.positions, input_metadata)
diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py
index a443b113d44..824f0586015 100644
--- a/python/sglang/srt/model_executor/forward_batch_info.py
+++ b/python/sglang/srt/model_executor/forward_batch_info.py
@@ -26,6 +26,11 @@
 import triton.language as tl
 
 from sglang.srt.managers.schedule_batch import ScheduleBatch
+from sglang.srt.managers.seq_parallel_layout import (
+    init_sequence_parallel_args,
+    seq_parallel_local_len_extend,
+    seq_parallel_pad_zeros,
+)
 from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
 
 if TYPE_CHECKING:
@@ -90,6 +95,18 @@ class InputMetadata:
     flashinfer_prefill_wrapper_paged: "BatchPrefillWithPagedKVCacheWrapper" = None
     flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None
     flashinfer_use_ragged: bool = False
+    # NOTE: for sequence parallel, we need dedicated kernels for cross-shard attn.
+    # Especially, we need custom masks for the last SP shard which may contain padding tokens.
+    flashinfer_prefill_wrapper_sp_full: "BatchPrefillWithRaggedKVCacheWrapper" = None
+    flashinfer_prefill_wrapper_sp_causal: "BatchPrefillWithRaggedKVCacheWrapper" = None
+
+    # For Sequence Parallel
+    sp_rank: int = None
+    sp_size: int = None
+    sp_to_normal_indices: np.ndarray = None
+    sp_local_token_length: int = None
+    sp_local_token_offset: int = None
+    _debug_normal_to_sp_metadata: Optional[List[np.ndarray]] = None
 
     def init_multimuldal_info(self, batch: ScheduleBatch):
         reqs = batch.reqs
@@ -97,7 +114,7 @@ def init_multimuldal_info(self, batch: ScheduleBatch):
         self.image_sizes = [r.image_sizes for r in reqs]
         self.image_offsets = [r.image_offsets for r in reqs]
 
-    def compute_positions(self, batch: ScheduleBatch):
+    def compute_positions(self, batch: ScheduleBatch, normal_to_sp_indices):
         position_ids_offsets = batch.position_ids_offsets
 
         if self.forward_mode == ForwardMode.DECODE:
@@ -137,6 +154,9 @@ def compute_positions(self, batch: ScheduleBatch):
 
         # Positions should be in long type
         self.positions = self.positions.to(torch.int64)
+        update_positions_for_seq_parallel(
+            self, normal_to_sp_indices, batch.prefill_extend_lens
+        )
 
     def compute_extend_infos(self, batch: ScheduleBatch):
         if self.forward_mode == ForwardMode.DECODE:
@@ -173,6 +193,9 @@ def from_schedule_batch(
         batch: ScheduleBatch,
         forward_mode: ForwardMode,
     ):
+        sp_args, aux_args = init_sequence_parallel_args(
+            model_runner, batch, forward_mode
+        )
         ret = cls(
             forward_mode=forward_mode,
             sampling_info=batch.sampling_info,
@@ -184,11 +207,12 @@ def from_schedule_batch(
             out_cache_loc=batch.out_cache_loc,
             return_logprob=batch.return_logprob,
             top_logprobs_nums=batch.top_logprobs_nums,
+            **sp_args,
         )
 
         ret.sampling_info.prepare_penalties()
 
-        ret.compute_positions(batch)
+        ret.compute_positions(batch, aux_args["normal_to_sp_indices"])
 
         ret.compute_extend_infos(batch)
 
@@ -208,12 +232,17 @@ def from_schedule_batch(
         if not model_runner.server_args.disable_flashinfer:
             if (
                 forward_mode != ForwardMode.DECODE
-                and int(torch.sum(ret.seq_lens)) > 4096
+                and (int(torch.sum(ret.seq_lens)) > 4096 or ret.sp_size > 1)
                 and model_runner.sliding_window_size is None
             ):
+                # NOTE: SP requires the ragged kernel regardless of the sequence length.
                 flashinfer_use_ragged = True
             ret.init_flashinfer_handlers(
-                model_runner, batch.prefix_lens_cpu, flashinfer_use_ragged
+                model_runner,
+                batch.prefix_lens_cpu,
+                flashinfer_use_ragged,
+                aux_args["normal_to_sp_indices"],
+                batch.sp_decode_local_lens,
             )
 
         return ret
@@ -236,6 +265,8 @@ def init_flashinfer_handlers(
         model_runner,
         prefix_lens_cpu,
         flashinfer_use_ragged,
+        normal_to_sp_indices,
+        sp_decode_local_lens,
     ):
         if self.forward_mode == ForwardMode.DECODE:
             prefix_lens = None
@@ -249,6 +280,8 @@ def init_flashinfer_handlers(
             self.seq_lens,
             prefix_lens,
             flashinfer_use_ragged=flashinfer_use_ragged,
+            normal_to_sp_indices=normal_to_sp_indices,
+            sp_decode_local_lens=sp_decode_local_lens,
         )
 
         (
@@ -308,10 +341,16 @@ def update_flashinfer_indices(
     prefix_lens,
     flashinfer_decode_wrapper=None,
     flashinfer_use_ragged=False,
+    normal_to_sp_indices=None,
+    sp_decode_local_lens=None,
 ):
     """Init auxiliary variables for FlashInfer attention backend."""
     num_qo_heads = model_runner.model_config.num_attention_heads // model_runner.tp_size
-    num_kv_heads = model_runner.model_config.get_num_kv_heads(model_runner.tp_size)
+    # NOTE (yifan): we partitioned K and V along both TP and SP dimensions.
+    # And here tp_size represents KV-TP size * SP size.
+    num_kv_heads = model_runner.model_config.get_num_kv_heads(
+        model_runner.tp_size // model_runner.sp_size
+    )
     head_dim = model_runner.model_config.head_dim
     batch_size = len(req_pool_indices)
 
@@ -321,6 +360,28 @@ def update_flashinfer_indices(
         else:
             paged_kernel_lens = seq_lens
 
+        sp_size = model_runner.sp_size
+        if forward_mode == ForwardMode.DECODE:
+            # With SP, reqs may have been reordered so we track them here.
+            if normal_to_sp_indices is not None:
+                req_ids = normal_to_sp_indices.tolist()
+            else:
+                req_ids = list(range(batch_size))
+            paged_kernel_lens = seq_lens if sp_size == 1 else sp_decode_local_lens
+        else:
+            extend_lens = seq_lens - prefix_lens
+            # With SP, we use different kernels for sequences that are not evenly partitioned
+            # across SP workers. Here seq_lens works for most SP workers that do not need
+            # masks, and we initiaize kernels with masks separately below.
+            seq_lens = torch.ceil(seq_lens / sp_size).to(torch.int32)
+            prefix_lens = torch.ceil(prefix_lens / sp_size).to(torch.int32)
+            req_ids = list(range(batch_size))
+
+        if sp_size > 1:
+            req_pool_indices = req_pool_indices[req_ids].contiguous()
+            paged_kernel_lens = paged_kernel_lens[req_ids].contiguous()
+            paged_kernel_lens = paged_kernel_lens.to(req_pool_indices.device)
+
         kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
         kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
 
@@ -338,6 +399,9 @@ def update_flashinfer_indices(
         kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
 
         if forward_mode == ForwardMode.DECODE:
+            # For decode, we replicate the current token across SP workers and hence
+            # each SP worker will have all q heads.
+            num_qo_heads *= model_runner.sp_size
             # CUDA graph uses different flashinfer_decode_wrapper
             if flashinfer_decode_wrapper is None:
                 flashinfer_decode_wrapper = model_runner.flashinfer_decode_wrapper
@@ -381,7 +445,63 @@ def update_flashinfer_indices(
                 head_dim,
                 1,
             )
+        if (
+            sp_size > 1 and forward_mode != ForwardMode.DECODE
+        ):  # Sequence parallel enabled, initialize SP kernels with custom masks.
+            # NOTE (yifan): here we assume that when sequence parallel is enabled,
+            # prefix_lens are always 0s, and we will use flashinfer paged attn kernel
+            # for cross-SP-shard attn computation. If later prefix_lens can be non-0s, (
+            # e.g., extend phases with SP), we will need a dedicate paged attn kernel
+            # wrapper for cross-SP-shard attn.
+            if torch.sum(prefix_lens) != 0:
+                raise ValueError(
+                    "Prefix caching with sequence parallelism is not supported."
+                )
+
+            # Prepare masks.
+            sp_size = sp_size
+            extend_lens_cpu = extend_lens.cpu().numpy()
+            padded_extend_lens = seq_parallel_local_len_extend(
+                0, sp_size, extend_lens_cpu
+            )
+            last_extend_lens = seq_parallel_local_len_extend(
+                sp_size - 1, sp_size, extend_lens_cpu
+            )
+            qo_len = (seq_lens - prefix_lens).cpu().tolist()
+            full_mask_arr = []
+            causal_mask_arr = []
+            for i in range(batch_size):
+                full_mask_i = torch.full((qo_len[i], qo_len[i]), False, device="cuda")
+                full_mask_i[: last_extend_lens[i], : padded_extend_lens[i]] = True
+                full_mask_arr.append(full_mask_i.flatten())
+                causal_mask_i = torch.tril(full_mask_i, diagonal=0)
+                causal_mask_arr.append(causal_mask_i.flatten())
+            full_mask = torch.cat(full_mask_arr, dim=0)
+            causal_mask = torch.cat(causal_mask_arr, dim=0)
+
+            # Cross-SP-shard extend part -- masked for the last SP shard which may have
+            # padding tokens. For the othe shards, we can simply use the ragged kernel.
+            model_runner.flashinfer_prefill_wrapper_sp_causal.end_forward()
+            model_runner.flashinfer_prefill_wrapper_sp_causal.begin_forward(
+                qo_indptr,
+                qo_indptr,
+                num_qo_heads,
+                num_kv_heads,
+                head_dim,
+                custom_mask=causal_mask,
+            )
+
+            model_runner.flashinfer_prefill_wrapper_sp_full.end_forward()
+            model_runner.flashinfer_prefill_wrapper_sp_full.begin_forward(
+                qo_indptr,
+                qo_indptr,
+                num_qo_heads,
+                num_kv_heads,
+                head_dim,
+                custom_mask=full_mask,
+            )
     else:
+        assert model_runner.sp_size == 1, "SP with sliding window not supported"
         # window attention use paged only
         kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
         for wrapper_id in range(2):
@@ -451,3 +571,20 @@ def update_flashinfer_indices(
                     head_dim,
                     1,
                 )
+
+
+def update_positions_for_seq_parallel(
+    input_metadata: InputMetadata, normal_to_sp_indices, extend_seq_lens
+):
+    sp_size = input_metadata.sp_size
+    if sp_size == 1:
+        return
+
+    positions = input_metadata.positions
+
+    if input_metadata.forward_mode == ForwardMode.DECODE:
+        positions = positions[normal_to_sp_indices]
+    else:
+        positions = positions[normal_to_sp_indices]
+        positions = seq_parallel_pad_zeros(positions, extend_seq_lens, sp_size)
+    input_metadata.positions = positions
diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py
index 3d3e0cde9d1..8fff310aaf0 100644
--- a/python/sglang/srt/model_executor/model_runner.py
+++ b/python/sglang/srt/model_executor/model_runner.py
@@ -36,7 +36,6 @@
 from vllm.distributed import (
     get_tp_group,
     init_distributed_environment,
-    initialize_model_parallel,
     set_custom_all_reduce,
 )
 from vllm.distributed.parallel_state import in_the_same_node_as
@@ -45,6 +44,7 @@
 
 from sglang.global_config import global_config
 from sglang.srt.layers.logits_processor import LogitsProcessorOutput
+from sglang.srt.layers.parallel_utils import initialize_model_parallel
 from sglang.srt.layers.sampler import SampleOutput
 from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
 from sglang.srt.mem_cache.memory_pool import (
@@ -78,6 +78,8 @@ def __init__(
         tp_size: int,
         nccl_port: int,
         server_args: ServerArgs,
+        sp_rank: int = 0,
+        sp_size: int = 1,
     ):
         # Parse args
         self.model_config = model_config
@@ -85,6 +87,8 @@ def __init__(
         self.gpu_id = gpu_id
         self.tp_rank = tp_rank
         self.tp_size = tp_size
+        self.sp_rank = sp_rank
+        self.sp_size = sp_size
         self.nccl_port = nccl_port
         self.server_args = server_args
         self.is_multimodal_model = is_multimodal_model(
@@ -137,7 +141,11 @@ def init_torch_distributed(self):
             local_rank=self.gpu_id,
             distributed_init_method=nccl_init_method,
         )
-        initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
+        initialize_model_parallel(
+            tensor_model_parallel_size=self.tp_size,
+            sequence_parallel_size=self.sp_size,
+        )
+        self.tp_group = get_tp_group()
         min_per_gpu_memory = get_available_gpu_memory(
             self.gpu_id, distributed=self.tp_size > 1
         )
@@ -321,14 +329,18 @@ def profile_max_num_token(self, total_gpu_memory: int):
             self.model_config.attention_arch == AttentionArch.MLA
             and self.server_args.enable_mla
         ):
+            # FIXME: temporarily disable SP with MLA
+            assert self.sp_size == 1, "sequence parallel with MLA not supported"
             cell_size = (
                 (self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
                 * self.model_config.num_hidden_layers
                 * torch._utils._element_size(self.kv_cache_dtype)
             )
         else:
+            kv_tp_size = self.tp_size // self.sp_size
+            head_num = self.model_config.get_num_kv_heads(kv_tp_size)
             cell_size = (
-                self.model_config.get_num_kv_heads(self.tp_size)
+                head_num
                 * self.model_config.head_dim
                 * self.model_config.num_hidden_layers
                 * 2
@@ -346,6 +358,11 @@ def init_memory_pool(
         max_num_reqs: int = None,
         max_total_tokens: int = None,
     ):
+        if self.tp_size % self.sp_size != 0:
+            raise ValueError(
+                f"Invalid sequence parallel configuration. tp_size={self.tp_size} "
+                f"must be divisible by sp_size={self.sp_size}"
+            )
         if self.server_args.kv_cache_dtype == "auto":
             self.kv_cache_dtype = self.dtype
         elif self.server_args.kv_cache_dtype == "fp8_e5m2":
@@ -389,6 +406,8 @@ def init_memory_pool(
             self.model_config.attention_arch == AttentionArch.MLA
             and self.server_args.enable_mla
         ):
+            # FIXME: temporarily disable SP with MLA
+            assert self.sp_size == 1, "sequence parallel with MLA not supported"
             self.token_to_kv_pool = MLATokenToKVPool(
                 self.max_total_num_tokens,
                 dtype=self.kv_cache_dtype,
@@ -400,10 +419,11 @@ def init_memory_pool(
             # FIXME: temporarily only Triton MLA is supported
             self.server_args.disable_flashinfer = True
         else:
+            kv_tp_size = self.tp_size // self.sp_size
             self.token_to_kv_pool = MHATokenToKVPool(
                 self.max_total_num_tokens,
                 dtype=self.kv_cache_dtype,
-                head_num=self.model_config.get_num_kv_heads(self.tp_size),
+                head_num=self.model_config.get_num_kv_heads(kv_tp_size),
                 head_dim=self.model_config.head_dim,
                 layer_num=self.model_config.num_hidden_layers,
             )
@@ -430,6 +450,9 @@ def init_flashinfer(self):
             self.flashinfer_prefill_wrapper_ragged = None
             self.flashinfer_prefill_wrapper_paged = None
             self.flashinfer_decode_wrapper = None
+            # NOTE: for sequence parallel, we need to use a dedicated kernel for cross-shard attn.
+            self.flashinfer_prefill_wrapper_sp_full = None
+            self.flashinfer_prefill_wrapper_sp_causal = None
             return
 
         if not _grouped_size_compiled_for_decode_kernels(
@@ -440,7 +463,10 @@ def init_flashinfer(self):
         else:
             use_tensor_cores = False
 
+        self.flashinfer_prefill_wrapper_sp_full = None
+        self.flashinfer_prefill_wrapper_sp_causal = None
         if self.sliding_window_size is None:
+            # FIXME: missing SP info here.
             self.flashinfer_workspace_buffer = torch.empty(
                 global_config.flashinfer_workspace_size,
                 dtype=torch.uint8,
@@ -459,6 +485,17 @@ def init_flashinfer(self):
                 "NHD",
                 use_tensor_cores=use_tensor_cores,
             )
+            if self.sp_size > 1:  # Sequence parallel enabled.
+                self.flashinfer_prefill_wrapper_sp_full = (
+                    BatchPrefillWithRaggedKVCacheWrapper(
+                        self.flashinfer_workspace_buffer, "NHD"
+                    )
+                )
+                self.flashinfer_prefill_wrapper_sp_causal = (
+                    BatchPrefillWithRaggedKVCacheWrapper(
+                        self.flashinfer_workspace_buffer, "NHD"
+                    )
+                )
         else:
             self.flashinfer_workspace_buffer = torch.empty(
                 global_config.flashinfer_workspace_size,
diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py
index 926d87db8b7..d11b91cd12d 100644
--- a/python/sglang/srt/models/llama.py
+++ b/python/sglang/srt/models/llama.py
@@ -26,7 +26,6 @@
 from vllm.distributed import get_tensor_model_parallel_world_size
 from vllm.model_executor.layers.linear import (
     MergedColumnParallelLinear,
-    QKVParallelLinear,
     RowParallelLinear,
 )
 from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
@@ -40,8 +39,10 @@
 from sglang.srt.layers.activation import SiluAndMul
 from sglang.srt.layers.layernorm import RMSNorm
 from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
+from sglang.srt.layers.parallel_utils import get_kv_tensor_model_parallel_world_size
 from sglang.srt.layers.radix_attention import RadixAttention
 from sglang.srt.layers.sampler import Sampler
+from sglang.srt.layers.sp_linear import QKVParallelLinear, RowSeqParallelLinear
 from sglang.srt.model_executor.forward_batch_info import InputMetadata
 
 
@@ -100,20 +101,28 @@ def __init__(
     ) -> None:
         super().__init__()
         self.hidden_size = hidden_size
+        # This is KV_TP_SIZE * SP_SIZE
         tp_size = get_tensor_model_parallel_world_size()
+        # This is the KV-TP size
+        kv_tp_size = get_kv_tensor_model_parallel_world_size()
+        # Sequence parallel size
         self.total_num_heads = num_heads
         assert self.total_num_heads % tp_size == 0
+        # num_heads is partitioned by both TP and SP so here use tp_size which
+        # represents the total TP x SP parallelism.
         self.num_heads = self.total_num_heads // tp_size
         self.total_num_kv_heads = num_kv_heads
-        if self.total_num_kv_heads >= tp_size:
-            # Number of KV heads is greater than TP size, so we partition
+        # num_kv_heads is partitioned only by TP so here use kv_tp_size which
+        # represents the KV-TP parallelism.
+        if self.total_num_kv_heads >= kv_tp_size:
+            # Number of KV heads is greater than KV-TP size, so we partition
             # the KV heads across multiple tensor parallel GPUs.
-            assert self.total_num_kv_heads % tp_size == 0
+            assert self.total_num_kv_heads % kv_tp_size == 0
         else:
             # Number of KV heads is less than TP size, so we replicate
             # the KV heads across multiple tensor parallel GPUs.
-            assert tp_size % self.total_num_kv_heads == 0
-        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
+            assert kv_tp_size % self.total_num_kv_heads == 0
+        self.num_kv_heads = max(1, self.total_num_kv_heads // kv_tp_size)
         # MistralConfig has an optional head_dim introduced by Mistral-Nemo
         self.head_dim = getattr(
             config, "head_dim", self.hidden_size // self.total_num_heads
@@ -133,9 +142,12 @@ def __init__(
             quant_config=quant_config,
             prefix=f"{prefix}.qkv_proj",
         )
-        self.o_proj = RowParallelLinear(
+        self.o_proj = RowSeqParallelLinear(
             self.total_num_heads * self.head_dim,
             hidden_size,
+            self.total_num_heads,
+            self.num_kv_heads,
+            self.head_dim,
             bias=False,
             quant_config=quant_config,
             prefix=f"{prefix}.o_proj",
@@ -163,8 +175,7 @@ def forward(
         hidden_states: torch.Tensor,
         input_metadata: InputMetadata,
     ) -> torch.Tensor:
-        qkv, _ = self.qkv_proj(hidden_states)
-        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
+        q, k, v = self.qkv_proj(hidden_states)
         q, k = self.rotary_emb(positions, q, k)
         attn_output = self.attn(q, k, v, input_metadata)
         output, _ = self.o_proj(attn_output)
@@ -315,6 +326,13 @@ def forward(
         input_embeds: torch.Tensor = None,
     ) -> LogitsProcessorOutput:
         hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
+        if input_metadata.sp_size > 1:
+            # TODO: instead of a GPU indexing, sample under SP layout and parse
+            # sampling result back to normal layout
+            hidden_states = hidden_states[
+                input_metadata.sp_to_normal_indices
+            ].contiguous()
+            input_ids = input_ids[input_metadata.sp_to_normal_indices].contiguous()
         logits_output = self.logits_processor(
             input_ids, hidden_states, self.lm_head.weight, input_metadata
         )
@@ -324,11 +342,14 @@ def forward(
     def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
-            (".qkv_proj", ".q_proj", "q"),
-            (".qkv_proj", ".k_proj", "k"),
-            (".qkv_proj", ".v_proj", "v"),
-            (".gate_up_proj", ".gate_proj", 0),
-            (".gate_up_proj", ".up_proj", 1),
+            ("qkv_proj.kv_proj", "k_proj", "k"),
+            ("qkv_proj.kv_proj", "v_proj", "v"),
+            ("gate_up_proj", "gate_proj", 0),
+            ("gate_up_proj", "up_proj", 1),
+        ]
+        renamed_params_mapping = [
+            # (param_name, weight_name)
+            ("qkv_proj.q_proj", "q_proj"),
         ]
         params_dict = self.param_dict
 
@@ -357,6 +378,11 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                 # Skip loading extra bias for GPTQ models.
                 if name.endswith(".bias") and name not in params_dict:
                     continue
+                for param_name, weight_name in renamed_params_mapping:
+                    if weight_name not in name:
+                        continue
+                    name = name.replace(weight_name, param_name)
+                    break
                 param = params_dict[name]
                 weight_loader = getattr(param, "weight_loader", default_weight_loader)
                 weight_loader(param, loaded_weight)
diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py
index 8a56c02e162..90093c5b2f5 100644
--- a/python/sglang/srt/server_args.py
+++ b/python/sglang/srt/server_args.py
@@ -58,6 +58,7 @@ class ServerArgs:
 
     # Other runtime options
     tp_size: int = 1
+    sp_size: int = 1
     stream_interval: int = 1
     random_seed: Optional[int] = None
 
@@ -304,6 +305,12 @@ def add_cli_args(parser: argparse.ArgumentParser):
             default=ServerArgs.tp_size,
             help="The tensor parallelism size.",
         )
+        parser.add_argument(
+            "--sp-size",
+            type=int,
+            default=ServerArgs.sp_size,
+            help="The sequence parallelism size.",
+        )
         parser.add_argument(
             "--stream-interval",
             type=int,
diff --git a/test/srt/test_seq_parallel_attn_kernel.py b/test/srt/test_seq_parallel_attn_kernel.py
new file mode 100644
index 00000000000..a2895422afe
--- /dev/null
+++ b/test/srt/test_seq_parallel_attn_kernel.py
@@ -0,0 +1,233 @@
+import multiprocessing
+import random
+
+import torch
+from vllm.distributed import init_distributed_environment
+
+from sglang.srt.layers.parallel_utils import initialize_model_parallel
+from sglang.srt.layers.radix_attention import RadixAttention
+from sglang.srt.model_executor.model_runner import InputMetadata
+
+NUM_HEADS = 32
+HEAD_DIM = 128
+SCALING = 1
+NUM_KV_HEADS = 8
+LAYER_ID = 0
+LOGIT_CAP = -1
+
+
+BATCH_SIZE = 1
+QO_LEN = 128
+KV_LEN = 128
+
+
+def gen_qkv(rank: int = 0, sp_size: int = 1):
+    torch.manual_seed(42)
+    random.seed(42)
+    q = torch.randn(BATCH_SIZE * QO_LEN, NUM_HEADS, HEAD_DIM).cuda().half()
+    k = torch.randn(BATCH_SIZE * KV_LEN, NUM_KV_HEADS, HEAD_DIM).cuda().half()
+    v = torch.randn(BATCH_SIZE * KV_LEN, NUM_KV_HEADS, HEAD_DIM).cuda().half()
+
+    # num_heads_per_partition = NUM_HEADS // sp_size
+    # q = q[
+    #     :, :, num_heads_per_partition * rank : num_heads_per_partition * (rank + 1)
+    # ].contiguous()
+    # kv_len_per_partition = KV_LEN // sp_size
+    # k = k[
+    #     :, kv_len_per_partition * rank : kv_len_per_partition * (rank + 1)
+    # ].contiguous()
+    # v = v[
+    #     :, kv_len_per_partition * rank : kv_len_per_partition * (rank + 1)
+    # ].contiguous()
+
+    return q, k, v
+
+
+def get_input_metadata(sp_size: int = 1, tp_size: int = 1):
+    from flashinfer import (
+        BatchPrefillWithPagedKVCacheWrapper,
+        BatchPrefillWithRaggedKVCacheWrapper,
+    )
+
+    input_metadata = InputMetadata(
+        forward_mode=None,
+        batch_size=BATCH_SIZE,
+        total_num_tokens=BATCH_SIZE * QO_LEN,
+        req_pool_indices=None,
+        seq_lens=None,
+        positions=None,
+        req_to_token_pool=None,
+        token_to_kv_pool=None,
+        out_cache_loc=None,
+        extend_seq_lens=None,
+        extend_start_loc=None,
+        extend_no_prefix=True,
+        return_logprob=None,
+        top_logprobs_nums=None,
+        flashinfer_prefill_wrapper_ragged=None,
+        flashinfer_prefill_wrapper_paged=None,
+        flashinfer_decode_wrapper=None,
+        sp_size=sp_size,
+    )
+
+    workspace_buffer = torch.empty(
+        2, 128 * 1024 * 1024, dtype=torch.int8, device="cuda"
+    )
+
+    input_metadata.flashinfer_prefill_wrapper_ragged = (
+        BatchPrefillWithRaggedKVCacheWrapper(workspace_buffer[0], "NHD")
+    )
+    input_metadata.flashinfer_prefill_wrapper_paged = (
+        BatchPrefillWithPagedKVCacheWrapper(workspace_buffer[1], "NHD")
+    )
+
+    num_qo_heads = NUM_HEADS // sp_size
+    num_kv_heads = NUM_KV_HEADS
+    qo_len_per_iter = QO_LEN // sp_size
+    kv_len_per_partition = KV_LEN // sp_size
+
+    qo_indptr = torch.arange(0, BATCH_SIZE + 1).cuda().int() * qo_len_per_iter
+    kv_indptr = torch.arange(0, BATCH_SIZE + 1).cuda().int() * kv_len_per_partition
+    input_metadata.flashinfer_prefill_wrapper_ragged.end_forward()
+    input_metadata.flashinfer_prefill_wrapper_ragged.begin_forward(
+        qo_indptr,
+        kv_indptr,
+        num_qo_heads,
+        num_kv_heads,
+        HEAD_DIM,
+    )
+
+    # cached part
+    kv_indices = torch.arange(0, BATCH_SIZE * kv_len_per_partition).cuda().int()
+    kv_last_page_len = torch.full((BATCH_SIZE,), 1, dtype=torch.int32).cuda()
+    input_metadata.flashinfer_prefill_wrapper_paged.end_forward()
+    input_metadata.flashinfer_prefill_wrapper_paged.begin_forward(
+        qo_indptr,
+        kv_indptr,
+        kv_indices,
+        kv_last_page_len,
+        num_qo_heads,
+        num_kv_heads,
+        HEAD_DIM,
+        1,
+    )
+
+    return input_metadata
+
+
+def sp_worker(rank: int = 0, sp_size: int = 1, tp_size: int = 1):
+    torch.manual_seed(42)
+    random.seed(42)
+
+    def init_comm():
+        nccl_init_method = f"tcp://127.0.0.1:28888"
+        init_distributed_environment(
+            backend="nccl",
+            world_size=tp_size,
+            rank=rank,
+            local_rank=rank,
+            distributed_init_method=nccl_init_method,
+        )
+        initialize_model_parallel(
+            tensor_model_parallel_size=tp_size, sequence_parallel_size=sp_size
+        )
+        torch.cuda.set_device(rank)
+
+    init_comm()
+
+    def init_attention():
+        attention = RadixAttention(
+            num_heads=NUM_HEADS // sp_size,
+            head_dim=HEAD_DIM,
+            scaling=SCALING,
+            num_kv_heads=NUM_KV_HEADS,
+            layer_id=LAYER_ID,
+            logit_cap=LOGIT_CAP,
+        )
+        return attention
+
+    attn = init_attention()
+    print("SP worker", rank, "initialized on", torch.cuda.current_device())
+
+    # Computation
+    input_metadata = get_input_metadata(sp_size=sp_size, tp_size=tp_size)
+    q, k, v = gen_qkv(rank, sp_size)
+    qs, ks, vs = [], [], []
+    q_head_idxes = _get_sequence_parallel_head_idxes(
+        NUM_HEADS, NUM_KV_HEADS, rank, sp_size
+    )
+    print(rank, q_head_idxes)
+    for i in range(sp_size):
+        qs.append(
+            q[(QO_LEN // sp_size) * i : (QO_LEN // sp_size) * (i + 1), q_head_idxes]
+        )
+        ks.append(k[(KV_LEN // sp_size) * i : (KV_LEN // sp_size) * (i + 1)])
+        vs.append(v[(KV_LEN // sp_size) * i : (KV_LEN // sp_size) * (i + 1)])
+
+    output = attn.seq_parallel_extend_forward_flashinfer(qs, ks, vs, input_metadata)
+
+    o_truth = reference_attn()
+    o_truth = (
+        o_truth.contiguous()
+        .view(-1, NUM_HEADS, HEAD_DIM)[:, q_head_idxes]
+        .view(-1, NUM_HEADS // sp_size * HEAD_DIM)
+    )
+
+    print("SP worker", rank, "results:")
+    print("Mean: ", torch.mean(torch.abs(output - o_truth)))
+    print("Max: ", torch.max(torch.abs(output - o_truth)))
+    assert torch.allclose(output, o_truth, rtol=1e-2, atol=1e-3)
+
+
+def _get_sequence_parallel_head_idxes(total_num_heads, num_kv_heads, sp_rank, sp_size):
+    group_num = num_kv_heads
+    group_size = total_num_heads // num_kv_heads
+    shard_num_heads = group_size // sp_size
+    idxes = [
+        group_size * i + sp_rank * shard_num_heads + j
+        for i in range(group_num)
+        for j in range(0, shard_num_heads)
+    ]
+    return idxes
+
+
+def reference_attn():
+    torch.manual_seed(42)
+    random.seed(42)
+
+    attn = RadixAttention(
+        num_heads=NUM_HEADS,
+        head_dim=HEAD_DIM,
+        scaling=SCALING,
+        num_kv_heads=NUM_KV_HEADS,
+        layer_id=LAYER_ID,
+        logit_cap=LOGIT_CAP,
+    )
+
+    input_metadata = get_input_metadata()
+    q, k, v = gen_qkv()
+
+    return attn.extend_forward_flashinfer(q, k, v, input_metadata)
+
+
+def main():
+    sp_size = 2
+    tp_size = 2
+
+    multiprocessing.set_start_method("spawn", force=True)
+    sp_procs = []
+    for rank in range(1, sp_size):
+        sp_proc = multiprocessing.Process(
+            target=sp_worker, args=(rank, sp_size, tp_size)
+        )
+        sp_proc.start()
+        sp_procs.append(sp_proc)
+
+    output = sp_worker(0, sp_size, tp_size)
+
+    for sp_proc in sp_procs:
+        sp_proc.join()
+
+
+if __name__ == "__main__":
+    main()
diff --git a/test/srt/test_seq_parallel_attn_kernel_simple.py b/test/srt/test_seq_parallel_attn_kernel_simple.py
new file mode 100644
index 00000000000..1ff578b8f07
--- /dev/null
+++ b/test/srt/test_seq_parallel_attn_kernel_simple.py
@@ -0,0 +1,279 @@
+import pytest
+import torch
+from flashinfer import (
+    BatchDecodeWithPagedKVCacheWrapper,
+    BatchPrefillWithPagedKVCacheWrapper,
+    BatchPrefillWithRaggedKVCacheWrapper,
+)
+from flashinfer.cascade import merge_state
+from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
+
+from sglang.srt.layers.extend_attention import extend_attention_fwd, redundant_attention
+from sglang.srt.layers.token_attention import token_attention_fwd
+
+flashinfer_prefill_wrapper_ragged = None
+flashinfer_prefill_wrapper_paged = None
+flashinfer_decode_wrapper = None
+
+
+def get_next_partition_id(curr_partition_id, num_partitions):
+    assert curr_partition_id < num_partitions
+    return (curr_partition_id - 1) % num_partitions
+
+
+def get_sp_prev_local_rank(rank, num_partitions):
+    return (rank - 1) % num_partitions
+
+
+def get_sp_next_local_rank(rank, num_partitions):
+    return (rank + 1) % num_partitions
+
+
+def append_merge_partition(partition_list, o, s):
+    if len(partition_list) == 0:
+        partition_list.append((o, s))
+    else:
+        o_prev, s_prev = partition_list[-1]
+        o, s = merge_state(o_prev, s_prev, o, s)
+        partition_list[-1] = (o, s)
+
+
+def seq_parallel_attn(
+    batch_size,
+    kv_len,
+    qo_len,
+    num_kv_heads,
+    num_qo_heads,
+    head_dim,
+    q,
+    k,
+    v,
+    rank: int,
+    sp_size: int,
+):
+    """Simulate a sequence parallel attention kernel. It takes full Q, K, and V
+    with simulated communication. TODO: replace with actual communication.
+    """
+    num_partitions = sp_size
+    num_iters = sp_size
+    # NOTE: we assume sequence length is divisible by num_partitions
+    qo_len_per_iter = qo_len // num_iters
+    kv_len_per_partition = kv_len // num_partitions
+
+    qo_indptr = torch.arange(0, batch_size + 1).to(0).int() * qo_len_per_iter
+    kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * kv_len_per_partition
+    flashinfer_prefill_wrapper_ragged.end_forward()
+    flashinfer_prefill_wrapper_ragged.begin_forward(
+        qo_indptr,
+        kv_indptr,
+        num_qo_heads,
+        num_kv_heads,
+        head_dim,
+    )
+
+    kv_indices = torch.arange(0, batch_size * kv_len_per_partition).to(0).int()
+    kv_last_page_len = torch.full((batch_size,), 1, dtype=torch.int32).to(0)
+    flashinfer_prefill_wrapper_paged.end_forward()
+    flashinfer_prefill_wrapper_paged.begin_forward(
+        qo_indptr,
+        kv_indptr,
+        kv_indices,
+        kv_last_page_len,
+        num_qo_heads,
+        num_kv_heads,
+        head_dim,
+        1,
+    )
+
+    local_k, local_v = (
+        k[:, rank * kv_len_per_partition : (rank + 1) * kv_len_per_partition]
+        .contiguous()
+        .view(-1, num_kv_heads, head_dim),
+        v[:, rank * kv_len_per_partition : (rank + 1) * kv_len_per_partition]
+        .contiguous()
+        .view(-1, num_kv_heads, head_dim),
+    )
+    k_partition, v_partition = local_k, local_v
+
+    owned_pids = [rank]
+    owned_partitions = [None for _ in range(num_partitions)]
+    owned_partitions[rank] = (local_k, local_v)
+    o_partitions = [[] for _ in range(num_partitions)]
+
+    to_rank = rank  # which SP worker to send my sequence KV partition to.
+    from_rank = rank  # which SP worker to receive the sequence KV partition from.
+
+    pid = rank  # start from the worker's own partition
+    for _ in range(num_iters):
+        # TODO: send-recv communication here
+        to_rank = get_sp_next_local_rank(to_rank, num_partitions)
+        # send_to(to_rank, k, v)
+        q_partition = q[:, pid * qo_len_per_iter : (pid + 1) * qo_len_per_iter]
+        k_partition, v_partition = owned_partitions[pid]
+        # Ragged attention computation for self attention within the partition
+        o, s = flashinfer_prefill_wrapper_ragged.forward_return_lse(
+            q_partition.contiguous().view(-1, num_qo_heads, head_dim),
+            k_partition.contiguous().view(-1, num_kv_heads, head_dim),
+            v_partition.contiguous().view(-1, num_kv_heads, head_dim),
+        )
+        append_merge_partition(o_partitions[pid], o, s)
+        # Paged attention computation for cross partition attention
+        # NOTE: below schedule is for load balancing
+        for existing_pid in owned_pids:
+            if existing_pid == pid:
+                continue
+            i, j = (existing_pid, pid) if existing_pid > pid else (pid, existing_pid)
+            q_data = q[:, i * qo_len_per_iter : (i + 1) * qo_len_per_iter]
+            kv_data = torch.stack(owned_partitions[j], dim=1)
+            o, s = flashinfer_prefill_wrapper_paged.forward_return_lse(
+                q_data.contiguous().view(-1, num_qo_heads, head_dim),
+                kv_data,
+                causal=False,
+            )
+            append_merge_partition(o_partitions[i], o, s)
+
+        # TODO: send-recv communication here
+        from_rank = get_sp_prev_local_rank(from_rank, num_partitions)
+        # recv_from(from_rank, k, v)
+        pid = from_rank
+        kv_recved = (
+            k[:, pid * kv_len_per_partition : (pid + 1) * kv_len_per_partition]
+            .contiguous()
+            .view(-1, num_kv_heads, head_dim),
+            v[:, pid * kv_len_per_partition : (pid + 1) * kv_len_per_partition]
+            .contiguous()
+            .view(-1, num_kv_heads, head_dim),
+        )
+        owned_pids.append(pid)
+        owned_partitions[pid] = kv_recved
+
+    # Reshape all o tensors so that we can concatenate along the sequence dimension
+    # we must have len(partition_list) == 1 here
+    os = [
+        o.view(batch_size, qo_len_per_iter, num_qo_heads, head_dim)
+        for partition_list in o_partitions
+        for o, _ in partition_list
+    ]
+    o = torch.cat(os, dim=1).view(
+        -1, num_qo_heads, head_dim
+    )  # restore the original shape
+    return o
+
+
+@pytest.mark.parametrize("batch_size", [12, 37, 67])
+@pytest.mark.parametrize("kv_len", [54, 97])
+@pytest.mark.parametrize("qo_len", [37, 17])
+@pytest.mark.parametrize("num_kv_heads", [4])
+@pytest.mark.parametrize("num_qo_heads", [32, 4])
+@pytest.mark.parametrize("head_dim", [128])
+def test_seq_parallel_prefill(
+    batch_size,
+    kv_len,
+    qo_len,
+    num_kv_heads,
+    num_qo_heads,
+    head_dim,
+    rank: int = 0,
+    sp_size: int = 2,
+):
+    init_flashinfer(num_qo_heads, num_kv_heads)
+
+    q = torch.randn(batch_size, qo_len, num_qo_heads, head_dim).to(0).half()
+    k = torch.randn(batch_size, kv_len, num_kv_heads, head_dim).to(0).half()
+    v = torch.randn(batch_size, kv_len, num_kv_heads, head_dim).to(0).half()
+
+    def reference_impl_ragged():
+        qo_indptr = torch.arange(0, batch_size + 1).to(0).int() * qo_len
+        kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * kv_len
+
+        flashinfer_prefill_wrapper_ragged.end_forward()
+        flashinfer_prefill_wrapper_ragged.begin_forward(
+            qo_indptr,
+            kv_indptr,
+            num_qo_heads,
+            num_kv_heads,
+            head_dim,
+        )
+        o = flashinfer_prefill_wrapper_ragged.forward(
+            q.contiguous().view(-1, num_qo_heads, head_dim),
+            k.contiguous().view(-1, num_kv_heads, head_dim),
+            v.contiguous().view(-1, num_kv_heads, head_dim),
+        )
+        flashinfer_prefill_wrapper_ragged.end_forward()
+        return o
+
+    def reference_impl_paged():
+        qo_indptr = torch.arange(0, batch_size + 1).to(0).int() * qo_len
+        total_tokens = kv_len * batch_size
+
+        kv_data = torch.zeros(total_tokens, 2, num_kv_heads, head_dim).to(0).half()
+        kv_data[:, 0] = k.contiguous().view(-1, num_kv_heads, head_dim)
+        kv_data[:, 1] = v.contiguous().view(-1, num_kv_heads, head_dim)
+        kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * kv_len
+        kv_indices = torch.arange(0, total_tokens).to(0).int()
+        kv_last_page_len = torch.full((batch_size,), 1, dtype=torch.int32).to(0)
+
+        flashinfer_prefill_wrapper_paged.end_forward()
+        flashinfer_prefill_wrapper_paged.begin_forward(
+            qo_indptr,
+            kv_indptr,
+            kv_indices,
+            kv_last_page_len,
+            num_qo_heads,
+            num_kv_heads,
+            head_dim,
+            1,
+        )
+        o = flashinfer_prefill_wrapper_paged.forward(
+            q.contiguous().view(-1, num_qo_heads, head_dim), kv_data
+        )
+        flashinfer_prefill_wrapper_paged.end_forward()
+        return o
+
+    o_sp = seq_parallel_attn(
+        batch_size,
+        kv_len,
+        qo_len,
+        num_kv_heads,
+        num_qo_heads,
+        head_dim,
+        q,
+        k,
+        v,
+        rank=1,
+        sp_size=4,
+    )
+    o_truth = reference_impl_paged()
+
+    print("Mean: ", torch.mean(torch.abs(o_sp - o_truth)))
+    print("Max: ", torch.max(torch.abs(o_sp - o_truth)))
+    assert torch.allclose(o_sp, o_truth, rtol=1e-2, atol=1e-3)
+
+
+def init_flashinfer(num_attention_heads, num_kv_heads):
+    if not _grouped_size_compiled_for_decode_kernels(num_attention_heads, num_kv_heads):
+        use_tensor_cores = True
+    else:
+        use_tensor_cores = False
+
+    workspace_buffer = torch.empty(
+        3, 128 * 1024 * 1024, dtype=torch.int8, device="cuda"
+    )
+
+    global flashinfer_prefill_wrapper_ragged, flashinfer_prefill_wrapper_paged, flashinfer_decode_wrapper
+
+    flashinfer_prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
+        workspace_buffer[0], "NHD"
+    )
+    flashinfer_prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper(
+        workspace_buffer[1], "NHD"
+    )
+    flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
+        workspace_buffer[2], "NHD", use_tensor_cores=use_tensor_cores
+    )
+
+
+if __name__ == "__main__":
+    test_seq_parallel_prefill(12, 128, 128, 8, 8, 128, rank=3, sp_size=4)
+    test_seq_parallel_prefill(12, 4096, 4096, 8, 8, 128, rank=4, sp_size=8)
+    test_seq_parallel_prefill(12, 1024, 1024, 32, 32, 128, rank=1, sp_size=2)
diff --git a/test/srt/test_sp_comm_group.py b/test/srt/test_sp_comm_group.py
new file mode 100644
index 00000000000..17d0a45fd9a
--- /dev/null
+++ b/test/srt/test_sp_comm_group.py
@@ -0,0 +1,70 @@
+import multiprocessing
+import random
+
+import torch
+from vllm.distributed import init_distributed_environment
+
+from sglang.srt.layers.parallel_utils import get_sp_group, initialize_model_parallel
+
+NUM_TOKENS = 3
+NUM_KV_HEADS = 2
+HEAD_DIM = 4
+
+
+def gen_kv(rank: int = 0, sp_size: int = 1):
+    torch.manual_seed(42)
+    random.seed(42)
+    k = torch.randn(NUM_TOKENS, NUM_KV_HEADS, HEAD_DIM).cuda().half()
+    v = torch.randn(NUM_TOKENS, NUM_KV_HEADS, HEAD_DIM).cuda().half()
+
+    return k, v
+
+
+def sp_worker(rank: int = 0, sp_size: int = 1, tp_size: int = 1):
+    torch.manual_seed(42)
+    random.seed(42)
+
+    nccl_init_method = f"tcp://127.0.0.1:28888"
+    init_distributed_environment(
+        backend="nccl",
+        world_size=tp_size,
+        rank=rank,
+        local_rank=rank,
+        distributed_init_method=nccl_init_method,
+    )
+    initialize_model_parallel(
+        tensor_model_parallel_size=tp_size, sequence_parallel_size=sp_size
+    )
+    torch.cuda.set_device(rank)
+    print("SP worker", rank, "initialized on", torch.cuda.current_device())
+
+    k, v = gen_kv(rank, sp_size)
+
+    ks = get_sp_group().all_gather(k.view(1, *k.shape), dim=0)
+    vs = get_sp_group().all_gather(v.view(1, *v.shape), dim=0)
+
+    print("SP worker", rank, "all-gathered ks", ks)
+    print("SP worker", rank, "all-gathered vs", vs)
+
+
+def main():
+    sp_size = 2
+    tp_size = 2
+
+    multiprocessing.set_start_method("spawn", force=True)
+    sp_procs = []
+    for rank in range(1, sp_size):
+        sp_proc = multiprocessing.Process(
+            target=sp_worker, args=(rank, sp_size, tp_size)
+        )
+        sp_proc.start()
+        sp_procs.append(sp_proc)
+
+    sp_worker(0, sp_size, tp_size)
+
+    for sp_proc in sp_procs:
+        sp_proc.join()
+
+
+if __name__ == "__main__":
+    main()
diff --git a/test/srt/test_sp_decode_attn.py b/test/srt/test_sp_decode_attn.py
new file mode 100644
index 00000000000..084c95985d9
--- /dev/null
+++ b/test/srt/test_sp_decode_attn.py
@@ -0,0 +1,191 @@
+import multiprocessing
+import random
+
+import torch
+from flashinfer import BatchDecodeWithPagedKVCacheWrapper, merge_state
+from vllm.distributed import init_distributed_environment
+
+from sglang.srt.layers.parallel_utils import get_sp_group, initialize_model_parallel
+
+NUM_HEADS = 32
+HEAD_DIM = 128
+SCALING = 1
+NUM_KV_HEADS = 8
+LAYER_ID = 0
+LOGIT_CAP = -1
+
+
+BATCH_SIZE = 3
+SEQ_LENS = [16, 64, 128]
+
+
+def gen_qkv(sp_rank: int = 0, sp_size: int = 1):
+    torch.manual_seed(42)
+    random.seed(42)
+
+    q = torch.randn(BATCH_SIZE, NUM_HEADS, HEAD_DIM).cuda().half()
+    total_num_context_tokens = sum(SEQ_LENS)
+    kv_cache = (
+        torch.randn(total_num_context_tokens, 2, NUM_KV_HEADS, HEAD_DIM).cuda().half()
+    )
+
+    if sp_size > 1:
+        q_head_idxes = _get_sequence_parallel_head_idxes(
+            NUM_HEADS, NUM_KV_HEADS, sp_rank, sp_size
+        )
+        q = q[:, q_head_idxes].contiguous()
+
+        sp_kv_cache = (
+            torch.empty(total_num_context_tokens // sp_size, 2, NUM_KV_HEADS, HEAD_DIM)
+            .cuda()
+            .half()
+        )
+        sp_stt, stt = 0, 0
+        for i in range(BATCH_SIZE):
+            seq_len = SEQ_LENS[i]
+            sp_seq_len = seq_len // sp_size
+
+            sp_end = sp_stt + sp_seq_len
+            end = stt + seq_len
+
+            sp_kv_cache[sp_stt:sp_end] = kv_cache[
+                stt + sp_rank * sp_seq_len : stt + (sp_rank + 1) * sp_seq_len
+            ]
+            sp_stt = sp_end
+            stt = end
+        kv_cache = sp_kv_cache
+
+    return q, kv_cache
+
+
+def init_flashinfer(sp_size: int = 1, tp_size: int = 1):
+
+    workspace_buffer = torch.empty(
+        1, 128 * 1024 * 1024, dtype=torch.int8, device="cuda"
+    )
+
+    flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
+        workspace_buffer[0], "NHD"
+    )
+
+    num_qo_heads = NUM_HEADS
+    num_kv_heads = NUM_KV_HEADS
+
+    seq_lens = torch.tensor(SEQ_LENS, dtype=torch.int32, device="cuda")
+    seq_lens = seq_lens // sp_size
+    total_num_context_tokens = sum(SEQ_LENS) // sp_size
+
+    kv_indptr = torch.zeros((BATCH_SIZE + 1,), dtype=torch.int32, device="cuda")
+    kv_indptr[1:] = torch.cumsum(seq_lens, dim=0)
+    kv_indices = torch.arange(
+        total_num_context_tokens, dtype=torch.int32, device="cuda"
+    )
+    kv_last_page_len = torch.ones((BATCH_SIZE,), dtype=torch.int32, device="cuda")
+
+    flashinfer_decode_wrapper.end_forward()
+    flashinfer_decode_wrapper.begin_forward(
+        kv_indptr,
+        kv_indices,
+        kv_last_page_len,
+        num_qo_heads,
+        num_kv_heads,
+        HEAD_DIM,
+        1,
+    )
+
+    return flashinfer_decode_wrapper
+
+
+def sp_worker(rank: int = 0, sp_size: int = 1, tp_size: int = 1):
+    torch.manual_seed(42)
+    random.seed(42)
+
+    def init_comm():
+        nccl_init_method = f"tcp://127.0.0.1:28888"
+        init_distributed_environment(
+            backend="nccl",
+            world_size=tp_size,
+            rank=rank,
+            local_rank=rank,
+            distributed_init_method=nccl_init_method,
+        )
+        initialize_model_parallel(
+            tensor_model_parallel_size=tp_size, sequence_parallel_size=sp_size
+        )
+        torch.cuda.set_device(rank)
+
+    init_comm()
+
+    print("SP worker", rank, "initialized on", torch.cuda.current_device())
+
+    decode_wrapper = init_flashinfer(sp_size=sp_size, tp_size=tp_size)
+    q, kv_cache = gen_qkv(rank, sp_size)
+
+    gathered_q = get_sp_group().all_gather(q.view(1, *q.shape), dim=0)
+    q = torch.empty_like(gathered_q).view(-1, NUM_HEADS, HEAD_DIM)
+
+    for i in range(sp_size):
+        idxes = _get_sequence_parallel_head_idxes(NUM_HEADS, NUM_KV_HEADS, i, sp_size)
+        q[:, idxes] = gathered_q[i]
+
+    # Computation
+    o, s = decode_wrapper.forward_return_lse(q, kv_cache)
+
+    os = get_sp_group().all_gather(o.view(1, *o.shape), dim=0)
+    ss = get_sp_group().all_gather(s.view(1, *s.shape), dim=0)
+    for i in range(sp_size):
+        if i != rank:
+            o, s = merge_state(os[i], ss[i], o, s)
+    output = o
+
+    o_truth = reference_attn()
+
+    print("SP worker", rank, "results:")
+    print("Mean: ", torch.mean(torch.abs(output - o_truth)))
+    print("Max: ", torch.max(torch.abs(output - o_truth)))
+    assert torch.allclose(output, o_truth, rtol=1e-2, atol=1e-3)
+
+
+def _get_sequence_parallel_head_idxes(total_num_heads, num_kv_heads, sp_rank, sp_size):
+    group_num = num_kv_heads
+    group_size = total_num_heads // num_kv_heads
+    shard_num_heads = group_size // sp_size
+    idxes = [
+        group_size * i + sp_rank * shard_num_heads + j
+        for i in range(group_num)
+        for j in range(0, shard_num_heads)
+    ]
+    return idxes
+
+
+def reference_attn():
+    torch.manual_seed(42)
+    random.seed(42)
+
+    decode_wrapper = init_flashinfer()
+    q, kv_cache = gen_qkv()
+
+    return decode_wrapper.forward(q, kv_cache)
+
+
+def main():
+    sp_size = 2
+    tp_size = 2
+
+    multiprocessing.set_start_method("spawn", force=True)
+    sp_procs = []
+    for rank in range(1, sp_size):
+        sp_proc = multiprocessing.Process(
+            target=sp_worker, args=(rank, sp_size, tp_size)
+        )
+        sp_proc.start()
+        sp_procs.append(sp_proc)
+
+    sp_worker(0, sp_size, tp_size)
+
+    for sp_proc in sp_procs:
+        sp_proc.join()
+
+
+if __name__ == "__main__":
+    main()