From 58027954ca789edee7afe07136b7648074ddd885 Mon Sep 17 00:00:00 2001 From: Ruslan Kuprieiev Date: Mon, 7 Oct 2024 21:32:33 +0300 Subject: [PATCH] feat(client): add subscribe --- projects/fal_client/src/fal_client/client.py | 42 +++++++++++++++++++ .../fal_client/tests/test_async_client.py | 9 ++++ projects/fal_client/tests/test_sync_client.py | 9 ++++ 3 files changed, 60 insertions(+) diff --git a/projects/fal_client/src/fal_client/client.py b/projects/fal_client/src/fal_client/client.py index 1005d93b..37f16784 100644 --- a/projects/fal_client/src/fal_client/client.py +++ b/projects/fal_client/src/fal_client/client.py @@ -345,6 +345,27 @@ async def submit( client=self._client, ) + async def subscribe( + self, + application: str, + arguments: AnyJSON, + *, + path: str = "", + hint: str | None = None, + with_logs: bool = False, + on_enqueue: Optional[callable[[Queued], None]] = None, + on_queue_update: Optional[callable[[Status], None]] = None, + ) -> AnyJSON: + handle = self.submit(application, arguments, path=path, hint=hint) + + if on_enqueue is not None: + on_enqueue(handle.request_id) + + async for event in handle.iter_events(with_logs=with_logs): + on_queue_update(event) + + return await handle.get() + def get_handle(self, application: str, request_id: str) -> AsyncRequestHandle: return AsyncRequestHandle.from_request_id(self._client, application, request_id) @@ -504,6 +525,27 @@ def submit( client=self._client, ) + def subscribe( + self, + application: str, + arguments: AnyJSON, + *, + path: str = "", + hint: str | None = None, + with_logs: bool = False, + on_enqueue: Optional[callable[[Queued], None]] = None, + on_queue_update: Optional[callable[[Status], None]] = None, + ) -> AnyJSON: + handle = self.submit(application, arguments, path=path, hint=hint) + + if on_enqueue is not None: + on_enqueue(handle.request_id) + + for event in handle.iter_events(with_logs=with_logs): + on_queue_update(event) + + return handle.get() + def get_handle(self, application: str, request_id: str) -> SyncRequestHandle: return SyncRequestHandle.from_request_id(self._client, application, request_id) diff --git a/projects/fal_client/tests/test_async_client.py b/projects/fal_client/tests/test_async_client.py index 10625b2c..a5a6bb30 100644 --- a/projects/fal_client/tests/test_async_client.py +++ b/projects/fal_client/tests/test_async_client.py @@ -55,6 +55,15 @@ async def test_fal_client(client: fal_client.AsyncClient): == status ) + output = await client.subscribe( + "fal-ai/fast-sdxl", + arguments={ + "prompt": "a cat", + }, + hint="lora:a", + ) + assert len(output["images"]) == 1 + output = await client.run( "fal-ai/fast-sdxl", arguments={ diff --git a/projects/fal_client/tests/test_sync_client.py b/projects/fal_client/tests/test_sync_client.py index 88f31138..1f9f0bd4 100644 --- a/projects/fal_client/tests/test_sync_client.py +++ b/projects/fal_client/tests/test_sync_client.py @@ -46,6 +46,15 @@ def test_fal_client(client: fal_client.SyncClient): assert client.status("fal-ai/fast-sdxl/image-to-image", handle.request_id) == status + output = client.subscribe( + "fal-ai/fast-sdxl", + arguments={ + "prompt": "a cat", + }, + hint="lora:a", + ) + assert len(output["images"]) == 1 + output = client.run( "fal-ai/fast-sdxl", arguments={