Skip to content

Commit 53adf19

Browse files
authored
Merge pull request #180 from dispatchrun/run-forever-args
Optionally run a coroutine when using `dispatch.run_forever`
2 parents 90b97d9 + 57f3772 commit 53adf19

File tree

1 file changed

+23
-5
lines changed

1 file changed

+23
-5
lines changed

src/dispatch/__init__.py

+23-5
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import asyncio
66
import os
77
from http.server import ThreadingHTTPServer
8-
from typing import Any, Callable, Coroutine, Optional, TypeVar, overload
8+
from typing import Any, Awaitable, Callable, Coroutine, Optional, TypeVar, overload
99
from urllib.parse import urlsplit
1010

1111
from typing_extensions import ParamSpec, TypeAlias
@@ -96,7 +96,7 @@ async def main(coro: Coroutine[Any, Any, T], addr: Optional[str] = None) -> T:
9696

9797

9898
def run(coro: Coroutine[Any, Any, T], addr: Optional[str] = None) -> T:
99-
"""Run the default dispatch server. The default server uses a function
99+
"""Run the default Dispatch server. The default server uses a function
100100
registry where functions tagged by the `@dispatch.function` decorator are
101101
registered.
102102
@@ -119,9 +119,27 @@ def run(coro: Coroutine[Any, Any, T], addr: Optional[str] = None) -> T:
119119
return asyncio.run(main(coro, addr))
120120

121121

122-
def run_forever():
123-
"""Run the default dispatch server forever."""
124-
return run(asyncio.Event().wait())
122+
def run_forever(
123+
coro: Optional[Coroutine[Any, Any, T]] = None, addr: Optional[str] = None
124+
):
125+
"""Run the default Dispatch server forever.
126+
127+
Args:
128+
coro: A coroutine to optionally run as the entrypoint.
129+
130+
addr: The address to bind the server to. If not provided, the server
131+
will bind to the address specified by the `DISPATCH_ENDPOINT_ADDR`
132+
environment variable. If the environment variable is not set, the
133+
server will bind to `localhost:8000`.
134+
"""
135+
wait = asyncio.Event().wait()
136+
coro = chain(coro, wait) if coro is not None else wait
137+
return run(coro=coro, addr=addr)
138+
139+
140+
async def chain(*awaitables: Awaitable[Any]):
141+
for a in awaitables:
142+
await a
125143

126144

127145
def batch() -> Batch:

0 commit comments

Comments
 (0)