From dce7551d00dd37e0ee005bafeed1757b0f5013bd Mon Sep 17 00:00:00 2001 From: Pranav Badhe Date: Wed, 4 Oct 2023 22:19:30 -0400 Subject: [PATCH] feat: Add default UMAP parameters in launch_app() (#1224) * Reaching till GraphQLWithContext * typing.Dict not dict() * Dict import * Suggested review changes Added validation, moved umap_params to pointcloud and removed Optional for self.umap_parameters in context * CamelCase -> SnakeCase Co-authored-by: Mikyo King * CamelCase -> SnakeCase all suggested changes Co-authored-by: Mikyo King * Default UMAPparam changes Removed hardcoded defaults for single source of truth. * typing hints & process_session() Process_session() error. Logs at -> https://textdoc.co/maM9oCZwexUA4EWQ * Prettier check I'd installed the pre-commit hooks, but "black ." was taking too long and checking venv lib packages, so I was doing it manually just on src/*. * Adjust frontend * parse string correctly * fix duplicate lines * Update umap_parameters.py Co-authored-by: Roger Yang <80478925+RogerHYang@users.noreply.github.com> * Update umap_parameters.py Co-authored-by: Roger Yang <80478925+RogerHYang@users.noreply.github.com> * Update umap_parameters.py Co-authored-by: Roger Yang <80478925+RogerHYang@users.noreply.github.com> * Update src/phoenix/session/session.py Co-authored-by: Roger Yang <80478925+RogerHYang@users.noreply.github.com> * Update src/phoenix/session/session.py Co-authored-by: Roger Yang <80478925+RogerHYang@users.noreply.github.com> * Update src/phoenix/session/session.py Co-authored-by: Roger Yang <80478925+RogerHYang@users.noreply.github.com> * use positional arguments * fix style and imports --------- Co-authored-by: Mikyo King Co-authored-by: Mikyo King Co-authored-by: Roger Yang <80478925+RogerHYang@users.noreply.github.com> --- app/package.json | 4 +- app/src/pages/embedding/EmbeddingPage.tsx | 18 ++++++-- app/src/store/pointCloudStore.ts | 3 -- app/src/window.d.ts | 14 ++++++ src/phoenix/pointcloud/umap_parameters.py | 52 +++++++++++++++++++++++ src/phoenix/server/app.py | 11 +++++ src/phoenix/server/main.py | 19 ++++++++- src/phoenix/server/templates/index.html | 5 +++ src/phoenix/services.py | 4 ++ src/phoenix/session/session.py | 36 ++++++++++++++-- 10 files changed, 154 insertions(+), 12 deletions(-) create mode 100644 app/src/window.d.ts create mode 100644 src/phoenix/pointcloud/umap_parameters.py diff --git a/app/package.json b/app/package.json index d9ca3adfca..28dc2a0737 100644 --- a/app/package.json +++ b/app/package.json @@ -74,8 +74,8 @@ "build:relay": "relay-compiler", "watch": "./esbuild.config.mjs dev", "test": "jest --config ./jest.config.js", - "dev": "npm run dev:server:traces:llama_index_rag & npm run build:static && npm run watch", - "dev:server:mnist": "python3 -m phoenix.server.main fixture fashion_mnist", + "dev": "npm run dev:server:mnist & npm run build:static && npm run watch", + "dev:server:mnist": "python3 -m phoenix.server.main --umap_params 0,30,550 fixture fashion_mnist", "dev:server:mnist:single": "python3 -m phoenix.server.main fixture fashion_mnist --primary-only true", "dev:server:sentiment": "python3 -m phoenix.server.main fixture sentiment_classification_language_drift", "dev:server:image": "python3 -m phoenix.server.main fixture image_classification", diff --git a/app/src/pages/embedding/EmbeddingPage.tsx b/app/src/pages/embedding/EmbeddingPage.tsx index fd788e9cfb..9d567d58c9 100644 --- a/app/src/pages/embedding/EmbeddingPage.tsx +++ b/app/src/pages/embedding/EmbeddingPage.tsx @@ -223,14 +223,26 @@ export function EmbeddingPage() { const { timeRange } = useTimeRange(); // Initialize the store based on whether or not there is a reference dataset const defaultPointCloudProps = useMemo>(() => { + let defaultPointCloudProps: Partial = {}; if (corpusDataset != null) { // If there is a corpus dataset, then initialize the page with the retrieval troubleshooting settings // TODO - this does make a bit of a leap of assumptions but is a short term solution in order to get the page working as intended - return DEFAULT_RETRIEVAL_TROUBLESHOOTING_POINT_CLOUD_PROPS; + defaultPointCloudProps = + DEFAULT_RETRIEVAL_TROUBLESHOOTING_POINT_CLOUD_PROPS; } else if (referenceDataset != null) { - return DEFAULT_DRIFT_POINT_CLOUD_PROPS; + defaultPointCloudProps = DEFAULT_DRIFT_POINT_CLOUD_PROPS; + } else { + defaultPointCloudProps = DEFAULT_SINGLE_DATASET_POINT_CLOUD_PROPS; } - return DEFAULT_SINGLE_DATASET_POINT_CLOUD_PROPS; + + // Apply the UMAP parameters from the server-sent config + defaultPointCloudProps = { + ...defaultPointCloudProps, + umapParameters: { + ...window.Config.UMAP, + }, + }; + return defaultPointCloudProps; }, [corpusDataset, referenceDataset]); return ( diff --git a/app/src/store/pointCloudStore.ts b/app/src/store/pointCloudStore.ts index e84132a21a..52481045ae 100644 --- a/app/src/store/pointCloudStore.ts +++ b/app/src/store/pointCloudStore.ts @@ -69,17 +69,14 @@ type DimensionMetadata = { export type UMAPParameters = { /** * Minimum distance between points in the eUMAP projection - * @default 0 */ minDist: number; /** * The number of neighbors to require for the UMAP projection - * @default 30 */ nNeighbors: number; /** * The number of samples to use for the UMAP projection. The sample number is per dataset. - * @default 500 */ nSamples: number; }; diff --git a/app/src/window.d.ts b/app/src/window.d.ts new file mode 100644 index 0000000000..fc7e2f4008 --- /dev/null +++ b/app/src/window.d.ts @@ -0,0 +1,14 @@ +export {}; + +declare global { + interface Window { + Config: { + hasCorpus: boolean; + UMAP: { + minDist: number; + nNeighbors: number; + nSamples: number; + }; + }; + } +} diff --git a/src/phoenix/pointcloud/umap_parameters.py b/src/phoenix/pointcloud/umap_parameters.py new file mode 100644 index 0000000000..a1f93c972d --- /dev/null +++ b/src/phoenix/pointcloud/umap_parameters.py @@ -0,0 +1,52 @@ +from dataclasses import dataclass +from typing import Any, Mapping, Optional + +DEFAULT_MIN_DIST = 0.0 +DEFAULT_N_NEIGHBORS = 30 +DEFAULT_N_SAMPLES = 500 + +MIN_NEIGHBORS = 5 +MAX_NEIGHBORS = 100 +MIN_SAMPLES = 1 +MAX_SAMPLES = 1000 +MIN_MIN_DIST = 0.0 +MAX_MIN_DIST = 0.99 + + +@dataclass +class UMAPParameters: + min_dist: float = DEFAULT_MIN_DIST + n_neighbors: int = DEFAULT_N_NEIGHBORS + n_samples: int = DEFAULT_N_SAMPLES + + def __post_init__(self) -> None: + if not isinstance(self.min_dist, float) or not ( + MIN_MIN_DIST <= self.min_dist <= MAX_MIN_DIST + ): + raise ValueError( + f"minDist must be float type, and between {MIN_MIN_DIST} and {MAX_MIN_DIST}" + ) + + if not isinstance(self.n_neighbors, int) or not ( + MIN_NEIGHBORS <= self.n_neighbors <= MAX_NEIGHBORS + ): + raise ValueError( + f"nNeighbors must be int type, and between {MIN_NEIGHBORS} and {MAX_NEIGHBORS}" + ) + + if not isinstance(self.n_samples, int) or not ( + MIN_SAMPLES <= self.n_samples <= MAX_SAMPLES + ): + raise ValueError( + f"nSamples must be int type, and between {MIN_SAMPLES} and {MAX_SAMPLES}" + ) + + +def get_umap_parameters(default_umap_parameters: Optional[Mapping[str, Any]]) -> UMAPParameters: + if not default_umap_parameters: + return UMAPParameters() + return UMAPParameters( + min_dist=float(default_umap_parameters.get("min_dist", DEFAULT_MIN_DIST)), + n_neighbors=int(default_umap_parameters.get("n_neighbors", DEFAULT_N_NEIGHBORS)), + n_samples=int(default_umap_parameters.get("n_samples", DEFAULT_N_SAMPLES)), + ) diff --git a/src/phoenix/server/app.py b/src/phoenix/server/app.py index 922775424a..db75dcfbbc 100644 --- a/src/phoenix/server/app.py +++ b/src/phoenix/server/app.py @@ -22,6 +22,7 @@ from phoenix.config import SERVER_DIR from phoenix.core.model_schema import Model from phoenix.core.traces import Traces +from phoenix.pointcloud.umap_parameters import UMAPParameters from phoenix.server.api.context import Context from phoenix.server.api.schema import schema from phoenix.server.span_handler import SpanHandler @@ -33,6 +34,9 @@ class AppConfig(NamedTuple): has_corpus: bool + min_dist: float + n_neighbors: int + n_samples: int class Static(StaticFiles): @@ -58,6 +62,9 @@ async def get_response(self, path: str, scope: Scope) -> Response: "index.html", context={ "has_corpus": self._app_config.has_corpus, + "min_dist": self._app_config.min_dist, + "n_neighbors": self._app_config.n_neighbors, + "n_samples": self._app_config.n_samples, "request": Request(scope), }, ) @@ -130,6 +137,7 @@ async def version(_: Request) -> PlainTextResponse: def create_app( export_path: Path, model: Model, + umap_params: UMAPParameters, corpus: Optional[Model] = None, traces: Optional[Traces] = None, debug: bool = False, @@ -182,6 +190,9 @@ def create_app( directory=SERVER_DIR / "static", app_config=AppConfig( has_corpus=corpus is not None, + min_dist=umap_params.min_dist, + n_neighbors=umap_params.n_neighbors, + n_samples=umap_params.n_samples, ), ), name="static", diff --git a/src/phoenix/server/main.py b/src/phoenix/server/main.py index 42a0207860..cf6f651bfb 100644 --- a/src/phoenix/server/main.py +++ b/src/phoenix/server/main.py @@ -14,6 +14,12 @@ from phoenix.core.traces import Traces from phoenix.datasets.dataset import EMPTY_DATASET, Dataset from phoenix.datasets.fixtures import FIXTURES, get_datasets +from phoenix.pointcloud.umap_parameters import ( + DEFAULT_MIN_DIST, + DEFAULT_N_NEIGHBORS, + DEFAULT_N_SAMPLES, + UMAPParameters, +) from phoenix.server.app import create_app from phoenix.trace.fixtures import ( TRACES_FIXTURES, @@ -46,6 +52,8 @@ def _get_pid_file() -> Path: return get_pids_path() / str(os.getpid()) +DEFAULT_UMAP_PARAMS_STR = f"{DEFAULT_MIN_DIST},{DEFAULT_N_NEIGHBORS},{DEFAULT_N_SAMPLES}" + if __name__ == "__main__": primary_dataset_name: str reference_dataset_name: Optional[str] @@ -63,7 +71,8 @@ def _get_pid_file() -> Path: parser.add_argument("--host", type=str, required=False) parser.add_argument("--port", type=int, required=False) parser.add_argument("--no-internet", action="store_true") - parser.add_argument("--debug", action="store_false") # TODO: Disable before public launch + parser.add_argument("--umap_params", type=str, required=False, default=DEFAULT_UMAP_PARAMS_STR) + parser.add_argument("--debug", action="store_false") subparsers = parser.add_subparsers(dest="command", required=True) datasets_parser = subparsers.add_parser("datasets") datasets_parser.add_argument("--primary", type=str, required=True) @@ -120,9 +129,17 @@ def _get_pid_file() -> Path: ), ): traces.put(span) + umap_params_list = args.umap_params.split(",") + umap_params = UMAPParameters( + min_dist=float(umap_params_list[0]), + n_neighbors=int(umap_params_list[1]), + n_samples=int(umap_params_list[2]), + ) + logger.info(f"Server umap params: {umap_params}") app = create_app( export_path=export_path, model=model, + umap_params=umap_params, traces=traces, corpus=None if corpus_dataset is None else create_model_from_datasets(corpus_dataset), debug=args.debug, diff --git a/src/phoenix/server/templates/index.html b/src/phoenix/server/templates/index.html index 76e785ecc2..16784b246e 100644 --- a/src/phoenix/server/templates/index.html +++ b/src/phoenix/server/templates/index.html @@ -23,6 +23,11 @@ // injected into the client before React runs value: Object.freeze({ hasCorpus: Boolean("{{has_corpus}}"), + UMAP: { + minDist: parseFloat("{{min_dist}}"), + nNeighbors: parseInt("{{n_neighbors}}"), + nSamples: parseInt("{{n_samples}}"), + } }), writable: false }); diff --git a/src/phoenix/services.py b/src/phoenix/services.py index 83ac99fbd2..1b4c7544c5 100644 --- a/src/phoenix/services.py +++ b/src/phoenix/services.py @@ -111,6 +111,7 @@ def __init__( host: str, port: int, primary_dataset_name: str, + umap_params: str, reference_dataset_name: Optional[str], corpus_dataset_name: Optional[str], trace_dataset_name: Optional[str], @@ -119,6 +120,7 @@ def __init__( self.host = host self.port = port self.__primary_dataset_name = primary_dataset_name + self.__umap_params = umap_params self.__reference_dataset_name = reference_dataset_name self.__corpus_dataset_name = corpus_dataset_name self.__trace_dataset_name = trace_dataset_name @@ -138,6 +140,8 @@ def command(self) -> List[str]: "datasets", "--primary", str(self.__primary_dataset_name), + "--umap_params", + self.__umap_params, ] if self.__reference_dataset_name is not None: command.extend(["--reference", str(self.__reference_dataset_name)]) diff --git a/src/phoenix/session/session.py b/src/phoenix/session/session.py index 3f1fbcc0ff..4dfd1f0e01 100644 --- a/src/phoenix/session/session.py +++ b/src/phoenix/session/session.py @@ -5,7 +5,15 @@ from datetime import datetime from pathlib import Path from tempfile import TemporaryDirectory -from typing import TYPE_CHECKING, Iterable, List, Optional, Set +from typing import ( + TYPE_CHECKING, + Any, + Iterable, + List, + Mapping, + Optional, + Set, +) import pandas as pd @@ -13,6 +21,7 @@ from phoenix.core.model_schema_adapter import create_model_from_datasets from phoenix.core.traces import Traces from phoenix.datasets.dataset import EMPTY_DATASET, Dataset +from phoenix.pointcloud.umap_parameters import get_umap_parameters from phoenix.server.app import create_app from phoenix.server.thread_server import ThreadServer from phoenix.services import AppService @@ -69,6 +78,7 @@ def __init__( reference_dataset: Optional[Dataset] = None, corpus_dataset: Optional[Dataset] = None, trace_dataset: Optional[TraceDataset] = None, + default_umap_parameters: Optional[Mapping[str, Any]] = None, host: Optional[str] = None, port: Optional[int] = None, ): @@ -76,6 +86,7 @@ def __init__( self.reference_dataset = reference_dataset self.corpus_dataset = corpus_dataset self.trace_dataset = trace_dataset + self.umap_parameters = get_umap_parameters(default_umap_parameters) self.model = create_model_from_datasets( primary_dataset, reference_dataset, @@ -174,6 +185,7 @@ def __init__( reference_dataset: Optional[Dataset] = None, corpus_dataset: Optional[Dataset] = None, trace_dataset: Optional[TraceDataset] = None, + default_umap_parameters: Optional[Mapping[str, Any]] = None, host: Optional[str] = None, port: Optional[int] = None, ) -> None: @@ -182,6 +194,7 @@ def __init__( reference_dataset=reference_dataset, corpus_dataset=corpus_dataset, trace_dataset=trace_dataset, + default_umap_parameters=default_umap_parameters, host=host, port=port, ) @@ -192,12 +205,18 @@ def __init__( corpus_dataset.to_disc() if isinstance(trace_dataset, TraceDataset): trace_dataset.to_disc() + umap_params_str = ( + f"{self.umap_parameters.min_dist}," + f"{self.umap_parameters.n_neighbors}," + f"{self.umap_parameters.n_samples}" + ) # Initialize an app service that keeps the server running self.app_service = AppService( self.export_path, self.host, self.port, self.primary_dataset.name, + umap_params_str, reference_dataset_name=( self.reference_dataset.name if self.reference_dataset is not None else None ), @@ -225,6 +244,7 @@ def __init__( reference_dataset: Optional[Dataset] = None, corpus_dataset: Optional[Dataset] = None, trace_dataset: Optional[TraceDataset] = None, + default_umap_parameters: Optional[Mapping[str, Any]] = None, host: Optional[str] = None, port: Optional[int] = None, ): @@ -233,6 +253,7 @@ def __init__( reference_dataset=reference_dataset, corpus_dataset=corpus_dataset, trace_dataset=trace_dataset, + default_umap_parameters=default_umap_parameters, host=host, port=port, ) @@ -242,6 +263,7 @@ def __init__( model=self.model, corpus=self.corpus, traces=self.traces, + umap_params=self.umap_parameters, ) self.server = ThreadServer( app=self.app, @@ -265,6 +287,7 @@ def launch_app( reference: Optional[Dataset] = None, corpus: Optional[Dataset] = None, trace: Optional[TraceDataset] = None, + default_umap_parameters: Optional[Mapping[str, Any]] = None, host: Optional[str] = None, port: Optional[int] = None, run_in_thread: bool = True, @@ -291,6 +314,9 @@ def launch_app( variable `PHOENIX_PORT`, otherwise it defaults to 6060. run_in_thread: bool, optional, default=True Whether the server should run in a Thread or Process. + default_umap_parameters: Dict[str, Union[int, float]], optional, default=None + User specified default UMAP parameters + eg: {"n_neighbors": 10, "n_samples": 5, "min_dist": 0.5} Returns ------- @@ -323,10 +349,14 @@ def launch_app( port = port or get_env_port() if run_in_thread: - _session = ThreadSession(primary, reference, corpus, trace, host=host, port=port) + _session = ThreadSession( + primary, reference, corpus, trace, default_umap_parameters, host=host, port=port + ) # TODO: catch exceptions from thread else: - _session = ProcessSession(primary, reference, corpus, trace, host=host, port=port) + _session = ProcessSession( + primary, reference, corpus, trace, default_umap_parameters, host=host, port=port + ) if not _session.active: logger.error(