Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PYTHON-4782 Fix deadlock and blocking behavior in _ACondition.wait #1875

Merged
merged 12 commits into from
Sep 30, 2024
Merged
7 changes: 4 additions & 3 deletions pymongo/asynchronous/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -992,7 +992,8 @@ def __init__(
# from the right side.
self.conns: collections.deque = collections.deque()
self.active_contexts: set[_CancellationContext] = set()
self.lock = _ALock(_create_lock())
_lock = _create_lock()
self.lock = _ALock(_lock)
self.active_sockets = 0
# Monotonically increasing connection ID required for CMAP Events.
self.next_connection_id = 1
Expand All @@ -1018,15 +1019,15 @@ def __init__(
# The first portion of the wait queue.
# Enforces: maxPoolSize
# Also used for: clearing the wait queue
self.size_cond = _ACondition(threading.Condition(self.lock)) # type: ignore[arg-type]
self.size_cond = _ACondition(threading.Condition(_lock))
self.requests = 0
self.max_pool_size = self.opts.max_pool_size
if not self.max_pool_size:
self.max_pool_size = float("inf")
# The second portion of the wait queue.
# Enforces: maxConnecting
# Also used for: clearing the wait queue
self._max_connecting_cond = _ACondition(threading.Condition(self.lock)) # type: ignore[arg-type]
self._max_connecting_cond = _ACondition(threading.Condition(_lock))
self._max_connecting = self.opts.max_connecting
self._pending = 0
self._client_id = client_id
Expand Down
5 changes: 3 additions & 2 deletions pymongo/asynchronous/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,9 @@ def __init__(self, topology_settings: TopologySettings):
self._seed_addresses = list(topology_description.server_descriptions())
self._opened = False
self._closed = False
self._lock = _ALock(_create_lock())
self._condition = _ACondition(self._settings.condition_class(self._lock)) # type: ignore[arg-type]
_lock = _create_lock()
self._lock = _ALock(_lock)
self._condition = _ACondition(self._settings.condition_class(_lock))
self._servers: dict[_Address, Server] = {}
self._pid: Optional[int] = None
self._max_cluster_time: Optional[ClusterTime] = None
Expand Down
147 changes: 126 additions & 21 deletions pymongo/lock.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,20 @@
from __future__ import annotations

import asyncio
import collections
import os
import threading
import time
import weakref
from typing import Any, Callable, Optional
from typing import Any, Callable, Optional, TypeVar

_HAS_REGISTER_AT_FORK = hasattr(os, "register_at_fork")

# References to instances of _create_lock
_forkable_locks: weakref.WeakSet[threading.Lock] = weakref.WeakSet()

_T = TypeVar("_T")


def _create_lock() -> threading.Lock:
"""Represents a lock that is tracked upon instantiation using a WeakSet and
Expand All @@ -43,7 +46,14 @@ def _release_locks() -> None:
lock.release()


# TODO: remove this.
def _Lock(lock: threading.Lock) -> threading.Lock:
return lock


class _ALock:
__slots__ = ("_lock",)

def __init__(self, lock: threading.Lock) -> None:
self._lock = lock

Expand Down Expand Up @@ -81,9 +91,18 @@ async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None:
self.release()


def _safe_set_result(fut: asyncio.Future) -> None:
# Ensure the future hasn't been cancelled before calling set_result.
if not fut.done():
fut.set_result(False)


class _ACondition:
__slots__ = ("_condition", "_waiters")

def __init__(self, condition: threading.Condition) -> None:
self._condition = condition
self._waiters: collections.deque = collections.deque()

async def acquire(self, blocking: bool = True, timeout: float = -1) -> bool:
if timeout > 0:
Expand All @@ -99,30 +118,116 @@ async def acquire(self, blocking: bool = True, timeout: float = -1) -> bool:
await asyncio.sleep(0)

async def wait(self, timeout: Optional[float] = None) -> bool:
if timeout is not None:
tstart = time.monotonic()
while True:
notified = self._condition.wait(0.001)
if notified:
return True
if timeout is not None and (time.monotonic() - tstart) > timeout:
return False

async def wait_for(self, predicate: Callable, timeout: Optional[float] = None) -> bool:
if timeout is not None:
tstart = time.monotonic()
while True:
notified = self._condition.wait_for(predicate, 0.001)
if notified:
return True
if timeout is not None and (time.monotonic() - tstart) > timeout:
return False
"""Wait until notified.

If the calling task has not acquired the lock when this
method is called, a RuntimeError is raised.

This method releases the underlying lock, and then blocks
until it is awakened by a notify() or notify_all() call for
the same condition variable in another task. Once
awakened, it re-acquires the lock and returns True.

This method may return spuriously,
which is why the caller should always
re-check the state and be prepared to wait() again.
"""
loop = asyncio.get_running_loop()
fut = loop.create_future()
self._waiters.append((loop, fut))
self.release()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this just to ensure we don't hold the lock while waiting for it?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes without releasing the lock, this code would deadlock since nothing would be able to notify the waiter.

try:
try:
try:
await asyncio.wait_for(fut, timeout)
return True
except asyncio.TimeoutError:
return False # Return false on timeout for sync pool compat.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we still need to acquire the lock if we timeout here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the API contract for wait() says you MUST hold the lock before calling and you MUST still hold the lock when it returns, even on timeout.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, so you always have to acquire the lock if you call wait and don't raise an error?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a comment linking to the API contract here? It would be good to make it more understandable for readers unfamiliar with it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's already this comment:

# Must re-acquire lock even if wait is cancelled.

finally:
# Must re-acquire lock even if wait is cancelled.
# We only catch CancelledError here, since we don't want any
# other (fatal) errors with the future to cause us to spin.
err = None
while True:
try:
await self.acquire()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this possibly loop forever?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes but only if something else holds the lock forever.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And we're asserting that this new code can't do that, got it!

break
except asyncio.exceptions.CancelledError as e:
err = e

self._waiters.remove((loop, fut))
if err is not None:
try:
raise err # Re-raise most recent exception instance.
finally:
err = None # Break reference cycles.
except BaseException:
# Any error raised out of here _may_ have occurred after this Task
# believed to have been successfully notified.
# Make sure to notify another Task instead. This may result
# in a "spurious wakeup", which is allowed as part of the
# Condition Variable protocol.
self.notify(1)
raise

async def wait_for(self, predicate: Callable[[], _T]) -> _T:
"""Wait until a predicate becomes true.

The predicate should be a callable whose result will be
interpreted as a boolean value. The method will repeatedly
wait() until it evaluates to true. The final predicate value is
the return value.
"""
result = predicate()
while not result:
await self.wait()
result = predicate()
return result

def notify(self, n: int = 1) -> None:
self._condition.notify(n)
"""By default, wake up one coroutine waiting on this condition, if any.
If the calling coroutine has not acquired the lock when this method
is called, a RuntimeError is raised.

This method wakes up at most n of the coroutines waiting for the
condition variable; it is a no-op if no coroutines are waiting.

Note: an awakened coroutine does not actually return from its
wait() call until it can reacquire the lock. Since notify() does
not release the lock, its caller should.
"""
idx = 0
to_remove = []
for loop, fut in self._waiters:
if idx >= n:
break

if fut.done():
continue

try:
loop.call_soon_threadsafe(_safe_set_result, fut)
except RuntimeError:
# Loop was closed, ignore.
to_remove.append((loop, fut))
continue

idx += 1

for waiter in to_remove:
self._waiters.remove(waiter)

def notify_all(self) -> None:
self._condition.notify_all()
"""Wake up all threads waiting on this condition. This method acts
like notify(), but wakes up all waiting threads instead of one. If the
calling thread has not acquired the lock when this method is called,
a RuntimeError is raised.
"""
self.notify(len(self._waiters))

def locked(self) -> bool:
"""Only needed for tests in test_locks."""
return self._condition._lock.locked() # type: ignore[attr-defined]

def release(self) -> None:
self._condition.release()
Expand Down
9 changes: 5 additions & 4 deletions pymongo/synchronous/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
_CertificateError,
)
from pymongo.hello import Hello, HelloCompat
from pymongo.lock import _create_lock
from pymongo.lock import _create_lock, _Lock
from pymongo.logger import (
_CONNECTION_LOGGER,
_ConnectionStatusMessage,
Expand Down Expand Up @@ -988,7 +988,8 @@ def __init__(
# from the right side.
self.conns: collections.deque = collections.deque()
self.active_contexts: set[_CancellationContext] = set()
self.lock = _create_lock()
_lock = _create_lock()
self.lock = _Lock(_lock)
self.active_sockets = 0
# Monotonically increasing connection ID required for CMAP Events.
self.next_connection_id = 1
Expand All @@ -1014,15 +1015,15 @@ def __init__(
# The first portion of the wait queue.
# Enforces: maxPoolSize
# Also used for: clearing the wait queue
self.size_cond = threading.Condition(self.lock) # type: ignore[arg-type]
self.size_cond = threading.Condition(_lock)
self.requests = 0
self.max_pool_size = self.opts.max_pool_size
if not self.max_pool_size:
self.max_pool_size = float("inf")
# The second portion of the wait queue.
# Enforces: maxConnecting
# Also used for: clearing the wait queue
self._max_connecting_cond = threading.Condition(self.lock) # type: ignore[arg-type]
self._max_connecting_cond = threading.Condition(_lock)
self._max_connecting = self.opts.max_connecting
self._pending = 0
self._client_id = client_id
Expand Down
7 changes: 4 additions & 3 deletions pymongo/synchronous/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
WriteError,
)
from pymongo.hello import Hello
from pymongo.lock import _create_lock
from pymongo.lock import _create_lock, _Lock
from pymongo.logger import (
_SDAM_LOGGER,
_SERVER_SELECTION_LOGGER,
Expand Down Expand Up @@ -170,8 +170,9 @@ def __init__(self, topology_settings: TopologySettings):
self._seed_addresses = list(topology_description.server_descriptions())
self._opened = False
self._closed = False
self._lock = _create_lock()
self._condition = self._settings.condition_class(self._lock) # type: ignore[arg-type]
_lock = _create_lock()
self._lock = _Lock(_lock)
self._condition = self._settings.condition_class(_lock)
self._servers: dict[_Address, Server] = {}
self._pid: Optional[int] = None
self._max_cluster_time: Optional[ClusterTime] = None
Expand Down
4 changes: 3 additions & 1 deletion test/asynchronous/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2351,7 +2351,9 @@ async def test_reconnect(self):

# But it can reconnect.
c.revive_host("a:1")
await (await c._get_topology()).select_servers(writable_server_selector, _Op.TEST)
await (await c._get_topology()).select_servers(
writable_server_selector, _Op.TEST, server_selection_timeout=10
)
self.assertEqual(await c.address, ("a", 1))

async def _test_network_error(self, operation_callback):
Expand Down
4 changes: 2 additions & 2 deletions test/asynchronous/test_cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1414,7 +1414,7 @@ async def test_to_list_length(self):
async def test_to_list_csot_applied(self):
client = await self.async_single_client(timeoutMS=500)
# Initialize the client with a larger timeout to help make test less flakey
with pymongo.timeout(2):
with pymongo.timeout(10):
await client.admin.command("ping")
coll = client.pymongo.test
await coll.insert_many([{} for _ in range(5)])
Expand Down Expand Up @@ -1456,7 +1456,7 @@ async def test_command_cursor_to_list_length(self):
async def test_command_cursor_to_list_csot_applied(self):
client = await self.async_single_client(timeoutMS=500)
# Initialize the client with a larger timeout to help make test less flakey
with pymongo.timeout(2):
with pymongo.timeout(10):
await client.admin.command("ping")
coll = client.pymongo.test
await coll.insert_many([{} for _ in range(5)])
Expand Down
Loading
Loading