Skip to content

Commit

Permalink
feat: Detect websocket availability and pass to client (#5224)
Browse files Browse the repository at this point in the history
* feat: Detect websocket availability and pass to client

Co-authored-by: Xander Song <axiomofjoy@gmail.com>

* Fix test

---------

Co-authored-by: Xander Song <axiomofjoy@gmail.com>
  • Loading branch information
cephalization and axiomofjoy authored Oct 29, 2024
1 parent 0fc91a6 commit 661ec17
Show file tree
Hide file tree
Showing 7 changed files with 47 additions and 2 deletions.
8 changes: 8 additions & 0 deletions src/phoenix/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@
"""
Whether or not to log migrations. Defaults to true.
"""
ENV_PHOENIX_ENABLE_WEBSOCKETS = "PHOENIX_ENABLE_WEBSOCKETS"
"""
Whether or not to enable websockets. Defaults to None.
"""

# Phoenix server OpenTelemetry instrumentation environment variables
ENV_PHOENIX_SERVER_INSTRUMENTATION_OTLP_TRACE_COLLECTOR_HTTP_ENDPOINT = (
Expand Down Expand Up @@ -371,6 +375,10 @@ def get_env_smtp_validate_certs() -> bool:
return _bool_val(ENV_PHOENIX_SMTP_VALIDATE_CERTS, True)


def get_env_enable_websockets() -> Optional[bool]:
return _bool_val(ENV_PHOENIX_ENABLE_WEBSOCKETS)


@dataclass(frozen=True)
class OAuth2ClientConfig:
idp_name: str
Expand Down
4 changes: 4 additions & 0 deletions src/phoenix/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ class AppConfig(NamedTuple):
web_manifest_path: Path
authentication_enabled: bool
""" Whether authentication is enabled """
websockets_enabled: bool
oauth2_idps: Sequence[OAuth2Idp]


Expand Down Expand Up @@ -216,6 +217,7 @@ async def get_response(self, path: str, scope: Scope) -> Response:
"manifest": self._web_manifest,
"authentication_enabled": self._app_config.authentication_enabled,
"oauth2_idps": self._app_config.oauth2_idps,
"websockets_enabled": self._app_config.websockets_enabled,
},
)
except Exception as e:
Expand Down Expand Up @@ -653,6 +655,7 @@ def create_app(
model: Model,
authentication_enabled: bool,
umap_params: UMAPParameters,
enable_websockets: bool,
corpus: Optional[Model] = None,
debug: bool = False,
dev: bool = False,
Expand Down Expand Up @@ -830,6 +833,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
authentication_enabled=authentication_enabled,
web_manifest_path=web_manifest_path,
oauth2_idps=oauth2_idps,
websockets_enabled=enable_websockets,
),
),
name="static",
Expand Down
15 changes: 14 additions & 1 deletion src/phoenix/server/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
get_env_database_schema,
get_env_db_logging_level,
get_env_enable_prometheus,
get_env_enable_websockets,
get_env_grpc_port,
get_env_host,
get_env_host_root_path,
Expand Down Expand Up @@ -95,6 +96,7 @@
| 🚀 Phoenix Server 🚀
| Phoenix UI: {{ ui_path }}
| Authentication: {{ auth_enabled }}
| Websockets: {{ websockets_enabled }}
| Log traces:
| - gRPC: {{ grpc_path }}
| - HTTP: {{ http_path }}
Expand Down Expand Up @@ -162,7 +164,7 @@ def main() -> None:
parser.add_argument("--debug", action="store_true", help=SUPPRESS)
parser.add_argument("--dev", action="store_true", help=SUPPRESS)
parser.add_argument("--no-ui", action="store_true", help=SUPPRESS)

parser.add_argument("--enable-websockets", type=str, help=SUPPRESS)
subparsers = parser.add_subparsers(dest="command", required=True, help=SUPPRESS)

serve_parser = subparsers.add_parser("serve")
Expand Down Expand Up @@ -348,6 +350,14 @@ def main() -> None:
corpus_model = (
None if corpus_inferences is None else create_model_from_inferences(corpus_inferences)
)

# Get enable_websockets from environment variable or command line argument
enable_websockets = get_env_enable_websockets()
if args.enable_websockets is not None:
enable_websockets = args.enable_websockets.lower() == "true"
if enable_websockets is None:
enable_websockets = True

# Print information about the server
root_path = urljoin(f"http://{host}:{port}", host_root_path)
msg = _WELCOME_MESSAGE.render(
Expand All @@ -358,6 +368,7 @@ def main() -> None:
storage=get_printable_db_url(db_connection_str),
schema=get_env_database_schema(),
auth_enabled=authentication_enabled,
websockets_enabled=enable_websockets,
)
if sys.platform.startswith("win"):
msg = codecs.encode(msg, "ascii", errors="ignore").decode("ascii").strip()
Expand All @@ -382,10 +393,12 @@ def main() -> None:
connection_method="STARTTLS",
validate_certs=get_env_smtp_validate_certs(),
)

