diff --git a/python/sglang/srt/managers/data_parallel_controller.py b/python/sglang/srt/managers/data_parallel_controller.py index 10e8cd80104..aad19c209a4 100644 --- a/python/sglang/srt/managers/data_parallel_controller.py +++ b/python/sglang/srt/managers/data_parallel_controller.py @@ -156,7 +156,7 @@ def launch_tensor_parallel_group( ) for tp_rank in tp_rank_range: reader, writer = mp.Pipe(duplex=False) - gpu_id = base_gpu_id + tp_rank % tp_size_per_node + gpu_id = server_args.base_gpu_id + base_gpu_id + tp_rank % tp_size_per_node proc = mp.Process( target=run_scheduler_process, args=(server_args, port_args, gpu_id, tp_rank, dp_rank, writer), diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 6241ed6b260..4e96fbc26aa 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -1380,6 +1380,10 @@ def run_scheduler_process( dp_rank: Optional[int], pipe_writer, ): + # [For Router] if env var "DP_RANK" exist, set dp_rank to the value of the env var + if dp_rank is None: + dp_rank = int(os.getenv("DP_RANK", -1)) + if dp_rank is None: configure_logger(server_args, prefix=f" TP{tp_rank}") else: diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 5621933a6f7..b1d0b3ea254 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -418,7 +418,7 @@ def launch_engine( ) for tp_rank in tp_rank_range: reader, writer = mp.Pipe(duplex=False) - gpu_id = tp_rank % tp_size_per_node + gpu_id = server_args.base_gpu_id + tp_rank % tp_size_per_node proc = mp.Process( target=run_scheduler_process, args=(server_args, port_args, gpu_id, tp_rank, None, writer), diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 5487f772f4c..78581f40759 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -72,6 +72,7 @@ class ServerArgs: constrained_json_whitespace_pattern: Optional[str] = None watchdog_timeout: float = 300 download_dir: Optional[str] = None + base_gpu_id: int = 0 # Logging log_level: str = "info" @@ -412,6 +413,12 @@ def add_cli_args(parser: argparse.ArgumentParser): default=ServerArgs.download_dir, help="Model download directory.", ) + parser.add_argument( + "--base-gpu-id", + type=int, + default=ServerArgs.base_gpu_id, + help="The base GPU ID to start allocating GPUs from. Useful when running multiple instances on the same machine.", + ) # Logging parser.add_argument( @@ -736,6 +743,7 @@ def check_server_args(self): and (self.lora_paths is None or self.disable_cuda_graph) and (self.lora_paths is None or self.disable_radix_cache) ), "compatibility of lora and cuda graph and radix attention is in progress" + assert self.base_gpu_id >= 0, "base_gpu_id must be non-negative" if isinstance(self.lora_paths, list): lora_paths = self.lora_paths diff --git a/scripts/playground/router/test_tree.py b/scripts/playground/router/test_tree.py new file mode 100644 index 00000000000..af41c738e02 --- /dev/null +++ b/scripts/playground/router/test_tree.py @@ -0,0 +1,207 @@ +import random +import string +import time +import unittest +from typing import Dict, List, Tuple + +from tree import MultiTenantRadixTree + + +class TestMultiTenantRadixTree(unittest.TestCase): + def setUp(self): + self.tree = MultiTenantRadixTree() + + def test_insert_exact_match(self): + """Test 1: Basic insert and exact match operations""" + # Insert a single string for one tenant + self.tree.insert("hello", "tenant1") + matched, tenant = self.tree.prefix_match("hello") + self.assertEqual(matched, "hello") + self.assertEqual(tenant, "tenant1") + + # Insert same string for different tenant + self.tree.insert("hello", "tenant2") + matched, tenant = self.tree.prefix_match("hello") + self.assertIn(tenant, ["tenant1", "tenant2"]) + + # Insert different string for same tenant + self.tree.insert("world", "tenant1") + matched, tenant = self.tree.prefix_match("world") + self.assertEqual(matched, "world") + self.assertEqual(tenant, "tenant1") + + print(self.tree.pretty_print()) + + def test_insert_partial_match(self): + """Test 2: Insert with partial matching scenarios""" + # Test partial matches with common prefixes + self.tree.insert("hello", "tenant1") + print(self.tree.pretty_print()) + self.tree.insert("help", "tenant2") + print(self.tree.pretty_print()) + + # Match exact strings + matched, tenant = self.tree.prefix_match("hello") + self.assertEqual(matched, "hello") + self.assertEqual(tenant, "tenant1") + + matched, tenant = self.tree.prefix_match("help") + self.assertEqual(matched, "help") + self.assertEqual(tenant, "tenant2") + + # Match partial string + matched, tenant = self.tree.prefix_match("hel") + self.assertEqual(matched, "hel") + self.assertIn(tenant, ["tenant1", "tenant2"]) + + # Match longer string + matched, tenant = self.tree.prefix_match("hello_world") + self.assertEqual(matched, "hello") + self.assertEqual(tenant, "tenant1") + + def test_insert_edge_cases(self): + """Test 3: Edge cases for insert and match operations""" + # Empty string + self.tree.insert("", "tenant1") + matched, tenant = self.tree.prefix_match("") + self.assertEqual(matched, "") + self.assertEqual(tenant, "tenant1") + + # Single character + self.tree.insert("a", "tenant1") + matched, tenant = self.tree.prefix_match("a") + self.assertEqual(matched, "a") + self.assertEqual(tenant, "tenant1") + + # Very long string + long_str = "a" * 1000 + self.tree.insert(long_str, "tenant1") + matched, tenant = self.tree.prefix_match(long_str) + self.assertEqual(matched, long_str) + self.assertEqual(tenant, "tenant1") + + # Unicode characters + self.tree.insert("你好", "tenant1") + matched, tenant = self.tree.prefix_match("你好") + self.assertEqual(matched, "你好") + self.assertEqual(tenant, "tenant1") + + def test_simple_eviction(self): + """Test 4: Simple eviction scenarios + Tenant1: limit 10 chars + Tenant2: limit 5 chars + + Should demonstrate: + 1. Basic eviction when size limit exceeded + 2. Proper eviction based on last access time + 3. Verification that shared nodes remain intact for other tenants + """ + # Set up size limits + max_size = {"tenant1": 10, "tenant2": 5} + + # Insert strings for both tenants + self.tree.insert("hello", "tenant1") # size 5 + self.tree.insert("hello", "tenant2") # size 5 + self.tree.insert("world", "tenant2") # size 5, total for tenant2 = 10 + + # Verify initial sizes + sizes_before = self.tree.get_used_size_per_tenant() + self.assertEqual(sizes_before["tenant1"], 5) # "hello" = 5 + self.assertEqual(sizes_before["tenant2"], 10) # "hello" + "world" = 10 + + # Evict - should remove "hello" from tenant2 as it's the oldest + self.tree.evict_tenant_data(max_size) + + # Verify sizes after eviction + sizes_after = self.tree.get_used_size_per_tenant() + self.assertEqual(sizes_after["tenant1"], 5) # Should be unchanged + self.assertEqual(sizes_after["tenant2"], 5) # Only "world" remains + + # Verify "world" remains for tenant2 (was accessed more recently) + matched, tenant = self.tree.prefix_match("world") + self.assertEqual(matched, "world") + self.assertEqual(tenant, "tenant2") + + def test_medium_eviction(self): + """Test 5: Medium complexity eviction scenarios with shared prefixes + Tenant1: limit 10 chars + Tenant2: limit 7 chars (forces one string to be evicted) + + Tree structure after inserts: + └── 'h' [t1, t2] + ├── 'i' [t1, t2] # Oldest for t2 + └── 'e' [t1, t2] + ├── 'llo' [t1, t2] + └── 'y' [t2] # Newest for t2 + + Size calculations: + tenant1: "h"(1) + "i"(1) + "e"(1) + "llo"(3) = 6 chars + tenant2: "h"(1) + "i"(1) + "e"(1) + "llo"(3) + "y"(1) = 7 chars + + After eviction (tenant2 exceeds limit by 1 char): + "hi" should be removed from tenant2 as it's the oldest access + """ + max_size = { + "tenant1": 10, + "tenant2": 6, + } # tenant2 will need to evict one string + + # Create a tree with overlapping prefixes + self.tree.insert("hi", "tenant1") + self.tree.insert("hi", "tenant2") # OLDEST for t2 + + self.tree.insert("hello", "tenant1") + self.tree.insert("hello", "tenant2") + + self.tree.insert("hey", "tenant2") # NEWEST for t2 + + # Verify initial sizes + sizes_before = self.tree.get_used_size_per_tenant() + self.assertEqual(sizes_before["tenant1"], 6) # h(1) + i(1) + e(1) + llo(3) = 6 + self.assertEqual( + sizes_before["tenant2"], 7 + ) # h(1) + i(1) + e(1) + llo(3) + y(1) = 7 + + print("\nTree before eviction:") + print(self.tree.pretty_print()) + + # Evict - should remove "hi" from tenant2 as it's the oldest + self.tree.evict_tenant_data(max_size) + + print("\nTree after eviction:") + print(self.tree.pretty_print()) + + # Verify sizes after eviction + sizes_after = self.tree.get_used_size_per_tenant() + self.assertEqual(sizes_after["tenant1"], 6) # Should be unchanged + self.assertEqual(sizes_after["tenant2"], 6) # h(1) + e(1) + llo(3) + y(1) = 6 + + def test_advanced_eviction(self): + ... + # Create 4 tenants + # Each tenants keeps adding strings with shared prefixes to thousands usage + # Set a strict limit for each tenant to only 100 + # At the end, check whether all of the tenant is under 100 after eviction + + max_size = {"tenant1": 100, "tenant2": 100, "tenant3": 100, "tenant4": 100} + + prefixes = ["aqwefcisdf", "iajsdfkmade", "kjnzxcvewqe", "iejksduqasd"] + for i in range(100): + for j, prefix in enumerate(prefixes): + random_suffix = "".join(random.choices(string.ascii_letters, k=10)) + self.tree.insert(prefix + random_suffix, f"tenant{j+1}") + + sizes_before = self.tree.get_used_size_per_tenant() + print(sizes_before) + + self.tree.evict_tenant_data(max_size) + + sizes_after = self.tree.get_used_size_per_tenant() + print(sizes_after) + # ensure size_after is below max_size + for tenant, size in sizes_after.items(): + self.assertLessEqual(size, max_size[tenant]) + + +if __name__ == "__main__": + unittest.main() diff --git a/scripts/playground/router/tree.py b/scripts/playground/router/tree.py new file mode 100644 index 00000000000..9cbfa7cfe93 --- /dev/null +++ b/scripts/playground/router/tree.py @@ -0,0 +1,292 @@ +import time +from collections import defaultdict +from typing import Dict, List + + +class Node: + def __init__(self): + self.children: Dict[str, Node] = dict() + # We choose to use text because most of the use cases are text-to-text, + # so we can save the tokenizing overhead. + self.text: str = "" + # Maps tenant_id to their last access timestamp + self.tenant_last_access_time: Dict[str, float] = dict() + self.parent = None + + +def shared_prefix_length(s1, s2): + min_length = min(len(s1), len(s2)) + for i in range(min_length): + if s1[i] != s2[i]: + return i + return min_length + + +class MultiTenantRadixTree: + """ + Python Reference of Rust implementation of MultiTenantRadixTree + + MultiTenantRadixTree is the overlap of multiple radix trees by different tenant + Each node in the tree can be owned by multiple tenants, allowing for efficient storage of common prefixes + while maintaining tenant isolation. + + Key concepts: + - Tenant: An entity that owns a subset of the stored strings + - Each node tracks which tenants have access to it via tenant_last_access_time + - The tree structure is shared, but queries can be filtered by tenant_id + """ + + def __init__(self): + self.root = Node() + + def insert(self, s: str, tenant_id: str) -> None: + """ + Insert string 's' and associate it with the given tenant_id. + + Args: + s: The string to insert + tenant_id: The identifier of the tenant who owns this string + """ + curr = self.root + curr_idx = 0 + curr.tenant_last_access_time[tenant_id] = time.time() + + while curr_idx < len(s): + matched_node = None + if s[curr_idx] in curr.children: + matched_node = curr.children[s[curr_idx]] + + if matched_node is None: + # No match => create a new node + new_node = Node() + new_node.text = s[curr_idx:] + new_node.parent = curr + + curr.children[s[curr_idx]] = new_node + curr_idx = len(s) + curr = new_node + curr.tenant_last_access_time[tenant_id] = time.time() + else: + shared_len = shared_prefix_length(s[curr_idx:], matched_node.text) + + # 1. If the matched text is shorter than the node text => split the node + if shared_len < len(matched_node.text): + # Split structure: [matched_node] => [new_node] -> [contracted_matched_node] + + matched_text = matched_node.text[:shared_len] + unmatched_text = matched_node.text[shared_len:] + + new_node = Node() + new_node.text = matched_text + new_node.children = {unmatched_text[0]: matched_node} + new_node.parent = curr + new_node.parent.children[matched_text[0]] = new_node + new_node.tenant_last_access_time = ( + matched_node.tenant_last_access_time.copy() + ) + + # Contract matched node + matched_node.text = unmatched_text + matched_node.parent = new_node + + curr_idx += shared_len + curr = new_node + curr.tenant_last_access_time[tenant_id] = time.time() + # 2. If the matched text is longer or equal to the node text => walk down the node + else: + curr_idx += shared_len + curr = matched_node + curr.tenant_last_access_time[tenant_id] = time.time() + + def prefix_match(self, s: str) -> tuple[str, int]: + """ + Match string 's' with multiple tenants' trees in one operation. + + Args: + s: The string to match + + Returns: + Tuple(str, int): The longest prefix of 's' that matches the tree and the first tenant_id that own the matched prefix + """ + curr = self.root + curr_idx = 0 + + ret_text = "" + ret_tenant = None + + while curr_idx < len(s): + matched_node = None + if s[curr_idx] in curr.children: + matched_node = curr.children[s[curr_idx]] + + if matched_node is None: + break + + shared_len = shared_prefix_length(s[curr_idx:], matched_node.text) + if shared_len == len(matched_node.text): + curr_idx += shared_len + curr = matched_node + else: + curr_idx += shared_len + curr = matched_node + break + + selected_tenant = list(curr.tenant_last_access_time.keys())[0] + + # traverse back to the root to update last access time for the selected tenant + while curr != self.root: + curr.tenant_last_access_time[selected_tenant] = time.time() + curr = curr.parent + + return s[:curr_idx], selected_tenant + + def evict_tenant_data(self, max_size_per_tenant: Dict[str, int]) -> None: + """ + Evict data for tenants that have exceeded their storage limits. + + Args: + max_size_per_tenant: Dictionary mapping tenant_id to their maximum allowed storage size + """ + + def leaf_of(node): + """ + If the node is a leaf for a tenant, add tenant_id to the return list + This will return list of tenant ids + If not a leaf for all tenants, return [] + """ + candidates = dict([(k, True) for k in node.tenant_last_access_time.keys()]) + + for n in node.children.values(): + for c in n.tenant_last_access_time.keys(): + candidates[c] = False + + return [k for k, v in candidates.items() if v] + + # maintain a heap with (time, tenant, node) as the value + import heapq + + # 1. traverse the tree to + # a. add all the leaves into a heap (a node with N tenants will be added N times into the heap) + # b. calculate the used size for each tenant + # do a dfs with stack + stack = [self.root] + pq = [] + used_size_per_tenant = defaultdict(int) + + while stack: + curr = stack.pop() + for t in curr.tenant_last_access_time.keys(): + used_size_per_tenant[t] += len(curr.text) + + for c in curr.children.values(): + stack.append(c) + + # if the node is a leaf for a tenant, add the tenant to the heap + tenants = leaf_of(curr) + for t in tenants: + heapq.heappush(pq, (curr.tenant_last_access_time[t], t, curr)) + + # 2. pop the heap + # a. if the tenant's used size is less than the limit, continue + # b. if the tenant's used size is greater than the limit, remove the leaf and update the used size, and add its parent to the heap + while len(pq) > 0: + time, tenant, node = heapq.heappop(pq) + if used_size_per_tenant[tenant] <= max_size_per_tenant[tenant]: + continue + + # remove the leaf + used_size_per_tenant[tenant] -= len(node.text) + del node.tenant_last_access_time[tenant] + # if no children and no tenants, remove the node + if len(node.children) == 0 and len(node.tenant_last_access_time) == 0: + del node.parent.children[node.text[0]] + + # add its parent to the heap + if tenant in leaf_of(node.parent): + heapq.heappush( + pq, + (node.parent.tenant_last_access_time[tenant], tenant, node.parent), + ) + + def get_used_size_per_tenant(self) -> Dict[str, int]: + """ + Calculate the used storage size for each tenant. + + Returns: + Dict[str, int]: A dictionary mapping tenant_id to their used storage size + """ + used_size_per_tenant = defaultdict(int) + + stack = [self.root] + while stack: + curr = stack.pop() + for t in curr.tenant_last_access_time.keys(): + used_size_per_tenant[t] += len(curr.text) + + for c in curr.children.values(): + stack.append(c) + + return used_size_per_tenant + + def remove_tenant(self, tenant_id: str) -> None: + """ + Remove all data associated with a specific tenant from the tree. + This operation maintains the integrity of the shared tree structure while + removing only the specified tenant's access information. + + Args: + tenant_id: The identifier of the tenant whose data should be removed + """ + # TODO: Implementation needed + pass + + def pretty_print(self) -> str: + """ + Returns a string representation of the tree showing the structure, tenant ownership, + and leaf status for each node. + + Returns: + str: A formatted string showing the tree hierarchy with tenant information + """ + + def _node_to_str(node: Node, prefix: str = "", is_last: bool = True) -> str: + # Current node representation + node_str = prefix + node_str += "└── " if is_last else "├── " + + # Add node text + node_str += f"'{node.text}' [" + + # Add tenant information including both timestamp and leaf status + tenant_info = [] + for tid, ts in node.tenant_last_access_time.items(): + time_str = ( + time.strftime("%H:%M:%S.", time.localtime(ts)) + + f"{(ts % 1):0.3f}"[2:] + ) + tenant_info.append(f"{tid} | {time_str}") + + node_str += ", ".join(tenant_info) + node_str += "]\n" + + # Handle children + children = list(node.children.items()) + for i, (char, child) in enumerate(children): + is_last_child = i == len(children) - 1 + # Adjust prefix for children based on whether this is the last child + new_prefix = prefix + (" " if is_last else "│ ") + node_str += _node_to_str(child, new_prefix, is_last_child) + + return node_str + + if not self.root.children: + return "Empty tree" + + # Start with root's children since root itself is just an empty node + result = "" + children = list(self.root.children.items()) + for i, (char, child) in enumerate(children): + is_last = i == len(children) - 1 + result += _node_to_str(child, "", is_last) + + return result