Skip to content

Commit

Permalink
Fix #470: Prevent error `AsyncToSync.main_wrap() got multiple values …
Browse files Browse the repository at this point in the history
…for argument '<kwarg>'` (#471)

Replace partial with pre-created awaitable to ensure kwargs/args can't be stepped on
  • Loading branch information
Krismix1 authored Sep 30, 2024
1 parent b7aaa79 commit a916061
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 6 deletions.
15 changes: 9 additions & 6 deletions asgiref/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,12 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R:
sys.exc_info(),
task_context,
context,
*args,
**kwargs,
# prepare an awaitable which can be passed as is to self.main_wrap,
# so that `args` and `kwargs` don't need to be
# destructured when passed to self.main_wrap
# (which is required by `ParamSpec`)
# as that may cause overlapping arguments
self.awaitable(*args, **kwargs),
)

if not (self.main_event_loop and self.main_event_loop.is_running()):
Expand Down Expand Up @@ -302,8 +306,7 @@ async def main_wrap(
exc_info: "OptExcInfo",
task_context: "Optional[List[asyncio.Task[Any]]]",
context: List[contextvars.Context],
*args: _P.args,
**kwargs: _P.kwargs,
awaitable: Union[Coroutine[Any, Any, _R], Awaitable[_R]],
) -> None:
"""
Wraps the awaitable with something that puts the result into the
Expand All @@ -326,9 +329,9 @@ async def main_wrap(
try:
raise exc_info[1]
except BaseException:
result = await self.awaitable(*args, **kwargs)
result = await awaitable
else:
result = await self.awaitable(*args, **kwargs)
result = await awaitable
except BaseException as e:
call_result.set_exception(e)
else:
Expand Down
34 changes: 34 additions & 0 deletions tests/test_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import warnings
from concurrent.futures import ThreadPoolExecutor
from functools import wraps
from typing import Any
from unittest import TestCase

import pytest
Expand Down Expand Up @@ -1174,3 +1175,36 @@ async def async_task():
assert task_complete

assert task_executed


def test_async_to_sync_overlapping_kwargs() -> None:
"""
Tests that AsyncToSync correctly passes through kwargs to the wrapped function,
particularly in the case where the wrapped function uses same names for the parameters
as the wrapper.
"""

@async_to_sync
async def test_function(**kwargs: Any) -> None:
assert kwargs

# AsyncToSync.main_wrap has a param named `context`.
# So we pass the same argument here to test for the error
# "AsyncToSync.main_wrap() got multiple values for argument '<kwarg>'"
test_function(context=1)


@pytest.mark.asyncio
async def test_sync_to_async_overlapping_kwargs() -> None:
"""
Tests that SyncToAsync correctly passes through kwargs to the wrapped function,
particularly in the case where the wrapped function uses same names for the parameters
as the wrapper.
"""

@sync_to_async
def test_function(**kwargs: Any) -> None:
assert kwargs

# SyncToAsync.__call__.loop.run_in_executor has a param named `task_context`.
await test_function(task_context=1)

0 comments on commit a916061

Please sign in to comment.