Skip to content

Commit

Permalink
[router] add base_gpu_id server args & merged radix tree python refer…
Browse files Browse the repository at this point in the history
  • Loading branch information
ByronHsu authored Nov 22, 2024
1 parent f6f7137 commit 30af7df
Show file tree
Hide file tree
Showing 6 changed files with 513 additions and 2 deletions.
2 changes: 1 addition & 1 deletion python/sglang/srt/managers/data_parallel_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def launch_tensor_parallel_group(
)
for tp_rank in tp_rank_range:
reader, writer = mp.Pipe(duplex=False)
gpu_id = base_gpu_id + tp_rank % tp_size_per_node
gpu_id = server_args.base_gpu_id + base_gpu_id + tp_rank % tp_size_per_node
proc = mp.Process(
target=run_scheduler_process,
args=(server_args, port_args, gpu_id, tp_rank, dp_rank, writer),
Expand Down
4 changes: 4 additions & 0 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1380,6 +1380,10 @@ def run_scheduler_process(
dp_rank: Optional[int],
pipe_writer,
):
# [For Router] if env var "DP_RANK" exist, set dp_rank to the value of the env var
if dp_rank is None:
dp_rank = int(os.getenv("DP_RANK", -1))

if dp_rank is None:
configure_logger(server_args, prefix=f" TP{tp_rank}")
else:
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ def launch_engine(
)
for tp_rank in tp_rank_range:
reader, writer = mp.Pipe(duplex=False)
gpu_id = tp_rank % tp_size_per_node
gpu_id = server_args.base_gpu_id + tp_rank % tp_size_per_node
proc = mp.Process(
target=run_scheduler_process,
args=(server_args, port_args, gpu_id, tp_rank, None, writer),
Expand Down
8 changes: 8 additions & 0 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class ServerArgs:
constrained_json_whitespace_pattern: Optional[str] = None
watchdog_timeout: float = 300
download_dir: Optional[str] = None
base_gpu_id: int = 0

# Logging
log_level: str = "info"
Expand Down Expand Up @@ -412,6 +413,12 @@ def add_cli_args(parser: argparse.ArgumentParser):
default=ServerArgs.download_dir,
help="Model download directory.",
)
parser.add_argument(
"--base-gpu-id",
type=int,
default=ServerArgs.base_gpu_id,
help="The base GPU ID to start allocating GPUs from. Useful when running multiple instances on the same machine.",
)

# Logging
parser.add_argument(
Expand Down Expand Up @@ -736,6 +743,7 @@ def check_server_args(self):
and (self.lora_paths is None or self.disable_cuda_graph)
and (self.lora_paths is None or self.disable_radix_cache)
), "compatibility of lora and cuda graph and radix attention is in progress"
assert self.base_gpu_id >= 0, "base_gpu_id must be non-negative"

if isinstance(self.lora_paths, list):
lora_paths = self.lora_paths
Expand Down
207 changes: 207 additions & 0 deletions scripts/playground/router/test_tree.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
import random
import string
import time
import unittest
from typing import Dict, List, Tuple

from tree import MultiTenantRadixTree


class TestMultiTenantRadixTree(unittest.TestCase):
def setUp(self):
self.tree = MultiTenantRadixTree()

def test_insert_exact_match(self):
"""Test 1: Basic insert and exact match operations"""
# Insert a single string for one tenant
self.tree.insert("hello", "tenant1")
matched, tenant = self.tree.prefix_match("hello")
self.assertEqual(matched, "hello")
self.assertEqual(tenant, "tenant1")

# Insert same string for different tenant
self.tree.insert("hello", "tenant2")
matched, tenant = self.tree.prefix_match("hello")
self.assertIn(tenant, ["tenant1", "tenant2"])

# Insert different string for same tenant
self.tree.insert("world", "tenant1")
matched, tenant = self.tree.prefix_match("world")
self.assertEqual(matched, "world")
self.assertEqual(tenant, "tenant1")

print(self.tree.pretty_print())

def test_insert_partial_match(self):
"""Test 2: Insert with partial matching scenarios"""
# Test partial matches with common prefixes
self.tree.insert("hello", "tenant1")
print(self.tree.pretty_print())
self.tree.insert("help", "tenant2")
print(self.tree.pretty_print())

# Match exact strings
matched, tenant = self.tree.prefix_match("hello")
self.assertEqual(matched, "hello")
self.assertEqual(tenant, "tenant1")

matched, tenant = self.tree.prefix_match("help")
self.assertEqual(matched, "help")
self.assertEqual(tenant, "tenant2")

# Match partial string
matched, tenant = self.tree.prefix_match("hel")
self.assertEqual(matched, "hel")
self.assertIn(tenant, ["tenant1", "tenant2"])

# Match longer string
matched, tenant = self.tree.prefix_match("hello_world")
self.assertEqual(matched, "hello")
self.assertEqual(tenant, "tenant1")

def test_insert_edge_cases(self):
"""Test 3: Edge cases for insert and match operations"""
# Empty string
self.tree.insert("", "tenant1")
matched, tenant = self.tree.prefix_match("")
self.assertEqual(matched, "")
self.assertEqual(tenant, "tenant1")

# Single character
self.tree.insert("a", "tenant1")
matched, tenant = self.tree.prefix_match("a")
self.assertEqual(matched, "a")
self.assertEqual(tenant, "tenant1")

# Very long string
long_str = "a" * 1000
self.tree.insert(long_str, "tenant1")
matched, tenant = self.tree.prefix_match(long_str)
self.assertEqual(matched, long_str)
self.assertEqual(tenant, "tenant1")

# Unicode characters
self.tree.insert("你好", "tenant1")
matched, tenant = self.tree.prefix_match("你好")
self.assertEqual(matched, "你好")
self.assertEqual(tenant, "tenant1")

def test_simple_eviction(self):
"""Test 4: Simple eviction scenarios
Tenant1: limit 10 chars
Tenant2: limit 5 chars
Should demonstrate:
1. Basic eviction when size limit exceeded
2. Proper eviction based on last access time
3. Verification that shared nodes remain intact for other tenants
"""
# Set up size limits
max_size = {"tenant1": 10, "tenant2": 5}

# Insert strings for both tenants
self.tree.insert("hello", "tenant1") # size 5
self.tree.insert("hello", "tenant2") # size 5
self.tree.insert("world", "tenant2") # size 5, total for tenant2 = 10

# Verify initial sizes
sizes_before = self.tree.get_used_size_per_tenant()
self.assertEqual(sizes_before["tenant1"], 5) # "hello" = 5
self.assertEqual(sizes_before["tenant2"], 10) # "hello" + "world" = 10

# Evict - should remove "hello" from tenant2 as it's the oldest
self.tree.evict_tenant_data(max_size)

# Verify sizes after eviction
sizes_after = self.tree.get_used_size_per_tenant()
self.assertEqual(sizes_after["tenant1"], 5) # Should be unchanged
self.assertEqual(sizes_after["tenant2"], 5) # Only "world" remains

# Verify "world" remains for tenant2 (was accessed more recently)
matched, tenant = self.tree.prefix_match("world")
self.assertEqual(matched, "world")
self.assertEqual(tenant, "tenant2")

def test_medium_eviction(self):
"""Test 5: Medium complexity eviction scenarios with shared prefixes
Tenant1: limit 10 chars
Tenant2: limit 7 chars (forces one string to be evicted)
Tree structure after inserts:
└── 'h' [t1, t2]
├── 'i' [t1, t2] # Oldest for t2
└── 'e' [t1, t2]
├── 'llo' [t1, t2]
└── 'y' [t2] # Newest for t2
Size calculations:
tenant1: "h"(1) + "i"(1) + "e"(1) + "llo"(3) = 6 chars
tenant2: "h"(1) + "i"(1) + "e"(1) + "llo"(3) + "y"(1) = 7 chars
After eviction (tenant2 exceeds limit by 1 char):
"hi" should be removed from tenant2 as it's the oldest access
"""
max_size = {
"tenant1": 10,
"tenant2": 6,
} # tenant2 will need to evict one string

# Create a tree with overlapping prefixes
self.tree.insert("hi", "tenant1")
self.tree.insert("hi", "tenant2") # OLDEST for t2

self.tree.insert("hello", "tenant1")
self.tree.insert("hello", "tenant2")

self.tree.insert("hey", "tenant2") # NEWEST for t2

# Verify initial sizes
sizes_before = self.tree.get_used_size_per_tenant()
self.assertEqual(sizes_before["tenant1"], 6) # h(1) + i(1) + e(1) + llo(3) = 6
self.assertEqual(
sizes_before["tenant2"], 7
) # h(1) + i(1) + e(1) + llo(3) + y(1) = 7

print("\nTree before eviction:")
print(self.tree.pretty_print())

# Evict - should remove "hi" from tenant2 as it's the oldest
self.tree.evict_tenant_data(max_size)

print("\nTree after eviction:")
print(self.tree.pretty_print())

# Verify sizes after eviction
sizes_after = self.tree.get_used_size_per_tenant()
self.assertEqual(sizes_after["tenant1"], 6) # Should be unchanged
self.assertEqual(sizes_after["tenant2"], 6) # h(1) + e(1) + llo(3) + y(1) = 6

def test_advanced_eviction(self):
...
# Create 4 tenants
# Each tenants keeps adding strings with shared prefixes to thousands usage
# Set a strict limit for each tenant to only 100
# At the end, check whether all of the tenant is under 100 after eviction

max_size = {"tenant1": 100, "tenant2": 100, "tenant3": 100, "tenant4": 100}

prefixes = ["aqwefcisdf", "iajsdfkmade", "kjnzxcvewqe", "iejksduqasd"]
for i in range(100):
for j, prefix in enumerate(prefixes):
random_suffix = "".join(random.choices(string.ascii_letters, k=10))
self.tree.insert(prefix + random_suffix, f"tenant{j+1}")

sizes_before = self.tree.get_used_size_per_tenant()
print(sizes_before)

self.tree.evict_tenant_data(max_size)

sizes_after = self.tree.get_used_size_per_tenant()
print(sizes_after)
# ensure size_after is below max_size
for tenant, size in sizes_after.items():
self.assertLessEqual(size, max_size[tenant])


if __name__ == "__main__":
unittest.main()
Loading

0 comments on commit 30af7df

Please sign in to comment.