Skip to content

Commit

Permalink
feat(client): support creating a handle from request_id
Browse files Browse the repository at this point in the history
  • Loading branch information
efiop committed Oct 3, 2024
1 parent 1b087b2 commit dcb2136
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 1 deletion.
84 changes: 83 additions & 1 deletion projects/fal_client/src/fal_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import base64
from dataclasses import dataclass, field
from functools import cached_property
from typing import Any, AsyncIterator, Iterator, TYPE_CHECKING
from typing import Any, AsyncIterator, Iterator, TYPE_CHECKING, Optional

import httpx
from httpx_sse import aconnect_sse, connect_sse
Expand Down Expand Up @@ -95,10 +95,71 @@ def _parse_status(self, data: AnyJSON) -> Status:
raise ValueError(f"Unknown status: {data['status']}")


APP_NAMESPACES = ["workflows", "comfy"]


def _ensure_app_id_format(id: str) -> str:
import re

parts = id.split("/")
if len(parts) > 1:
return id

match = re.match(r"^([0-9]+)-([a-zA-Z0-9-]+)$", id)
if match:
app_owner, app_id = match.groups()
return f"{app_owner}/{app_id}"

raise ValueError(f"Invalid app id: {id}. Must be in the format <appOwner>/<appId>")


@dataclass(frozen=True)
class AppId:
owner: str
alias: str
path: Optional[str]
namespace: Optional[str]

@classmethod
def from_endpoint_id(cls, endpoint_id: str) -> AppId:
normalized_id = _ensure_app_id_format(endpoint_id)
parts = normalized_id.split("/")

if parts[0] in APP_NAMESPACES:
return cls(
owner=parts[1],
alias=parts[2],
path="/".join(parts[3:]) or None,
namespace=parts[0],
)

return cls(
owner=parts[0],
alias=parts[1],
path="/".join(parts[2:]) or None,
namespace=None,
)


@dataclass(frozen=True)
class SyncRequestHandle(_BaseRequestHandle):
client: httpx.Client = field(repr=False)

@classmethod
def from_request_id(
cls, client: httpx.Client, application: str, request_id: str
) -> SyncRequestHandle:
app_id = AppId.from_endpoint_id(application)
prefix = f"{app_id.namespace}/" if app_id.namespace else ""
base_url = f"{QUEUE_URL_FORMAT}{prefix}{app_id.owner}/{app_id.alias}/requests/{request_id}"
return cls(
request_id=request_id,
response_url=base_url,
status_url=base_url + "/status",
cancel_url=base_url + "/cancel",
client=client,
)

def status(self, *, with_logs: bool = False) -> Status:
"""Returns the status of the request (which can be one of the following:
Queued, InProgress, Completed). If `with_logs` is True, logs will be included
Expand Down Expand Up @@ -143,6 +204,21 @@ def get(self) -> AnyJSON:
class AsyncRequestHandle(_BaseRequestHandle):
client: httpx.AsyncClient = field(repr=False)

@classmethod
def from_request_id(
cls, client: httpx.AsyncClient, application: str, request_id: str
) -> AsyncRequestHandle:
app_id = AppId.from_endpoint_id(application)
prefix = f"{app_id.namespace}/" if app_id.namespace else ""
base_url = f"{QUEUE_URL_FORMAT}{prefix}{app_id.owner}/{app_id.alias}/requests/{request_id}"
return cls(
request_id=request_id,
response_url=base_url,
status_url=base_url + "/status",
cancel_url=base_url + "/cancel",
client=client,
)

async def status(self, *, with_logs: bool = False) -> Status:
"""Returns the status of the request (which can be one of the following:
Queued, InProgress, Completed). If `with_logs` is True, logs will be included
Expand Down Expand Up @@ -269,6 +345,9 @@ async def submit(
client=self._client,
)

def get_handle(self, application: str, request_id: str) -> AsyncRequestHandle:
return AsyncRequestHandle.from_request_id(self._client, application, request_id)

async def stream(
self,
application: str,
Expand Down Expand Up @@ -415,6 +494,9 @@ def submit(
client=self._client,
)

def get_handle(self, application: str, request_id: str) -> SyncRequestHandle:
return SyncRequestHandle.from_request_id(self._client, application, request_id)

def stream(
self,
application: str,
Expand Down
3 changes: 3 additions & 0 deletions projects/fal_client/tests/test_async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ async def test_fal_client(client: fal_client.AsyncClient):
assert isinstance(status, fal_client.Completed)
assert status.logs is None

new_handle = client.get_handle("fal-ai/fast-sdxl/image-to-image", handle.request_id)
assert new_handle == handle

status_w_logs = await handle.status(with_logs=True)
assert isinstance(status_w_logs, fal_client.Completed)
assert status_w_logs.logs is not None
Expand Down
3 changes: 3 additions & 0 deletions projects/fal_client/tests/test_sync_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ def test_fal_client(client: fal_client.SyncClient):
assert isinstance(status_w_logs, fal_client.Completed)
assert status_w_logs.logs is not None

new_handle = client.get_handle("fal-ai/fast-sdxl/image-to-image", handle.request_id)
assert new_handle == handle

output = client.run(
"fal-ai/fast-sdxl",
arguments={
Expand Down

0 comments on commit dcb2136

Please sign in to comment.