Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix cancellation tests with new Twisted. #17906

Merged
merged 5 commits into from
Nov 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/17906.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix tests to run with latest Twisted.
107 changes: 83 additions & 24 deletions tests/http/server/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
Callable,
ContextManager,
Dict,
Generator,
List,
Optional,
Set,
Expand All @@ -49,7 +50,10 @@
respond_with_json,
)
from synapse.http.site import SynapseRequest
from synapse.logging.context import LoggingContext, make_deferred_yieldable
from synapse.logging.context import (
LoggingContext,
make_deferred_yieldable,
)
from synapse.types import JsonDict

from tests.server import FakeChannel, make_request
Expand Down Expand Up @@ -199,7 +203,7 @@ def make_request_with_cancellation_test(
#
# We would like to trigger a cancellation at the first `await`, re-run the
# request and cancel at the second `await`, and so on. By patching
# `Deferred.__next__`, we can intercept `await`s, track which ones we have or
# `Deferred.__await__`, we can intercept `await`s, track which ones we have or
# have not seen, and force them to block when they wouldn't have.

# The set of previously seen `await`s.
Expand All @@ -211,7 +215,7 @@ def make_request_with_cancellation_test(
)

for request_number in itertools.count(1):
deferred_patch = Deferred__next__Patch(seen_awaits, request_number)
deferred_patch = Deferred__await__Patch(seen_awaits, request_number)

try:
with mock.patch(
Expand Down Expand Up @@ -250,6 +254,8 @@ def make_request_with_cancellation_test(
)

if respond_mock.called:
_log_for_request(request_number, "--- response finished ---")

# The request ran to completion and we are done with testing it.

# `respond_with_json` writes the response asynchronously, so we
Expand Down Expand Up @@ -311,8 +317,8 @@ def make_request_with_cancellation_test(
assert False, "unreachable" # noqa: B011


class Deferred__next__Patch:
"""A `Deferred.__next__` patch that will intercept `await`s and force them
class Deferred__await__Patch:
"""A `Deferred.__await__` patch that will intercept `await`s and force them
to block once it sees a new `await`.

When done with the patch, `unblock_awaits()` must be called to clean up after any
Expand All @@ -322,7 +328,7 @@ class Deferred__next__Patch:

Usage:
seen_awaits = set()
deferred_patch = Deferred__next__Patch(seen_awaits, 1)
deferred_patch = Deferred__await__Patch(seen_awaits, 1)
try:
with deferred_patch.patch():
# do things
Expand All @@ -335,23 +341,28 @@ def __init__(self, seen_awaits: Set[Tuple[str, ...]], request_number: int):
"""
Args:
seen_awaits: The set of stack traces of `await`s that have been previously
seen. When the `Deferred.__next__` patch sees a new `await`, it will add
seen. When the `Deferred.__await__` patch sees a new `await`, it will add
it to the set.
request_number: The request number to log against.
"""
self._request_number = request_number
self._seen_awaits = seen_awaits

self._original_Deferred___next__ = Deferred.__next__ # type: ignore[misc,unused-ignore]
self._original_Deferred__await__ = Deferred.__await__ # type: ignore[misc,unused-ignore]

# The number of `await`s on `Deferred`s we have seen so far.
self.awaits_seen = 0

# Whether we have seen a new `await` not in `seen_awaits`.
self.new_await_seen = False

# Whether to block new await points we see. This gets set to False once
# we have cancelled the request to allow things to run after
# cancellation.
self._block_new_awaits = True

# To force `await`s on resolved `Deferred`s to block, we make up a new
# unresolved `Deferred` and return it out of `Deferred.__next__` /
# unresolved `Deferred` and return it out of `Deferred.__await__` /
# `coroutine.send()`. We have to resolve it later, in case the `await`ing
# coroutine is part of some shared processing, such as `@cached`.
self._to_unblock: Dict[Deferred, Union[object, Failure]] = {}
Expand All @@ -360,25 +371,59 @@ def __init__(self, seen_awaits: Set[Tuple[str, ...]], request_number: int):
self._previous_stack: List[inspect.FrameInfo] = []

def patch(self) -> ContextManager[Mock]:
"""Returns a context manager which patches `Deferred.__next__`."""
"""Returns a context manager which patches `Deferred.__await__`."""

def Deferred___next__(
deferred: "Deferred[T]", value: object = None
) -> "Deferred[T]":
"""Intercepts `await`s on `Deferred`s and rigs them to block once we have
seen enough of them.
def Deferred___await__(
deferred: "Deferred[T]",
) -> Generator["Deferred[T]", None, T]:
"""Intercepts calls to `__await__`, which returns a generator
yielding deferreds that we await on.

`Deferred.__next__` will normally:
The generator for `__await__` will normally:
* return `self` if the `Deferred` is unresolved, in which case
`coroutine.send()` will return the `Deferred`, and
`_defer.inlineCallbacks` will stop running the coroutine until the
`Deferred` is resolved.
* raise a `StopIteration(result)`, containing the result of the `await`.
* raise another exception, which will come out of the `await`.
"""

# Get the original generator.
gen = self._original_Deferred__await__(deferred)

# Run the generator, handling each iteration to see if we need to
# block.
try:
while True:
# We've hit a new await point (or the deferred has
# completed), handle it.
handle_next_iteration(deferred)

# Continue on.
yield gen.send(None)
except StopIteration as e:
# We need to convert `StopIteration` into a normal return.
return e.value

def handle_next_iteration(
deferred: "Deferred[T]",
) -> None:
"""Intercepts `await`s on `Deferred`s and rigs them to block once we have
seen enough of them.

Args:
deferred: The deferred that we've captured and are intercepting
`await` calls within.
"""
if not self._block_new_awaits:
# We're no longer blocking awaits points
return

self.awaits_seen += 1

stack = _get_stack(skip_frames=1)
stack = _get_stack(
skip_frames=2 # Ignore this function and `Deferred___await__` in stack trace
)
stack_hash = _hash_stack(stack)

if stack_hash not in self._seen_awaits:
Expand All @@ -389,20 +434,29 @@ def Deferred___next__(
if not self.new_await_seen:
# This `await` isn't interesting. Let it proceed normally.

_log_await_stack(
stack,
self._previous_stack,
self._request_number,
"already seen",
)

# Don't log the stack. It's been seen before in a previous run.
self._previous_stack = stack

return self._original_Deferred___next__(deferred, value)
return

# We want to block at the current `await`.
if deferred.called and not deferred.paused:
# This `Deferred` already has a result.
# We return a new, unresolved, `Deferred` for `_inlineCallbacks` to wait
# on. This blocks the coroutine that did this `await`.
# This `Deferred` already has a result. We chain a new,
# unresolved, `Deferred` to the end of this Deferred that it
# will wait on. This blocks the coroutine that did this `await`.
# We queue it up for unblocking later.
new_deferred: "Deferred[T]" = Deferred()
self._to_unblock[new_deferred] = deferred.result

deferred.addBoth(lambda _: make_deferred_yieldable(new_deferred))

_log_await_stack(
stack,
self._previous_stack,
Expand All @@ -411,7 +465,9 @@ def Deferred___next__(
)
self._previous_stack = stack

return make_deferred_yieldable(new_deferred)
# Continue iterating on the deferred now that we've blocked it
# again.
return

# This `Deferred` does not have a result yet.
# The `await` will block normally, so we don't have to do anything.
Expand All @@ -423,16 +479,19 @@ def Deferred___next__(
)
self._previous_stack = stack

return self._original_Deferred___next__(deferred, value)
return

return mock.patch.object(Deferred, "__next__", new=Deferred___next__)
return mock.patch.object(Deferred, "__await__", new=Deferred___await__)

def unblock_awaits(self) -> None:
"""Unblocks any shared processing that we forced to block.

Must be called when done, otherwise processing shared between multiple requests,
such as database queries started by `@cached`, will become permanently stuck.
"""
# Also disable blocking at future await points
self._block_new_awaits = False

to_unblock = self._to_unblock
self._to_unblock = {}
for deferred, result in to_unblock.items():
Expand Down
Loading