Skip to content

Commit

Permalink
feat: Add default UMAP parameters in launch_app() (#1224)
Browse files Browse the repository at this point in the history
* Reaching till GraphQLWithContext

* typing.Dict not dict()

* Dict import

* Suggested review changes

Added validation, moved umap_params to pointcloud and removed Optional for self.umap_parameters in context

* CamelCase -> SnakeCase

Co-authored-by: Mikyo King <mikeldking@gmail.com>

* CamelCase -> SnakeCase all suggested changes

Co-authored-by: Mikyo King <mikeldking@gmail.com>

* Default UMAPparam changes

Removed hardcoded defaults for single source of truth.

* typing hints & process_session()

Process_session() error. Logs at -> https://textdoc.co/maM9oCZwexUA4EWQ

* Prettier check

I'd installed the pre-commit hooks, but "black ." was taking too long and checking venv lib packages, so I was doing it manually just on src/*.

* Adjust frontend

* parse string correctly

* fix duplicate lines

* Update umap_parameters.py

Co-authored-by: Roger Yang <80478925+RogerHYang@users.noreply.github.com>

* Update umap_parameters.py

Co-authored-by: Roger Yang <80478925+RogerHYang@users.noreply.github.com>

* Update umap_parameters.py

Co-authored-by: Roger Yang <80478925+RogerHYang@users.noreply.github.com>

* Update src/phoenix/session/session.py

Co-authored-by: Roger Yang <80478925+RogerHYang@users.noreply.github.com>

* Update src/phoenix/session/session.py

Co-authored-by: Roger Yang <80478925+RogerHYang@users.noreply.github.com>

* Update src/phoenix/session/session.py

Co-authored-by: Roger Yang <80478925+RogerHYang@users.noreply.github.com>

* use positional arguments

* fix style and imports

---------

Co-authored-by: Mikyo King <mikeldking@gmail.com>
Co-authored-by: Mikyo King <mikyo@arize.com>
Co-authored-by: Roger Yang <80478925+RogerHYang@users.noreply.github.com>
  • Loading branch information
4 people authored Oct 5, 2023
1 parent 8bad5ac commit dce7551
Show file tree
Hide file tree
Showing 10 changed files with 154 additions and 12 deletions.
4 changes: 2 additions & 2 deletions app/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@
"build:relay": "relay-compiler",
"watch": "./esbuild.config.mjs dev",
"test": "jest --config ./jest.config.js",
"dev": "npm run dev:server:traces:llama_index_rag & npm run build:static && npm run watch",
"dev:server:mnist": "python3 -m phoenix.server.main fixture fashion_mnist",
"dev": "npm run dev:server:mnist & npm run build:static && npm run watch",
"dev:server:mnist": "python3 -m phoenix.server.main --umap_params 0,30,550 fixture fashion_mnist",
"dev:server:mnist:single": "python3 -m phoenix.server.main fixture fashion_mnist --primary-only true",
"dev:server:sentiment": "python3 -m phoenix.server.main fixture sentiment_classification_language_drift",
"dev:server:image": "python3 -m phoenix.server.main fixture image_classification",
Expand Down
18 changes: 15 additions & 3 deletions app/src/pages/embedding/EmbeddingPage.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -223,14 +223,26 @@ export function EmbeddingPage() {
const { timeRange } = useTimeRange();
// Initialize the store based on whether or not there is a reference dataset
const defaultPointCloudProps = useMemo<Partial<PointCloudProps>>(() => {
let defaultPointCloudProps: Partial<PointCloudProps> = {};
if (corpusDataset != null) {
// If there is a corpus dataset, then initialize the page with the retrieval troubleshooting settings
// TODO - this does make a bit of a leap of assumptions but is a short term solution in order to get the page working as intended
return DEFAULT_RETRIEVAL_TROUBLESHOOTING_POINT_CLOUD_PROPS;
defaultPointCloudProps =
DEFAULT_RETRIEVAL_TROUBLESHOOTING_POINT_CLOUD_PROPS;
} else if (referenceDataset != null) {
return DEFAULT_DRIFT_POINT_CLOUD_PROPS;
defaultPointCloudProps = DEFAULT_DRIFT_POINT_CLOUD_PROPS;
} else {
defaultPointCloudProps = DEFAULT_SINGLE_DATASET_POINT_CLOUD_PROPS;
}
return DEFAULT_SINGLE_DATASET_POINT_CLOUD_PROPS;

// Apply the UMAP parameters from the server-sent config
defaultPointCloudProps = {
...defaultPointCloudProps,
umapParameters: {
...window.Config.UMAP,
},
};
return defaultPointCloudProps;
}, [corpusDataset, referenceDataset]);
return (
<TimeSliceContextProvider initialTimestamp={timeRange.end}>
Expand Down
3 changes: 0 additions & 3 deletions app/src/store/pointCloudStore.ts
Original file line number Diff line number Diff line change
Expand Up @@ -69,17 +69,14 @@ type DimensionMetadata = {
export type UMAPParameters = {
/**
* Minimum distance between points in the eUMAP projection
* @default 0
*/
minDist: number;
/**
* The number of neighbors to require for the UMAP projection
* @default 30
*/
nNeighbors: number;
/**
* The number of samples to use for the UMAP projection. The sample number is per dataset.
* @default 500
*/
nSamples: number;
};
Expand Down
14 changes: 14 additions & 0 deletions app/src/window.d.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
export {};

declare global {
interface Window {
Config: {
hasCorpus: boolean;
UMAP: {
minDist: number;
nNeighbors: number;
nSamples: number;
};
};
}
}
52 changes: 52 additions & 0 deletions src/phoenix/pointcloud/umap_parameters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from dataclasses import dataclass
from typing import Any, Mapping, Optional

DEFAULT_MIN_DIST = 0.0
DEFAULT_N_NEIGHBORS = 30
DEFAULT_N_SAMPLES = 500

MIN_NEIGHBORS = 5
MAX_NEIGHBORS = 100
MIN_SAMPLES = 1
MAX_SAMPLES = 1000
MIN_MIN_DIST = 0.0
MAX_MIN_DIST = 0.99


@dataclass
class UMAPParameters:
min_dist: float = DEFAULT_MIN_DIST
n_neighbors: int = DEFAULT_N_NEIGHBORS
n_samples: int = DEFAULT_N_SAMPLES

def __post_init__(self) -> None:
if not isinstance(self.min_dist, float) or not (
MIN_MIN_DIST <= self.min_dist <= MAX_MIN_DIST
):
raise ValueError(
f"minDist must be float type, and between {MIN_MIN_DIST} and {MAX_MIN_DIST}"
)

if not isinstance(self.n_neighbors, int) or not (
MIN_NEIGHBORS <= self.n_neighbors <= MAX_NEIGHBORS
):
raise ValueError(
f"nNeighbors must be int type, and between {MIN_NEIGHBORS} and {MAX_NEIGHBORS}"
)

if not isinstance(self.n_samples, int) or not (
MIN_SAMPLES <= self.n_samples <= MAX_SAMPLES
):
raise ValueError(
f"nSamples must be int type, and between {MIN_SAMPLES} and {MAX_SAMPLES}"
)


def get_umap_parameters(default_umap_parameters: Optional[Mapping[str, Any]]) -> UMAPParameters:
if not default_umap_parameters:
return UMAPParameters()
return UMAPParameters(
min_dist=float(default_umap_parameters.get("min_dist", DEFAULT_MIN_DIST)),
n_neighbors=int(default_umap_parameters.get("n_neighbors", DEFAULT_N_NEIGHBORS)),
n_samples=int(default_umap_parameters.get("n_samples", DEFAULT_N_SAMPLES)),
)
11 changes: 11 additions & 0 deletions src/phoenix/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from phoenix.config import SERVER_DIR
from phoenix.core.model_schema import Model
from phoenix.core.traces import Traces
from phoenix.pointcloud.umap_parameters import UMAPParameters
from phoenix.server.api.context import Context
from phoenix.server.api.schema import schema
from phoenix.server.span_handler import SpanHandler
Expand All @@ -33,6 +34,9 @@

class AppConfig(NamedTuple):
has_corpus: bool
min_dist: float
n_neighbors: int
n_samples: int


class Static(StaticFiles):
Expand All @@ -58,6 +62,9 @@ async def get_response(self, path: str, scope: Scope) -> Response:
"index.html",
context={
"has_corpus": self._app_config.has_corpus,
"min_dist": self._app_config.min_dist,
"n_neighbors": self._app_config.n_neighbors,
"n_samples": self._app_config.n_samples,
"request": Request(scope),
},
)
Expand Down Expand Up @@ -130,6 +137,7 @@ async def version(_: Request) -> PlainTextResponse:
def create_app(
export_path: Path,
model: Model,
umap_params: UMAPParameters,
corpus: Optional[Model] = None,
traces: Optional[Traces] = None,
debug: bool = False,
Expand Down Expand Up @@ -182,6 +190,9 @@ def create_app(
directory=SERVER_DIR / "static",
app_config=AppConfig(
has_corpus=corpus is not None,
min_dist=umap_params.min_dist,
n_neighbors=umap_params.n_neighbors,
n_samples=umap_params.n_samples,
),
),
name="static",
Expand Down
19 changes: 18 additions & 1 deletion src/phoenix/server/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@
from phoenix.core.traces import Traces
from phoenix.datasets.dataset import EMPTY_DATASET, Dataset
from phoenix.datasets.fixtures import FIXTURES, get_datasets
from phoenix.pointcloud.umap_parameters import (
DEFAULT_MIN_DIST,
DEFAULT_N_NEIGHBORS,
DEFAULT_N_SAMPLES,
UMAPParameters,
)
from phoenix.server.app import create_app
from phoenix.trace.fixtures import (
TRACES_FIXTURES,
Expand Down Expand Up @@ -46,6 +52,8 @@ def _get_pid_file() -> Path:
return get_pids_path() / str(os.getpid())


DEFAULT_UMAP_PARAMS_STR = f"{DEFAULT_MIN_DIST},{DEFAULT_N_NEIGHBORS},{DEFAULT_N_SAMPLES}"

if __name__ == "__main__":
primary_dataset_name: str
reference_dataset_name: Optional[str]
Expand All @@ -63,7 +71,8 @@ def _get_pid_file() -> Path:
parser.add_argument("--host", type=str, required=False)
parser.add_argument("--port", type=int, required=False)
parser.add_argument("--no-internet", action="store_true")
parser.add_argument("--debug", action="store_false") # TODO: Disable before public launch
parser.add_argument("--umap_params", type=str, required=False, default=DEFAULT_UMAP_PARAMS_STR)
parser.add_argument("--debug", action="store_false")
subparsers = parser.add_subparsers(dest="command", required=True)
datasets_parser = subparsers.add_parser("datasets")
datasets_parser.add_argument("--primary", type=str, required=True)
Expand Down Expand Up @@ -120,9 +129,17 @@ def _get_pid_file() -> Path:
),
):
traces.put(span)
umap_params_list = args.umap_params.split(",")
umap_params = UMAPParameters(
min_dist=float(umap_params_list[0]),
n_neighbors=int(umap_params_list[1]),
n_samples=int(umap_params_list[2]),
)
logger.info(f"Server umap params: {umap_params}")
app = create_app(
export_path=export_path,
model=model,
umap_params=umap_params,
traces=traces,
corpus=None if corpus_dataset is None else create_model_from_datasets(corpus_dataset),
debug=args.debug,
Expand Down
5 changes: 5 additions & 0 deletions src/phoenix/server/templates/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@
// injected into the client before React runs
value: Object.freeze({
hasCorpus: Boolean("{{has_corpus}}"),
UMAP: {
minDist: parseFloat("{{min_dist}}"),
nNeighbors: parseInt("{{n_neighbors}}"),
nSamples: parseInt("{{n_samples}}"),
}
}),
writable: false
});
Expand Down
4 changes: 4 additions & 0 deletions src/phoenix/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def __init__(
host: str,
port: int,
primary_dataset_name: str,
umap_params: str,
reference_dataset_name: Optional[str],
corpus_dataset_name: Optional[str],
trace_dataset_name: Optional[str],
Expand All @@ -119,6 +120,7 @@ def __init__(
self.host = host
self.port = port
self.__primary_dataset_name = primary_dataset_name
self.__umap_params = umap_params
self.__reference_dataset_name = reference_dataset_name
self.__corpus_dataset_name = corpus_dataset_name
self.__trace_dataset_name = trace_dataset_name
Expand All @@ -138,6 +140,8 @@ def command(self) -> List[str]:
"datasets",
"--primary",
str(self.__primary_dataset_name),
"--umap_params",
self.__umap_params,
]
if self.__reference_dataset_name is not None:
command.extend(["--reference", str(self.__reference_dataset_name)])
Expand Down
36 changes: 33 additions & 3 deletions src/phoenix/session/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,23 @@
from datetime import datetime
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import TYPE_CHECKING, Iterable, List, Optional, Set
from typing import (
TYPE_CHECKING,
Any,
Iterable,
List,
Mapping,
Optional,
Set,
)

import pandas as pd

from phoenix.config import get_env_host, get_env_port, get_exported_files
from phoenix.core.model_schema_adapter import create_model_from_datasets
from phoenix.core.traces import Traces
from phoenix.datasets.dataset import EMPTY_DATASET, Dataset
from phoenix.pointcloud.umap_parameters import get_umap_parameters
from phoenix.server.app import create_app
from phoenix.server.thread_server import ThreadServer
from phoenix.services import AppService
Expand Down Expand Up @@ -69,13 +78,15 @@ def __init__(
reference_dataset: Optional[Dataset] = None,
corpus_dataset: Optional[Dataset] = None,
trace_dataset: Optional[TraceDataset] = None,
default_umap_parameters: Optional[Mapping[str, Any]] = None,
host: Optional[str] = None,
port: Optional[int] = None,
):
self.primary_dataset = primary_dataset
self.reference_dataset = reference_dataset
self.corpus_dataset = corpus_dataset
self.trace_dataset = trace_dataset
self.umap_parameters = get_umap_parameters(default_umap_parameters)
self.model = create_model_from_datasets(
primary_dataset,
reference_dataset,
Expand Down Expand Up @@ -174,6 +185,7 @@ def __init__(
reference_dataset: Optional[Dataset] = None,
corpus_dataset: Optional[Dataset] = None,
trace_dataset: Optional[TraceDataset] = None,
default_umap_parameters: Optional[Mapping[str, Any]] = None,
host: Optional[str] = None,
port: Optional[int] = None,
) -> None:
Expand All @@ -182,6 +194,7 @@ def __init__(
reference_dataset=reference_dataset,
corpus_dataset=corpus_dataset,
trace_dataset=trace_dataset,
default_umap_parameters=default_umap_parameters,
host=host,
port=port,
)
Expand All @@ -192,12 +205,18 @@ def __init__(
corpus_dataset.to_disc()
if isinstance(trace_dataset, TraceDataset):
trace_dataset.to_disc()
umap_params_str = (
f"{self.umap_parameters.min_dist},"
f"{self.umap_parameters.n_neighbors},"
f"{self.umap_parameters.n_samples}"
)
# Initialize an app service that keeps the server running
self.app_service = AppService(
self.export_path,
self.host,
self.port,
self.primary_dataset.name,
umap_params_str,
reference_dataset_name=(
self.reference_dataset.name if self.reference_dataset is not None else None
),
Expand Down Expand Up @@ -225,6 +244,7 @@ def __init__(
reference_dataset: Optional[Dataset] = None,
corpus_dataset: Optional[Dataset] = None,
trace_dataset: Optional[TraceDataset] = None,
default_umap_parameters: Optional[Mapping[str, Any]] = None,
host: Optional[str] = None,
port: Optional[int] = None,
):
Expand All @@ -233,6 +253,7 @@ def __init__(
reference_dataset=reference_dataset,
corpus_dataset=corpus_dataset,
trace_dataset=trace_dataset,
default_umap_parameters=default_umap_parameters,
host=host,
port=port,
)
Expand All @@ -242,6 +263,7 @@ def __init__(
model=self.model,
corpus=self.corpus,
traces=self.traces,
umap_params=self.umap_parameters,
)
self.server = ThreadServer(
app=self.app,
Expand All @@ -265,6 +287,7 @@ def launch_app(
reference: Optional[Dataset] = None,
corpus: Optional[Dataset] = None,
trace: Optional[TraceDataset] = None,
default_umap_parameters: Optional[Mapping[str, Any]] = None,
host: Optional[str] = None,
port: Optional[int] = None,
run_in_thread: bool = True,
Expand All @@ -291,6 +314,9 @@ def launch_app(
variable `PHOENIX_PORT`, otherwise it defaults to 6060.
run_in_thread: bool, optional, default=True
Whether the server should run in a Thread or Process.
default_umap_parameters: Dict[str, Union[int, float]], optional, default=None
User specified default UMAP parameters
eg: {"n_neighbors": 10, "n_samples": 5, "min_dist": 0.5}
Returns
-------
Expand Down Expand Up @@ -323,10 +349,14 @@ def launch_app(
port = port or get_env_port()

if run_in_thread:
_session = ThreadSession(primary, reference, corpus, trace, host=host, port=port)
_session = ThreadSession(
primary, reference, corpus, trace, default_umap_parameters, host=host, port=port
)
# TODO: catch exceptions from thread
else:
_session = ProcessSession(primary, reference, corpus, trace, host=host, port=port)
_session = ProcessSession(
primary, reference, corpus, trace, default_umap_parameters, host=host, port=port
)

if not _session.active:
logger.error(
Expand Down

0 comments on commit dce7551

Please sign in to comment.