Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add default UMAP parameters in launch_app() #1224

Merged
merged 30 commits into from
Oct 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
8ea71bd
Reaching till GraphQLWithContext
pbadhe Aug 28, 2023
689b6f5
typing.Dict not dict()
pbadhe Aug 28, 2023
5146c6e
Merge branch 'main' into umap_params
pbadhe Sep 9, 2023
2834175
Dict import
pbadhe Sep 9, 2023
c532d04
Merge branch 'main' into umap_params
pbadhe Sep 15, 2023
5e195e4
Suggested review changes
pbadhe Sep 15, 2023
dc9992b
Merge branch 'main' into umap_params
pbadhe Sep 26, 2023
b71de1a
CamelCase -> SnakeCase
pbadhe Sep 26, 2023
e1f9304
CamelCase -> SnakeCase all suggested changes
pbadhe Sep 26, 2023
a2f525d
Merge branch 'main' into umap_params
pbadhe Sep 27, 2023
91f5f47
Default UMAPparam changes
pbadhe Sep 27, 2023
d4673ba
Merge branch 'main' into umap_params
pbadhe Sep 29, 2023
eb16a0d
typing hints & process_session()
pbadhe Sep 29, 2023
6f2f8e6
Merge branch 'main' into umap_params
pbadhe Sep 29, 2023
0be3164
Prettier check
pbadhe Sep 29, 2023
bc4a8bc
Merge branch 'main' into umap_params
mikeldking Oct 4, 2023
08f3a57
Adjust frontend
mikeldking Oct 4, 2023
4975122
parse string correctly
mikeldking Oct 4, 2023
aa2fb82
fix duplicate lines
mikeldking Oct 4, 2023
6756ee4
Merge branch 'main' into umap_params
mikeldking Oct 4, 2023
7687b3b
Update umap_parameters.py
mikeldking Oct 5, 2023
7aa84a6
Update umap_parameters.py
mikeldking Oct 5, 2023
83f6ea5
Update umap_parameters.py
mikeldking Oct 5, 2023
800f85b
Update src/phoenix/session/session.py
mikeldking Oct 5, 2023
7ff011f
Update src/phoenix/session/session.py
mikeldking Oct 5, 2023
194a1ea
Update src/phoenix/session/session.py
mikeldking Oct 5, 2023
4a30132
use positional arguments
mikeldking Oct 5, 2023
fbf87e4
fix style and imports
mikeldking Oct 5, 2023
5b05baa
Merge branch 'main' into umap_params
mikeldking Oct 5, 2023
4c411be
Merge branch 'main' into umap_params
mikeldking Oct 5, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions app/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@
"build:relay": "relay-compiler",
"watch": "./esbuild.config.mjs dev",
"test": "jest --config ./jest.config.js",
"dev": "npm run dev:server:traces:llama_index_rag & npm run build:static && npm run watch",
"dev:server:mnist": "python3 -m phoenix.server.main fixture fashion_mnist",
"dev": "npm run dev:server:mnist & npm run build:static && npm run watch",
"dev:server:mnist": "python3 -m phoenix.server.main --umap_params 0,30,550 fixture fashion_mnist",
"dev:server:mnist:single": "python3 -m phoenix.server.main fixture fashion_mnist --primary-only true",
"dev:server:sentiment": "python3 -m phoenix.server.main fixture sentiment_classification_language_drift",
"dev:server:image": "python3 -m phoenix.server.main fixture image_classification",
Expand Down
18 changes: 15 additions & 3 deletions app/src/pages/embedding/EmbeddingPage.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -223,14 +223,26 @@ export function EmbeddingPage() {
const { timeRange } = useTimeRange();
// Initialize the store based on whether or not there is a reference dataset
const defaultPointCloudProps = useMemo<Partial<PointCloudProps>>(() => {
let defaultPointCloudProps: Partial<PointCloudProps> = {};
if (corpusDataset != null) {
// If there is a corpus dataset, then initialize the page with the retrieval troubleshooting settings
// TODO - this does make a bit of a leap of assumptions but is a short term solution in order to get the page working as intended
return DEFAULT_RETRIEVAL_TROUBLESHOOTING_POINT_CLOUD_PROPS;
defaultPointCloudProps =
DEFAULT_RETRIEVAL_TROUBLESHOOTING_POINT_CLOUD_PROPS;
} else if (referenceDataset != null) {
return DEFAULT_DRIFT_POINT_CLOUD_PROPS;
defaultPointCloudProps = DEFAULT_DRIFT_POINT_CLOUD_PROPS;
} else {
defaultPointCloudProps = DEFAULT_SINGLE_DATASET_POINT_CLOUD_PROPS;
}
return DEFAULT_SINGLE_DATASET_POINT_CLOUD_PROPS;

// Apply the UMAP parameters from the server-sent config
defaultPointCloudProps = {
...defaultPointCloudProps,
umapParameters: {
...window.Config.UMAP,
},
};
return defaultPointCloudProps;
}, [corpusDataset, referenceDataset]);
return (
<TimeSliceContextProvider initialTimestamp={timeRange.end}>
Expand Down
3 changes: 0 additions & 3 deletions app/src/store/pointCloudStore.ts
Original file line number Diff line number Diff line change
Expand Up @@ -69,17 +69,14 @@ type DimensionMetadata = {
export type UMAPParameters = {
/**
* Minimum distance between points in the eUMAP projection
* @default 0
*/
minDist: number;
/**
* The number of neighbors to require for the UMAP projection
* @default 30
*/
nNeighbors: number;
/**
* The number of samples to use for the UMAP projection. The sample number is per dataset.
* @default 500
*/
nSamples: number;
};
Expand Down
14 changes: 14 additions & 0 deletions app/src/window.d.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
export {};
mikeldking marked this conversation as resolved.
Show resolved Hide resolved

declare global {
interface Window {
Config: {
hasCorpus: boolean;
UMAP: {
minDist: number;
nNeighbors: number;
nSamples: number;
};
};
}
}
52 changes: 52 additions & 0 deletions src/phoenix/pointcloud/umap_parameters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from dataclasses import dataclass
from typing import Any, Mapping, Optional

DEFAULT_MIN_DIST = 0.0
DEFAULT_N_NEIGHBORS = 30
DEFAULT_N_SAMPLES = 500

MIN_NEIGHBORS = 5
MAX_NEIGHBORS = 100
MIN_SAMPLES = 1
MAX_SAMPLES = 1000
MIN_MIN_DIST = 0.0
MAX_MIN_DIST = 0.99


@dataclass
class UMAPParameters:
min_dist: float = DEFAULT_MIN_DIST
n_neighbors: int = DEFAULT_N_NEIGHBORS
n_samples: int = DEFAULT_N_SAMPLES

def __post_init__(self) -> None:
if not isinstance(self.min_dist, float) or not (
MIN_MIN_DIST <= self.min_dist <= MAX_MIN_DIST
):
raise ValueError(
f"minDist must be float type, and between {MIN_MIN_DIST} and {MAX_MIN_DIST}"
)

if not isinstance(self.n_neighbors, int) or not (
MIN_NEIGHBORS <= self.n_neighbors <= MAX_NEIGHBORS
):
raise ValueError(
f"nNeighbors must be int type, and between {MIN_NEIGHBORS} and {MAX_NEIGHBORS}"
)

if not isinstance(self.n_samples, int) or not (
MIN_SAMPLES <= self.n_samples <= MAX_SAMPLES
):
raise ValueError(
f"nSamples must be int type, and between {MIN_SAMPLES} and {MAX_SAMPLES}"
)


def get_umap_parameters(default_umap_parameters: Optional[Mapping[str, Any]]) -> UMAPParameters:
if not default_umap_parameters:
return UMAPParameters()
return UMAPParameters(
min_dist=float(default_umap_parameters.get("min_dist", DEFAULT_MIN_DIST)),
n_neighbors=int(default_umap_parameters.get("n_neighbors", DEFAULT_N_NEIGHBORS)),
n_samples=int(default_umap_parameters.get("n_samples", DEFAULT_N_SAMPLES)),
)
11 changes: 11 additions & 0 deletions src/phoenix/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from phoenix.config import SERVER_DIR
from phoenix.core.model_schema import Model
from phoenix.core.traces import Traces
from phoenix.pointcloud.umap_parameters import UMAPParameters
from phoenix.server.api.context import Context
from phoenix.server.api.schema import schema
from phoenix.server.span_handler import SpanHandler
Expand All @@ -33,6 +34,9 @@

class AppConfig(NamedTuple):
has_corpus: bool
min_dist: float
n_neighbors: int
n_samples: int


class Static(StaticFiles):
Expand All @@ -58,6 +62,9 @@ async def get_response(self, path: str, scope: Scope) -> Response:
"index.html",
context={
"has_corpus": self._app_config.has_corpus,
"min_dist": self._app_config.min_dist,
"n_neighbors": self._app_config.n_neighbors,
"n_samples": self._app_config.n_samples,
"request": Request(scope),
},
)
Expand Down Expand Up @@ -130,6 +137,7 @@ async def version(_: Request) -> PlainTextResponse:
def create_app(
export_path: Path,
model: Model,
umap_params: UMAPParameters,
corpus: Optional[Model] = None,
traces: Optional[Traces] = None,
debug: bool = False,
Expand Down Expand Up @@ -182,6 +190,9 @@ def create_app(
directory=SERVER_DIR / "static",
app_config=AppConfig(
has_corpus=corpus is not None,
min_dist=umap_params.min_dist,
n_neighbors=umap_params.n_neighbors,
n_samples=umap_params.n_samples,
),
),
name="static",
Expand Down
19 changes: 18 additions & 1 deletion src/phoenix/server/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@
from phoenix.core.traces import Traces
from phoenix.datasets.dataset import EMPTY_DATASET, Dataset
from phoenix.datasets.fixtures import FIXTURES, get_datasets
from phoenix.pointcloud.umap_parameters import (
DEFAULT_MIN_DIST,
DEFAULT_N_NEIGHBORS,
DEFAULT_N_SAMPLES,
UMAPParameters,
)
from phoenix.server.app import create_app
from phoenix.trace.fixtures import (
TRACES_FIXTURES,
Expand Down Expand Up @@ -46,6 +52,8 @@ def _get_pid_file() -> Path:
return get_pids_path() / str(os.getpid())


DEFAULT_UMAP_PARAMS_STR = f"{DEFAULT_MIN_DIST},{DEFAULT_N_NEIGHBORS},{DEFAULT_N_SAMPLES}"

if __name__ == "__main__":
primary_dataset_name: str
reference_dataset_name: Optional[str]
Expand All @@ -63,7 +71,8 @@ def _get_pid_file() -> Path:
parser.add_argument("--host", type=str, required=False)
parser.add_argument("--port", type=int, required=False)
parser.add_argument("--no-internet", action="store_true")
parser.add_argument("--debug", action="store_false") # TODO: Disable before public launch
parser.add_argument("--umap_params", type=str, required=False, default=DEFAULT_UMAP_PARAMS_STR)
parser.add_argument("--debug", action="store_false")
subparsers = parser.add_subparsers(dest="command", required=True)
datasets_parser = subparsers.add_parser("datasets")
datasets_parser.add_argument("--primary", type=str, required=True)
Expand Down Expand Up @@ -120,9 +129,17 @@ def _get_pid_file() -> Path:
),
):
traces.put(span)
umap_params_list = args.umap_params.split(",")
umap_params = UMAPParameters(
min_dist=float(umap_params_list[0]),
n_neighbors=int(umap_params_list[1]),
n_samples=int(umap_params_list[2]),
)
logger.info(f"Server umap params: {umap_params}")
app = create_app(
export_path=export_path,
model=model,
umap_params=umap_params,
traces=traces,
corpus=None if corpus_dataset is None else create_model_from_datasets(corpus_dataset),
debug=args.debug,
Expand Down
5 changes: 5 additions & 0 deletions src/phoenix/server/templates/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@
// injected into the client before React runs
value: Object.freeze({
hasCorpus: Boolean("{{has_corpus}}"),
UMAP: {
minDist: parseFloat("{{min_dist}}"),
nNeighbors: parseInt("{{n_neighbors}}"),
nSamples: parseInt("{{n_samples}}"),
}
}),
writable: false
});
Expand Down
4 changes: 4 additions & 0 deletions src/phoenix/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def __init__(
host: str,
port: int,
primary_dataset_name: str,
umap_params: str,
reference_dataset_name: Optional[str],
corpus_dataset_name: Optional[str],
trace_dataset_name: Optional[str],
Expand All @@ -119,6 +120,7 @@ def __init__(
self.host = host
self.port = port
self.__primary_dataset_name = primary_dataset_name
self.__umap_params = umap_params
self.__reference_dataset_name = reference_dataset_name
self.__corpus_dataset_name = corpus_dataset_name
self.__trace_dataset_name = trace_dataset_name
Expand All @@ -138,6 +140,8 @@ def command(self) -> List[str]:
"datasets",
"--primary",
str(self.__primary_dataset_name),
"--umap_params",
self.__umap_params,
]
if self.__reference_dataset_name is not None:
command.extend(["--reference", str(self.__reference_dataset_name)])
Expand Down
36 changes: 33 additions & 3 deletions src/phoenix/session/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,23 @@
from datetime import datetime
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import TYPE_CHECKING, Iterable, List, Optional, Set
from typing import (
TYPE_CHECKING,
Any,
Iterable,
List,
Mapping,
Optional,
Set,
)

import pandas as pd

from phoenix.config import get_env_host, get_env_port, get_exported_files
from phoenix.core.model_schema_adapter import create_model_from_datasets
from phoenix.core.traces import Traces
from phoenix.datasets.dataset import EMPTY_DATASET, Dataset
from phoenix.pointcloud.umap_parameters import get_umap_parameters
from phoenix.server.app import create_app
from phoenix.server.thread_server import ThreadServer
from phoenix.services import AppService
Expand Down Expand Up @@ -69,13 +78,15 @@ def __init__(
reference_dataset: Optional[Dataset] = None,
corpus_dataset: Optional[Dataset] = None,
trace_dataset: Optional[TraceDataset] = None,
default_umap_parameters: Optional[Mapping[str, Any]] = None,
host: Optional[str] = None,
port: Optional[int] = None,
):
self.primary_dataset = primary_dataset
self.reference_dataset = reference_dataset
self.corpus_dataset = corpus_dataset
self.trace_dataset = trace_dataset
self.umap_parameters = get_umap_parameters(default_umap_parameters)
self.model = create_model_from_datasets(
primary_dataset,
reference_dataset,
Expand Down Expand Up @@ -174,6 +185,7 @@ def __init__(
reference_dataset: Optional[Dataset] = None,
corpus_dataset: Optional[Dataset] = None,
trace_dataset: Optional[TraceDataset] = None,
default_umap_parameters: Optional[Mapping[str, Any]] = None,
host: Optional[str] = None,
port: Optional[int] = None,
) -> None:
Expand All @@ -182,6 +194,7 @@ def __init__(
reference_dataset=reference_dataset,
corpus_dataset=corpus_dataset,
trace_dataset=trace_dataset,
default_umap_parameters=default_umap_parameters,
host=host,
port=port,
)
Expand All @@ -192,12 +205,18 @@ def __init__(
corpus_dataset.to_disc()
if isinstance(trace_dataset, TraceDataset):
trace_dataset.to_disc()
umap_params_str = (
f"{self.umap_parameters.min_dist},"
f"{self.umap_parameters.n_neighbors},"
f"{self.umap_parameters.n_samples}"
)
# Initialize an app service that keeps the server running
self.app_service = AppService(
self.export_path,
self.host,
self.port,
self.primary_dataset.name,
umap_params_str,
reference_dataset_name=(
self.reference_dataset.name if self.reference_dataset is not None else None
),
Expand Down Expand Up @@ -225,6 +244,7 @@ def __init__(
reference_dataset: Optional[Dataset] = None,
corpus_dataset: Optional[Dataset] = None,
trace_dataset: Optional[TraceDataset] = None,
default_umap_parameters: Optional[Mapping[str, Any]] = None,
host: Optional[str] = None,
port: Optional[int] = None,
):
Expand All @@ -233,6 +253,7 @@ def __init__(
reference_dataset=reference_dataset,
corpus_dataset=corpus_dataset,
trace_dataset=trace_dataset,
default_umap_parameters=default_umap_parameters,
host=host,
port=port,
)
Expand All @@ -242,6 +263,7 @@ def __init__(
model=self.model,
corpus=self.corpus,
traces=self.traces,
umap_params=self.umap_parameters,
)
self.server = ThreadServer(
app=self.app,
Expand All @@ -265,6 +287,7 @@ def launch_app(
reference: Optional[Dataset] = None,
corpus: Optional[Dataset] = None,
trace: Optional[TraceDataset] = None,
default_umap_parameters: Optional[Mapping[str, Any]] = None,
host: Optional[str] = None,
port: Optional[int] = None,
run_in_thread: bool = True,
Expand All @@ -291,6 +314,9 @@ def launch_app(
variable `PHOENIX_PORT`, otherwise it defaults to 6060.
run_in_thread: bool, optional, default=True
Whether the server should run in a Thread or Process.
default_umap_parameters: Dict[str, Union[int, float]], optional, default=None
User specified default UMAP parameters
eg: {"n_neighbors": 10, "n_samples": 5, "min_dist": 0.5}

Returns
-------
Expand Down Expand Up @@ -323,10 +349,14 @@ def launch_app(
port = port or get_env_port()

if run_in_thread:
_session = ThreadSession(primary, reference, corpus, trace, host=host, port=port)
_session = ThreadSession(
primary, reference, corpus, trace, default_umap_parameters, host=host, port=port
)
# TODO: catch exceptions from thread
else:
_session = ProcessSession(primary, reference, corpus, trace, host=host, port=port)
_session = ProcessSession(
primary, reference, corpus, trace, default_umap_parameters, host=host, port=port
)

if not _session.active:
logger.error(
Expand Down