Skip to content

Commit

Permalink
Introduce websocket POC for truss server
Browse files Browse the repository at this point in the history
  • Loading branch information
nnarayen committed Feb 7, 2025
1 parent ecc1b4f commit 8633002
Show file tree
Hide file tree
Showing 6 changed files with 171 additions and 23 deletions.
120 changes: 98 additions & 22 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ opentelemetry-sdk = ">=1.25.0"
truss_transfer="0.0.1"
uvicorn = ">=0.24.0"
uvloop = ">=0.17.0"
websockets = ">=13.0"


[build-system]
Expand Down
9 changes: 9 additions & 0 deletions truss/templates/server/model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class MethodName(str, enum.Enum):
PREDICT = "predict"
PREPROCESS = "preprocess"
SETUP_ENVIRONMENT = "setup_environment"
WEBSOCKET = "websocket"


InputType = Union[serialization.JSONType, serialization.MsgPackType, pydantic.BaseModel]
Expand Down Expand Up @@ -224,6 +225,7 @@ class ModelDescriptor:
is_healthy: Optional[MethodDescriptor]
completions: Optional[MethodDescriptor]
chat_completions: Optional[MethodDescriptor]
websocket: Optional[MethodDescriptor]

@cached_property
def skip_input_parsing(self) -> bool:
Expand Down Expand Up @@ -293,6 +295,12 @@ def from_model(cls, model_cls) -> "ModelDescriptor":
f"`{MethodName.IS_HEALTHY}` must have only one argument: `self`."
)

websocket = cls._safe_extract_descriptor(model_cls, MethodName.WEBSOCKET)
if websocket and websocket.arg_config != ArgConfig.INPUTS_ONLY:
raise errors.ModelDefinitionError(
f"`{MethodName.WEBSOCKET}` must have only one argument: `websocket`."
)

truss_schema = cls._gen_truss_schema(
model_cls=model_cls,
predict=predict,
Expand All @@ -308,6 +316,7 @@ def from_model(cls, model_cls) -> "ModelDescriptor":
is_healthy=is_healthy,
completions=completions,
chat_completions=chats,
websocket=websocket,
)


Expand Down
1 change: 1 addition & 0 deletions truss/templates/server/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ requests==2.31.0
uvicorn==0.24.0
uvloop==0.19.0
aiofiles==24.1.0
websockets==13.1
22 changes: 21 additions & 1 deletion truss/templates/server/truss_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@
import yaml
from common import errors, tracing
from common.schema import TrussSchema
from fastapi import Depends, FastAPI, HTTPException, Request
from fastapi import Depends, FastAPI, HTTPException, Request, WebSocket
from fastapi.responses import ORJSONResponse, StreamingResponse
from fastapi.routing import APIRoute as FastAPIRoute
from fastapi.routing import APIWebSocketRoute as FastAPIWebSocketRoute
from model_wrapper import MODEL_BASENAME, MethodName, ModelWrapper
from opentelemetry import propagate as otel_propagate
from opentelemetry import trace
Expand Down Expand Up @@ -233,6 +234,23 @@ async def completions(
body_raw=body_raw,
)

async def websocket(self, websocket: WebSocket):
model = self._safe_lookup_model()
# Rejected the connection upgrade is standard practice if we know the request
# is invalid before we accept it.
if not model.model_descriptor.websocket:
return

self.check_healthy(model)
trace_ctx = otel_propagate.extract(websocket.headers) or None
# This is the top-level span in the truss-server, so we set the context here.
# Nested spans "inherit" context automatically.
with self._tracer.start_as_current_span(
f"{MethodName.WEBSOCKET}-endpoint", context=trace_ctx
):
await websocket.accept()
await model.model_descriptor.websocket.method(websocket)

async def predict(
self, model_name: str, request: Request, body_raw: bytes = Depends(parse_body)
) -> Response:
Expand Down Expand Up @@ -407,6 +425,8 @@ def create_application(self):
methods=["POST"],
tags=["V1"],
),
# Websocket endpoint
FastAPIWebSocketRoute(r"/v1/websocket", self._endpoints.websocket),
# Endpoint aliases for Sagemaker hosting
FastAPIRoute(r"/ping", self._endpoints.invocations_ready),
FastAPIRoute(
Expand Down
Loading

0 comments on commit 8633002

Please sign in to comment.