diff --git a/src/phoenix/config.py b/src/phoenix/config.py index fd20857ec1..9da8da077c 100644 --- a/src/phoenix/config.py +++ b/src/phoenix/config.py @@ -80,6 +80,10 @@ """ Whether or not to log migrations. Defaults to true. """ +ENV_PHOENIX_ENABLE_WEBSOCKETS = "PHOENIX_ENABLE_WEBSOCKETS" +""" +Whether or not to enable websockets. Defaults to None. +""" # Phoenix server OpenTelemetry instrumentation environment variables ENV_PHOENIX_SERVER_INSTRUMENTATION_OTLP_TRACE_COLLECTOR_HTTP_ENDPOINT = ( @@ -371,6 +375,10 @@ def get_env_smtp_validate_certs() -> bool: return _bool_val(ENV_PHOENIX_SMTP_VALIDATE_CERTS, True) +def get_env_enable_websockets() -> Optional[bool]: + return _bool_val(ENV_PHOENIX_ENABLE_WEBSOCKETS) + + @dataclass(frozen=True) class OAuth2ClientConfig: idp_name: str diff --git a/src/phoenix/server/app.py b/src/phoenix/server/app.py index 3df21bf086..89b427ce9c 100644 --- a/src/phoenix/server/app.py +++ b/src/phoenix/server/app.py @@ -166,6 +166,7 @@ class AppConfig(NamedTuple): web_manifest_path: Path authentication_enabled: bool """ Whether authentication is enabled """ + websockets_enabled: bool oauth2_idps: Sequence[OAuth2Idp] @@ -216,6 +217,7 @@ async def get_response(self, path: str, scope: Scope) -> Response: "manifest": self._web_manifest, "authentication_enabled": self._app_config.authentication_enabled, "oauth2_idps": self._app_config.oauth2_idps, + "websockets_enabled": self._app_config.websockets_enabled, }, ) except Exception as e: @@ -653,6 +655,7 @@ def create_app( model: Model, authentication_enabled: bool, umap_params: UMAPParameters, + enable_websockets: bool, corpus: Optional[Model] = None, debug: bool = False, dev: bool = False, @@ -830,6 +833,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: authentication_enabled=authentication_enabled, web_manifest_path=web_manifest_path, oauth2_idps=oauth2_idps, + websockets_enabled=enable_websockets, ), ), name="static", diff --git a/src/phoenix/server/main.py b/src/phoenix/server/main.py index 1da4325339..5e061e5f23 100644 --- a/src/phoenix/server/main.py +++ b/src/phoenix/server/main.py @@ -21,6 +21,7 @@ get_env_database_schema, get_env_db_logging_level, get_env_enable_prometheus, + get_env_enable_websockets, get_env_grpc_port, get_env_host, get_env_host_root_path, @@ -95,6 +96,7 @@ | 🚀 Phoenix Server 🚀 | Phoenix UI: {{ ui_path }} | Authentication: {{ auth_enabled }} +| Websockets: {{ websockets_enabled }} | Log traces: | - gRPC: {{ grpc_path }} | - HTTP: {{ http_path }} @@ -162,7 +164,7 @@ def main() -> None: parser.add_argument("--debug", action="store_true", help=SUPPRESS) parser.add_argument("--dev", action="store_true", help=SUPPRESS) parser.add_argument("--no-ui", action="store_true", help=SUPPRESS) - + parser.add_argument("--enable-websockets", type=str, help=SUPPRESS) subparsers = parser.add_subparsers(dest="command", required=True, help=SUPPRESS) serve_parser = subparsers.add_parser("serve") @@ -348,6 +350,14 @@ def main() -> None: corpus_model = ( None if corpus_inferences is None else create_model_from_inferences(corpus_inferences) ) + + # Get enable_websockets from environment variable or command line argument + enable_websockets = get_env_enable_websockets() + if args.enable_websockets is not None: + enable_websockets = args.enable_websockets.lower() == "true" + if enable_websockets is None: + enable_websockets = True + # Print information about the server root_path = urljoin(f"http://{host}:{port}", host_root_path) msg = _WELCOME_MESSAGE.render( @@ -358,6 +368,7 @@ def main() -> None: storage=get_printable_db_url(db_connection_str), schema=get_env_database_schema(), auth_enabled=authentication_enabled, + websockets_enabled=enable_websockets, ) if sys.platform.startswith("win"): msg = codecs.encode(msg, "ascii", errors="ignore").decode("ascii").strip() @@ -382,10 +393,12 @@ def main() -> None: connection_method="STARTTLS", validate_certs=get_env_smtp_validate_certs(), ) + app = create_app( db=factory, export_path=export_path, model=model, + enable_websockets=enable_websockets, authentication_enabled=authentication_enabled, umap_params=umap_params, corpus=corpus_model, diff --git a/src/phoenix/server/templates/index.html b/src/phoenix/server/templates/index.html index 0a78f03a02..6353f3d7b2 100644 --- a/src/phoenix/server/templates/index.html +++ b/src/phoenix/server/templates/index.html @@ -88,6 +88,7 @@ }, authenticationEnabled: Boolean("{{authentication_enabled}}" == "True"), oAuth2Idps: {{ oauth2_idps | tojson }}, + websocketsEnabled: Boolean("{{websockets_enabled}}" == "True"), }), writable: false }); diff --git a/src/phoenix/services.py b/src/phoenix/services.py index 5aa191771a..0987c2e4f2 100644 --- a/src/phoenix/services.py +++ b/src/phoenix/services.py @@ -118,6 +118,7 @@ def __init__( reference_inferences_name: Optional[str], corpus_inferences_name: Optional[str], trace_dataset_name: Optional[str], + enable_websockets: bool, ): self.database_url = database_url self.export_path = export_path @@ -129,6 +130,7 @@ def __init__( self.__reference_inferences_name = reference_inferences_name self.__corpus_inferences_name = corpus_inferences_name self.__trace_dataset_name = trace_dataset_name + self.enable_websockets = enable_websockets super().__init__() @property @@ -156,5 +158,7 @@ def command(self) -> list[str]: command.extend(["--corpus", str(self.__corpus_inferences_name)]) if self.__trace_dataset_name is not None: command.extend(["--trace", str(self.__trace_dataset_name)]) + if self.enable_websockets: + command.append("--enable-websockets") logger.info(f"command: {' '.join(command)}") return command diff --git a/src/phoenix/session/session.py b/src/phoenix/session/session.py index dbea64522e..48267c64ff 100644 --- a/src/phoenix/session/session.py +++ b/src/phoenix/session/session.py @@ -24,6 +24,7 @@ ENV_PHOENIX_PORT, ensure_working_dir, get_env_database_connection_str, + get_env_enable_websockets, get_env_host, get_env_port, get_exported_files, @@ -269,6 +270,7 @@ def __init__( self, database_url: str, primary_inferences: Inferences, + enable_websockets: bool, reference_inferences: Optional[Inferences] = None, corpus_inferences: Optional[Inferences] = None, trace_dataset: Optional[TraceDataset] = None, @@ -319,6 +321,7 @@ def __init__( trace_dataset_name=( self.trace_dataset.name if self.trace_dataset is not None else None ), + enable_websockets=enable_websockets, ) @property @@ -335,6 +338,7 @@ def __init__( self, database_url: str, primary_inferences: Inferences, + enable_websockets: bool, reference_inferences: Optional[Inferences] = None, corpus_inferences: Optional[Inferences] = None, trace_dataset: Optional[TraceDataset] = None, @@ -375,6 +379,7 @@ def __init__( export_path=self.export_path, model=self.model, authentication_enabled=False, + enable_websockets=enable_websockets, corpus=self.corpus, umap_params=self.umap_parameters, initial_spans=trace_dataset.to_spans() if trace_dataset else None, @@ -438,6 +443,7 @@ def launch_app( run_in_thread: bool = True, notebook_environment: Optional[Union[NotebookEnvironment, str]] = None, use_temp_dir: bool = True, + enable_websockets: Optional[bool] = None, ) -> Optional[Session]: """ Launches the phoenix application and returns a session to interact with. @@ -472,7 +478,8 @@ def launch_app( use_temp_dir: bool, optional, default=True Whether to use a temporary directory to store the data. If set to False, the data will be stored in the directory specified by PHOENIX_WORKING_DIR environment variable via SQLite. - + enable_websockets: bool, optional, default=False + Whether to enable websockets. Returns ------- @@ -553,10 +560,16 @@ def launch_app( else: database_url = get_env_database_connection_str() + enable_websockets_env = get_env_enable_websockets() or False + enable_websockets = ( + enable_websockets if enable_websockets is not None else enable_websockets_env + ) + if run_in_thread: _session = ThreadSession( database_url, primary, + enable_websockets, reference, corpus, trace, @@ -570,6 +583,7 @@ def launch_app( _session = ProcessSession( database_url, primary, + enable_websockets, reference, corpus, trace, diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 4b05a563cb..46d20255f1 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -191,6 +191,7 @@ async def app( umap_params=get_umap_parameters(None), serve_ui=False, bulk_inserter_factory=TestBulkInserter, + enable_websockets=True, ) manager = await stack.enter_async_context(LifespanManager(app)) yield manager.app