-
Notifications
You must be signed in to change notification settings - Fork 233
[skyrl-train][refactor] 2/N Inference Server Refactor -- RemoteInferenceClient #904
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request refactors the inference engine by replacing the InferenceEngineInterface with a new RemoteInferenceClient for HTTP-based inference, introducing new modules for common utilities, protocols, server groups, and a robust router. While the changes are well-structured and include comprehensive unit and GPU CI tests, it introduces significant security risks. The most critical issue is the use of pickle.loads in the vLLM worker extension, which provides a direct path to Remote Code Execution (RCE). Additionally, the lack of authentication on sensitive control plane and weight synchronization endpoints in both the router and the server actor exposes the cluster to unauthorized control and potential weight hijacking. These security concerns must be addressed before deployment in untrusted network environments.
|
|
||
| # Unpickle init_info to restore the original object type | ||
| assert isinstance(init_info, bytes), f"Expected bytes, got {type(init_info).__name__}" | ||
| init_info = pickle.loads(init_info) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The init_weight_update_communicator method uses pickle.loads() to deserialize init_info. This is a critical security vulnerability because pickle is inherently insecure and can be exploited to execute arbitrary code during deserialization. An attacker who can trigger this RPC call with a malicious payload can achieve Remote Code Execution (RCE) on all vLLM workers.
Recommendation: Replace pickle with a secure serialization format such as JSON. Since BroadcastInitInfo is a dataclass, it can be easily converted to and from a JSON-compatible dictionary.
|
|
||
| # Unpickle request to restore the original object type | ||
| assert isinstance(request, bytes), f"Expected bytes, got {type(request).__name__}" | ||
| request = pickle.loads(request) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| def _build_app(self) -> FastAPI: | ||
| """Build the FastAPI app with proxy routes.""" | ||
| app = FastAPI( | ||
| title="SkyRL Inference Router", | ||
| docs_url=None, | ||
| redoc_url=None, | ||
| openapi_url=None, | ||
| ) | ||
|
|
||
| @app.get("/health") | ||
| async def health(): | ||
| """Router health check (doesn't proxy to backends).""" | ||
| return {"status": "healthy"} | ||
|
|
||
| @app.get("/servers") | ||
| async def list_servers(): | ||
| """Return list of server URLs.""" | ||
| return {"servers": self._server_urls} | ||
|
|
||
| @app.get("/get_server_info") | ||
| async def get_server_info(): | ||
| """Fetch server info from all servers, return mapping.""" | ||
| return await self._fan_out_get("/get_server_info") | ||
|
|
||
| # Catch-all: proxy everything else to backends | ||
| @app.api_route( | ||
| "/{path:path}", | ||
| methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "HEAD"], | ||
| ) | ||
| async def proxy(request: Request, path: str): | ||
| return await self._proxy_request(request, f"/{path}") | ||
|
|
||
| return app |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The InferenceRouter exposes several sensitive control plane routes (e.g., /pause, /resume, /init_weight_transfer, /update_weights) without any authentication or authorization mechanism. This allows any user with network access to the router to disrupt the inference service or potentially hijack model weights by pointing workers to a malicious master node.
Recommendation: Implement an authentication mechanism, such as API keys or OAuth2, and ensure the router validates credentials before processing or proxying requests to these sensitive endpoints.
| """Add custom SkyRL endpoints to the FastAPI app.""" | ||
| engine = self._engine | ||
|
|
||
| @app.get("/get_server_info") | ||
| async def _get_server_info(): | ||
| """Return server parallelism info.""" | ||
| return self._get_extended_server_info() | ||
|
|
||
| # TODO (Kourosh): After https://github.com/vllm-project/vllm/pull/ | ||
| # 31943/ is merged, use the native API. | ||
| @app.post("/init_weight_transfer") | ||
| async def _init_weight_transfer(request: Request): | ||
| """Initialize weight sync process group.""" | ||
| from skyrl_train.weight_sync import BroadcastInitInfo | ||
|
|
||
| data = await request.json() | ||
| init_info = BroadcastInitInfo(**data).for_engine( | ||
| engine_index=self._server_idx, | ||
| tp_size=self._cli_args.tensor_parallel_size, | ||
| pp_size=self._cli_args.pipeline_parallel_size, | ||
| ) | ||
| pickled_init_info = pickle.dumps(init_info) | ||
|
|
||
| await engine.collective_rpc( | ||
| "init_weight_update_communicator", | ||
| args=(pickled_init_info,), | ||
| ) | ||
| return {"status": "ok"} | ||
|
|
||
| @app.post("/update_weights") | ||
| async def _update_weights(request: Request): | ||
| """Update model weights via NCCL broadcast.""" | ||
| from skyrl_train.weight_sync import BroadcastWeightUpdateRequest | ||
|
|
||
| data = await request.json() | ||
| weight_request = BroadcastWeightUpdateRequest(**data) | ||
| pickled_request = pickle.dumps(weight_request) | ||
|
|
||
| await engine.collective_rpc( | ||
| "load_weights", | ||
| args=(pickled_request,), | ||
| ) | ||
| return {"status": "ok"} | ||
|
|
||
| @app.post("/finalize_weight_update") | ||
| async def _finalize_weight_update(request: Request): | ||
| """ | ||
| Finalize weight update - post-processing hook. | ||
|
|
||
| Currently a no-op, reserved for future use e.g. Quantization | ||
| See https://github.com/vllm-project/vllm/issues/31848 for more | ||
| details. | ||
| """ | ||
| # No-op for now - placeholder for future post-processing | ||
| return {"status": "ok"} | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The VLLMServerActor adds custom endpoints for weight synchronization and cluster management directly to the FastAPI application without any authentication. These endpoints trigger sensitive operations, including the insecure pickle.loads calls in the workers.
Recommendation: Protect these custom endpoints with an authentication layer (e.g., FastAPI dependencies or middleware) to ensure only authorized training components can trigger these operations.
| strategy_cls = init_info.strategy_type() | ||
|
|
||
| if hasattr(self, "_weight_receiver") and self._weight_receiver is not None: | ||
| # TODO(haochen): we should get rid of this flag and override existing receiver. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary: Adds
RemoteInferenceClient, a lightweight, fully serializable HTTP client that wraps inference server APIs. This client replaces the oldInferenceEngineInterfacefor HTTP-based inference and can work with any HTTP-compatible inference backend (vLLM, sglang-router, Ray Serve LLM, etc.).Key Features:
proxy_urlfor data plane (roundrobin / sticky session router),server_urlsfor control plane (fan-out)generate(),chat_completion(),completion(),tokenize(),detokenize()pause(),resume(),sleep(),wake_up(),reset_prefix_cache()init_weight_transfer(),update_weights(),finalize_weight_update()stop_reason="abort"during weight syncComparison vs
InferenceEngineInterface+InferenceEngineClient:/tokenizeendpoint insteadX-Session-IDheaderget_world_size()vs separatetp_size(),pp_size(),dp_size()Files Added:
skyrl_train/inference_servers/remote_inference_client.py- The client implementationtests/cpu/inference_servers/test_remote_inference_client.py- Unit testsNext: Integration with training code via
setup_inference()hook inBasePPOExp.