app = create_app(
db=factory,
export_path=export_path,
model=model,
enable_websockets=enable_websockets,
authentication_enabled=authentication_enabled,
umap_params=umap_params,
corpus=corpus_model,
Expand Down
1 change: 1 addition & 0 deletions src/phoenix/server/templates/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@
},
authenticationEnabled: Boolean("{{authentication_enabled}}" == "True"),
oAuth2Idps: {{ oauth2_idps | tojson }},
websocketsEnabled: Boolean("{{websockets_enabled}}" == "True"),
}),
writable: false
});
Expand Down
4 changes: 4 additions & 0 deletions src/phoenix/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def __init__(
reference_inferences_name: Optional[str],
corpus_inferences_name: Optional[str],
trace_dataset_name: Optional[str],
enable_websockets: bool,
):
self.database_url = database_url
self.export_path = export_path
Expand All @@ -129,6 +130,7 @@ def __init__(
self.__reference_inferences_name = reference_inferences_name
self.__corpus_inferences_name = corpus_inferences_name
self.__trace_dataset_name = trace_dataset_name
self.enable_websockets = enable_websockets
super().__init__()

@property
Expand Down Expand Up @@ -156,5 +158,7 @@ def command(self) -> list[str]:
command.extend(["--corpus", str(self.__corpus_inferences_name)])
if self.__trace_dataset_name is not None:
command.extend(["--trace", str(self.__trace_dataset_name)])
if self.enable_websockets:
command.append("--enable-websockets")
logger.info(f"command: {' '.join(command)}")
return command
16 changes: 15 additions & 1 deletion src/phoenix/session/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
ENV_PHOENIX_PORT,
ensure_working_dir,
get_env_database_connection_str,
get_env_enable_websockets,
get_env_host,
get_env_port,
get_exported_files,
Expand Down Expand Up @@ -269,6 +270,7 @@ def __init__(
self,
database_url: str,
primary_inferences: Inferences,
enable_websockets: bool,
reference_inferences: Optional[Inferences] = None,
corpus_inferences: Optional[Inferences] = None,
trace_dataset: Optional[TraceDataset] = None,
Expand Down Expand Up @@ -319,6 +321,7 @@ def __init__(
trace_dataset_name=(
self.trace_dataset.name if self.trace_dataset is not None else None
),
enable_websockets=enable_websockets,
)

@property
Expand All @@ -335,6 +338,7 @@ def __init__(
self,
database_url: str,
primary_inferences: Inferences,
enable_websockets: bool,
reference_inferences: Optional[Inferences] = None,
corpus_inferences: Optional[Inferences] = None,
trace_dataset: Optional[TraceDataset] = None,
Expand Down Expand Up @@ -375,6 +379,7 @@ def __init__(
export_path=self.export_path,
model=self.model,
authentication_enabled=False,
enable_websockets=enable_websockets,
corpus=self.corpus,
umap_params=self.umap_parameters,
initial_spans=trace_dataset.to_spans() if trace_dataset else None,
Expand Down Expand Up @@ -438,6 +443,7 @@ def launch_app(
run_in_thread: bool = True,
notebook_environment: Optional[Union[NotebookEnvironment, str]] = None,
use_temp_dir: bool = True,
enable_websockets: Optional[bool] = None,
) -> Optional[Session]:
"""
Launches the phoenix application and returns a session to interact with.
Expand Down Expand Up @@ -472,7 +478,8 @@ def launch_app(
use_temp_dir: bool, optional, default=True
Whether to use a temporary directory to store the data. If set to False, the data will be
stored in the directory specified by PHOENIX_WORKING_DIR environment variable via SQLite.
enable_websockets: bool, optional, default=False
Whether to enable websockets.
Returns
-------
Expand Down Expand Up @@ -553,10 +560,16 @@ def launch_app(
else:
database_url = get_env_database_connection_str()

enable_websockets_env = get_env_enable_websockets() or False
enable_websockets = (
enable_websockets if enable_websockets is not None else enable_websockets_env
)

if run_in_thread:
_session = ThreadSession(
database_url,
primary,
enable_websockets,
reference,
corpus,
trace,
Expand All @@ -570,6 +583,7 @@ def launch_app(
_session = ProcessSession(
database_url,
primary,
enable_websockets,
reference,
corpus,
trace,
Expand Down
1 change: 1 addition & 0 deletions tests/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ async def app(
umap_params=get_umap_parameters(None),
serve_ui=False,
bulk_inserter_factory=TestBulkInserter,
enable_websockets=True,
)
manager = await stack.enter_async_context(LifespanManager(app))
yield manager.app
Expand Down

0 comments on commit 661ec17

Please sign in to comment.