Skip to content

Commit bd91dca

Browse files
authored
feat: add flush_cache endpoint to sglang (#1769)
1 parent b204456 commit bd91dca

File tree

3 files changed

+183
-3
lines changed

3 files changed

+183
-3
lines changed

examples/sglang/components/decode_worker.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,14 @@ async def generate(self, request: str):
5454
async for result in results:
5555
yield result
5656

57+
async def flush_cache(self, request: dict):
58+
_ = request
59+
asyncio.create_task(self.engine.tokenizer_manager.flush_cache())
60+
yield {
61+
"status": "success",
62+
"message": "Cache flush initiated. Check backend logs for status",
63+
}
64+
5765

5866
async def graceful_shutdown(runtime):
5967
logging.info("Received shutdown signal, shutting down DistributedRuntime")
@@ -89,8 +97,13 @@ async def init(runtime: DistributedRuntime, server_args: ServerArgs):
8997
component = runtime.namespace("dynamo").component("decode")
9098
await component.create_service()
9199

92-
endpoint = component.endpoint("generate")
93-
await endpoint.serve_endpoint(handler.generate)
100+
gen_endpoint = component.endpoint("generate")
101+
flush_endpoint = component.endpoint("flush_cache")
102+
103+
tasks = [gen_endpoint.serve_endpoint(handler.generate)]
104+
tasks.append(flush_endpoint.serve_endpoint(handler.flush_cache))
105+
106+
await asyncio.gather(*tasks)
94107

95108

96109
if __name__ == "__main__":

examples/sglang/components/worker.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,14 @@ async def _prefill_generator(self, prefill):
242242
async for _ in prefill:
243243
pass
244244

245+
async def flush_cache(self, request: dict):
246+
_ = request
247+
asyncio.create_task(self.engine.tokenizer_manager.flush_cache())
248+
yield {
249+
"status": "success",
250+
"message": "Cache flush initiated. Check backend logs for status",
251+
}
252+
245253

246254
async def graceful_shutdown(runtime):
247255
logging.info("Received shutdown signal, shutting down DistributedRuntime")
@@ -305,7 +313,12 @@ async def init(runtime: DistributedRuntime, server_args: ServerArgs):
305313
)
306314
_ = ZmqKvEventPublisher(component=component, config=zmq_config)
307315

308-
await endpoint.serve_endpoint(handler.generate)
316+
tasks = [endpoint.serve_endpoint(handler.generate)]
317+
318+
flush_endpoint = component.endpoint("flush_cache")
319+
tasks.append(flush_endpoint.serve_endpoint(handler.flush_cache))
320+
321+
await asyncio.gather(*tasks)
309322

310323

