-
Notifications
You must be signed in to change notification settings - Fork 285
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: allow download of exported parquet files #459
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -1,6 +1,7 @@ | ||||||
import errno | ||||||
import tempfile | ||||||
from heapq import nlargest | ||||||
from pathlib import Path | ||||||
from typing import List | ||||||
|
||||||
|
||||||
def _get_temp_path() -> Path: | ||||||
|
@@ -14,15 +15,7 @@ def get_pids_path() -> Path: | |||||
on the host machine. The directory will be created if it does not exist. | ||||||
""" | ||||||
path = _get_temp_path() / "pids" | ||||||
try: | ||||||
path.mkdir(parents=True, exist_ok=True) | ||||||
except OSError as e: | ||||||
if e.errno == errno.EEXIST: | ||||||
pass | ||||||
else: | ||||||
raise | ||||||
else: | ||||||
path.chmod(0o777) | ||||||
path.mkdir(parents=True, exist_ok=True) | ||||||
Comment on lines
-17
to
+18
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🙏 |
||||||
return path | ||||||
|
||||||
|
||||||
|
@@ -38,3 +31,30 @@ def get_pids_path() -> Path: | |||||
SERVER_DIR = PHOENIX_DIR / "server" | ||||||
# The port the server will run on after launch_app is called | ||||||
PORT = 6060 | ||||||
|
||||||
|
||||||
def get_exported_files( | ||||||
n_latest: int = 5, | ||||||
directory: Path = EXPORT_DIR, | ||||||
extension: str = "parquet", | ||||||
) -> List[Path]: | ||||||
""" | ||||||
Yields n most recently exported files by descending modification time. | ||||||
|
||||||
Parameters | ||||||
---------- | ||||||
n_latest: int, optional, default=5 | ||||||
Specifies the number of the most recent exported files to return. If | ||||||
there are fewer than n exported files then fewer than n files will | ||||||
be returned. | ||||||
|
||||||
Returns | ||||||
------- | ||||||
list: List[Path] | ||||||
List of paths of the n most recent exported files. | ||||||
""" | ||||||
return nlargest( | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is a python built-in: heapq.nlargest |
||||||
n_latest, | ||||||
directory.glob("*." + extension), | ||||||
lambda p: p.stat().st_mtime, | ||||||
) | ||||||
Comment on lines
+36
to
+60
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure I would expect this function to live with the config. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. true. i didn't find a better home for it There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could make sense to put it in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yea that's where i had put it in the first place, but then i realize i may want to use this function in other places too, e.g. in |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,17 @@ | ||
from typing import Optional | ||
import asyncio | ||
from typing import List, Optional | ||
|
||
import strawberry | ||
from strawberry.types import Info | ||
from strawberry.unset import UNSET | ||
|
||
from phoenix.config import EXPORT_DIR, get_exported_files | ||
from phoenix.server.api.context import Context | ||
|
||
from .Dataset import Dataset, to_gql_dataset | ||
from .Dimension import Dimension, to_gql_dimension | ||
from .EmbeddingDimension import EmbeddingDimension, to_gql_embedding_dimension | ||
from .ExportedFile import ExportedFile | ||
from .pagination import Connection, ConnectionArgs, Cursor, connection_from_list | ||
|
||
|
||
|
@@ -43,13 +46,19 @@ def dimensions( | |
|
||
@strawberry.field | ||
def primary_dataset(self, info: Info[Context, None]) -> Dataset: | ||
return to_gql_dataset(dataset=info.context.model.primary_dataset, type="primary") | ||
return to_gql_dataset( | ||
dataset=info.context.model.primary_dataset, | ||
type="primary", | ||
) | ||
|
||
@strawberry.field | ||
def reference_dataset(self, info: Info[Context, None]) -> Optional[Dataset]: | ||
if info.context.model.reference_dataset is None: | ||
return None | ||
return to_gql_dataset(dataset=info.context.model.reference_dataset, type="reference") | ||
return to_gql_dataset( | ||
dataset=info.context.model.reference_dataset, | ||
type="reference", | ||
) | ||
|
||
@strawberry.field | ||
def embedding_dimensions( | ||
|
@@ -68,7 +77,9 @@ def embedding_dimensions( | |
return connection_from_list( | ||
[ | ||
to_gql_embedding_dimension(index, embedding_dimension) | ||
for index, embedding_dimension in enumerate(info.context.model.embedding_dimensions) | ||
for index, embedding_dimension in enumerate( | ||
info.context.model.embedding_dimensions, | ||
) | ||
], | ||
args=ConnectionArgs( | ||
first=first, | ||
|
@@ -77,3 +88,25 @@ def embedding_dimensions( | |
before=before if isinstance(before, Cursor) else None, | ||
), | ||
) | ||
|
||
@strawberry.field( | ||
description=( | ||
"Returns n most recent exported Parquet files sorted by descending modification time." | ||
), | ||
) # type: ignore # https://github.com/strawberry-graphql/strawberry/issues/1929 | ||
async def exported_files( | ||
self, | ||
n_latest: int = 5, | ||
) -> List[ExportedFile]: | ||
loop = asyncio.get_running_loop() | ||
return [ | ||
ExportedFile( | ||
file_name=path.stem, | ||
directory=str(EXPORT_DIR), | ||
) | ||
for path in await loop.run_in_executor( | ||
None, | ||
get_exported_files, | ||
n_latest, | ||
) | ||
] | ||
Comment on lines
+102
to
+112
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What does the async list comprehension do here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's just running the I/O operation (listing files) in a separate thread so it's not blocking the event loop. The comprehension itself is not async. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So the idea is that the call to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes. probably doesn't matter in reality. i think it's just good practice (for I/O operations in general) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Got it. I'm curious if there's a way of accomplishing this without explicitly invoking the event loop. It would be possible to make There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, but not in 3.8 haha There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i thought about making |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,19 +2,21 @@ | |
from typing import Optional, Union | ||
|
||
from starlette.applications import Starlette | ||
from starlette.datastructures import QueryParams | ||
from starlette.endpoints import HTTPEndpoint | ||
from starlette.exceptions import HTTPException | ||
from starlette.middleware import Middleware | ||
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint | ||
from starlette.requests import Request | ||
from starlette.responses import Response | ||
from starlette.responses import FileResponse, Response | ||
from starlette.routing import Mount, Route, WebSocketRoute | ||
from starlette.staticfiles import StaticFiles | ||
from starlette.types import Scope | ||
from starlette.websockets import WebSocket | ||
from strawberry.asgi import GraphQL | ||
from strawberry.schema import BaseSchema | ||
|
||
from phoenix.config import SERVER_DIR | ||
from phoenix.config import EXPORT_DIR, SERVER_DIR | ||
from phoenix.core.model import Model | ||
from phoenix.datasets import Dataset | ||
|
||
|
@@ -81,16 +83,31 @@ async def get_context( | |
) | ||
|
||
|
||
class Download(HTTPEndpoint): | ||
async def get(self, request: Request) -> FileResponse: | ||
params = QueryParams(request.query_params) | ||
file = EXPORT_DIR / (params.get("filename", "") + ".parquet") | ||
if not file.is_file(): | ||
raise HTTPException(status_code=404) | ||
return FileResponse( | ||
path=file, | ||
filename=file.name, | ||
media_type="application/x-octet-stream", | ||
) | ||
Comment on lines
+86
to
+96
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why not just use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I couldn't figure out how to get |
||
|
||
|
||
def create_app( | ||
primary_dataset_name: str, | ||
reference_dataset_name: Optional[str], | ||
debug: bool = False, | ||
) -> Starlette: | ||
model = Model( | ||
primary_dataset=Dataset.from_name(primary_dataset_name), | ||
reference_dataset=Dataset.from_name(reference_dataset_name) | ||
if reference_dataset_name is not None | ||
else None, | ||
reference_dataset=( | ||
Dataset.from_name(reference_dataset_name) | ||
if reference_dataset_name is not None | ||
else None | ||
), | ||
) | ||
graphql = GraphQLWithContext( | ||
schema=schema, | ||
|
@@ -104,6 +121,10 @@ def create_app( | |
], | ||
debug=debug, | ||
routes=[ | ||
Route( | ||
"/exports", | ||
Download, | ||
), | ||
Route( | ||
"/graphql", | ||
graphql, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just checking I'm not sure you're going to be able to execute graphQL queries from the python runtime for colab- at least I've failed to do so so far.
It might be simplest to just read from the directory? No need for networkIO? This is nice for the UI so non-blocking.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that's correct. i am working on the next PR for the console version. this is just for the GUI (e.g. in a modal)