Skip to content

Commit

Permalink
feat(client): add subscribe
Browse files Browse the repository at this point in the history
  • Loading branch information
efiop committed Oct 7, 2024
1 parent 86f3b9a commit 5802795
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 0 deletions.
42 changes: 42 additions & 0 deletions projects/fal_client/src/fal_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,27 @@ async def submit(
client=self._client,
)

async def subscribe(
self,
application: str,
arguments: AnyJSON,
*,
path: str = "",
hint: str | None = None,
with_logs: bool = False,
on_enqueue: Optional[callable[[Queued], None]] = None,
on_queue_update: Optional[callable[[Status], None]] = None,
) -> AnyJSON:
handle = self.submit(application, arguments, path=path, hint=hint)

if on_enqueue is not None:
on_enqueue(handle.request_id)

async for event in handle.iter_events(with_logs=with_logs):
on_queue_update(event)

return await handle.get()

def get_handle(self, application: str, request_id: str) -> AsyncRequestHandle:
return AsyncRequestHandle.from_request_id(self._client, application, request_id)

Expand Down Expand Up @@ -504,6 +525,27 @@ def submit(
client=self._client,
)

def subscribe(
self,
application: str,
arguments: AnyJSON,
*,
path: str = "",
hint: str | None = None,
with_logs: bool = False,
on_enqueue: Optional[callable[[Queued], None]] = None,
on_queue_update: Optional[callable[[Status], None]] = None,
) -> AnyJSON:
handle = self.submit(application, arguments, path=path, hint=hint)

if on_enqueue is not None:
on_enqueue(handle.request_id)

for event in handle.iter_events(with_logs=with_logs):
on_queue_update(event)

return handle.get()

def get_handle(self, application: str, request_id: str) -> SyncRequestHandle:
return SyncRequestHandle.from_request_id(self._client, application, request_id)

Expand Down
9 changes: 9 additions & 0 deletions projects/fal_client/tests/test_async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,15 @@ async def test_fal_client(client: fal_client.AsyncClient):
== status
)

output = await client.subscribe(
"fal-ai/fast-sdxl",
arguments={
"prompt": "a cat",
},
hint="lora:a",
)
assert len(output["images"]) == 1

output = await client.run(
"fal-ai/fast-sdxl",
arguments={
Expand Down
9 changes: 9 additions & 0 deletions projects/fal_client/tests/test_sync_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,15 @@ def test_fal_client(client: fal_client.SyncClient):

assert client.status("fal-ai/fast-sdxl/image-to-image", handle.request_id) == status

output = client.subscribe(
"fal-ai/fast-sdxl",
arguments={
"prompt": "a cat",
},
hint="lora:a",
)
assert len(output["images"]) == 1

output = client.run(
"fal-ai/fast-sdxl",
arguments={
Expand Down

0 comments on commit 5802795

Please sign in to comment.