311324
if __name__ == "__main__":
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
import argparse
5+
import asyncio
6+
import logging
7+
8+
import uvicorn
9+
import uvloop
10+
from fastapi import FastAPI
11+
12+
from dynamo.runtime import DistributedRuntime, dynamo_worker
13+
from dynamo.runtime.logging import configure_dynamo_logging
14+
15+
FLUSH_CACHE_ENDPOINT = "flush_cache"
16+
17+
configure_dynamo_logging()
18+
19+
20+
class SglangHttpServer:
21+
def __init__(self, port: int, runtime: DistributedRuntime, args):
22+
self.port = port
23+
self.app = FastAPI()
24+
self.runtime = runtime
25+
self.args = args
26+
self.setup_routes()
27+
28+
async def _discover_endpoints(self):
29+
"""Discover endpoints that match the pattern"""
30+
etcd_client = self.runtime.etcd_client()
31+
if etcd_client is None:
32+
raise RuntimeError("Runtime has no etcd client; cannot discover endpoints")
33+
34+
prefix = "instances/"
35+
kvs = await etcd_client.kv_get_prefix(prefix)
36+
37+
# Collect (namespace, component) combos that expose flush_cache
38+
discovered = set()
39+
for kv in kvs:
40+
key = kv["key"] if isinstance(kv, dict) else kv.key
41+
if isinstance(key, bytes):
42+
key = key.decode()
43+
if not key.startswith(prefix):
44+
continue
45+
46+
segments = key.split("/")
47+
# Format: instances/<ns>/<comp>/<endpoint:lease>
48+
if len(segments) < 4:
49+
continue
50+
ns, comp, ep_with_lease = segments[1], segments[2], segments[3]
51+
52+
if self.args.ns and ns != self.args.ns:
53+
continue
54+
if self.args.comp and comp != self.args.comp:
55+
continue
56+
57+
ep_name = ep_with_lease.split(":", 1)[0]
58+
if ep_name == self.args.endpoint:
59+
discovered.add((ns, comp))
60+
logging.debug(f"Discovered endpoint: {ns}.{comp}")
61+
62+
logging.debug(
63+
f"Endpoint discovery complete. Found {len(discovered)} matching endpoints"
64+
)
65+
return discovered
66+
67+
def setup_routes(self):
68+
@self.app.post("/flush_cache")
69+
async def flush_cache():
70+
"""Flush the radix cache."""
71+
try:
72+
discovered = await self._discover_endpoints()
73+
74+
if not discovered:
75+
return {"message": "No matching endpoints found", "success": False}
76+
77+
logging.debug(
78+
f"Found components: {', '.join([f'{ns}.{comp}' for ns, comp in discovered])}"
79+
)
80+
81+
for ns, comp in discovered:
82+
ep = (
83+
self.runtime.namespace(ns)
84+
.component(comp)
85+
.endpoint(self.args.endpoint)
86+
)
87+
client = await ep.client()
88+
await client.wait_for_instances()
89+
ids = client.instance_ids()
90+
91+
logging.debug(f"-- {ns}.{comp} : {len(ids)} instances --")
92+
93+
for inst_id in ids:
94+
try:
95+
stream = await client.direct("{}", inst_id)
96+
async for payload in stream:
97+
logging.debug(f"[{ns}.{comp}][{inst_id}] -> {payload}")
98+
except Exception as e:
99+
logging.error(f"[{ns}.{comp}][{inst_id}] flush error: {e}")
100+
101+
return {"message": "Cache flush initiated", "success": True}
102+
except Exception as e:
103+
logging.error(f"Cache flush error: {e}")
104+
return {"message": f"Cache flush failed: {str(e)}", "success": False}
105+
106+
async def start_server(self):
107+
"""Start the HTTP server"""
108+
config = uvicorn.Config(
109+
self.app,
110+
host="0.0.0.0",
111+
port=self.port,
112+
)
113+
server = uvicorn.Server(config)
114+
115+
# Single nice log with available endpoints
116+
logging.info(
117+
f"🚀 SGL engine HTTP server running on http://0.0.0.0:{self.port} - Endpoints: POST /flush_cache"
118+
)
119+
120+
await server.serve()
121+
122+
123+
def parse_args():
124+
p = argparse.ArgumentParser(description="SGLang HTTP server for cache management")
125+
p.add_argument("--port", type=int, default=9001, help="Port to listen on")
126+
p.add_argument(
127+
"--ns",
128+
"--namespace",
129+
default="dynamo",
130+
help="Specify Dynamo namespace (default: discover all)",
131+
)
132+
p.add_argument(
133+
"--comp",
134+
"--component",
135+
default=None,
136+
help="Specify component name (default: discover all)",
137+
)
138+
p.add_argument(
139+
"--endpoint", default=FLUSH_CACHE_ENDPOINT, help="Specify endpoint name"
140+
)
141+
return p.parse_args()
142+
143+
144+
@dynamo_worker(static=False)
145+
async def main(runtime: DistributedRuntime):
146+
args = parse_args()
147+
148+
http_server = SglangHttpServer(args.port, runtime, args)
149+
await http_server.start_server()
150+
151+
152+
if __name__ == "__main__":
153+
uvloop.install()
154+
asyncio.run(main())

0 commit comments

Comments
 (0)