From d7dc659237303aeedbc3fb3f11f86707910b013e Mon Sep 17 00:00:00 2001 From: Shane Harvey Date: Fri, 20 Sep 2024 16:09:01 -0700 Subject: [PATCH 01/12] PYTHON-4782 Fix deadlock and blocking behavior in _ACondition.wait --- pymongo/lock.py | 141 +++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 116 insertions(+), 25 deletions(-) diff --git a/pymongo/lock.py b/pymongo/lock.py index b05f6acffb..aa1ed6fd61 100644 --- a/pymongo/lock.py +++ b/pymongo/lock.py @@ -14,11 +14,12 @@ from __future__ import annotations import asyncio +import collections import os import threading import time import weakref -from typing import Any, Callable, Optional +from typing import Any, Optional _HAS_REGISTER_AT_FORK = hasattr(os, "register_at_fork") @@ -44,6 +45,8 @@ def _release_locks() -> None: class _ALock: + __slots__ = ("_lock",) + def __init__(self, lock: threading.Lock) -> None: self._lock = lock @@ -82,8 +85,11 @@ async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None: class _ACondition: + __slots__ = ("_condition", "_waiters") + def __init__(self, condition: threading.Condition) -> None: self._condition = condition + self._waiters = collections.deque() async def acquire(self, blocking: bool = True, timeout: float = -1) -> bool: if timeout > 0: @@ -99,30 +105,115 @@ async def acquire(self, blocking: bool = True, timeout: float = -1) -> bool: await asyncio.sleep(0) async def wait(self, timeout: Optional[float] = None) -> bool: - if timeout is not None: - tstart = time.monotonic() - while True: - notified = self._condition.wait(0.001) - if notified: - return True - if timeout is not None and (time.monotonic() - tstart) > timeout: - return False - - async def wait_for(self, predicate: Callable, timeout: Optional[float] = None) -> bool: - if timeout is not None: - tstart = time.monotonic() - while True: - notified = self._condition.wait_for(predicate, 0.001) - if notified: - return True - if timeout is not None and (time.monotonic() - tstart) > timeout: - return False - - def notify(self, n: int = 1) -> None: - self._condition.notify(n) - - def notify_all(self) -> None: - self._condition.notify_all() + """Wait until notified. + + If the calling task has not acquired the lock when this + method is called, a RuntimeError is raised. + + This method releases the underlying lock, and then blocks + until it is awakened by a notify() or notify_all() call for + the same condition variable in another task. Once + awakened, it re-acquires the lock and returns True. + + This method may return spuriously, + which is why the caller should always + re-check the state and be prepared to wait() again. + """ + loop = asyncio.get_running_loop() + fut = loop.create_future() + self._waiters.append((loop, fut)) + self.release() + try: + try: + try: + await asyncio.wait_for(fut, timeout) + return True + except asyncio.TimeoutError: + return False # Return false on timeout for sync pool compat. + finally: + # Must re-acquire lock even if wait is cancelled. + # We only catch CancelledError here, since we don't want any + # other (fatal) errors with the future to cause us to spin. + err = None + while True: + try: + await self.acquire() + break + except asyncio.exceptions.CancelledError as e: + err = e + + self._waiters.remove((loop, fut)) + if err is not None: + try: + raise err # Re-raise most recent exception instance. + finally: + err = None # Break reference cycles. + except BaseException: + # Any error raised out of here _may_ have occurred after this Task + # believed to have been successfully notified. + # Make sure to notify another Task instead. This may result + # in a "spurious wakeup", which is allowed as part of the + # Condition Variable protocol. + self.notify(1) + raise + + async def wait_for(self, predicate): + """Wait until a predicate becomes true. + + The predicate should be a callable which result will be + interpreted as a boolean value. The final predicate value is + the return value. + """ + result = predicate() + while not result: + await self.wait() + result = predicate() + return result + + def notify(self, n=1): + """By default, wake up one coroutine waiting on this condition, if any. + If the calling coroutine has not acquired the lock when this method + is called, a RuntimeError is raised. + + This method wakes up at most n of the coroutines waiting for the + condition variable; it is a no-op if no coroutines are waiting. + + Note: an awakened coroutine does not actually return from its + wait() call until it can reacquire the lock. Since notify() does + not release the lock, its caller should. + """ + idx = 0 + to_remove = [] + for loop, fut in self._waiters: + if idx >= n: + break + + if fut.done(): + continue + + def safe_set_result(fut): + if not fut.done(): + fut.set_result(False) + + try: + loop.call_soon_threadsafe(safe_set_result, fut) + except RuntimeError: + # Loop was closed, ignore. + to_remove.append((loop, fut)) + continue + + idx += 1 + + for waiter in to_remove: + self._waiters.remove(waiter) + + def notify_all(self): + """Wake up all threads waiting on this condition. This method acts + like notify(), but wakes up all waiting threads instead of one. If the + calling thread has not acquired the lock when this method is called, + a RuntimeError is raised. + """ + self.notify(len(self._waiters)) def release(self) -> None: self._condition.release() From 514432606ecb450f39f73e235e4aa571d6089114 Mon Sep 17 00:00:00 2001 From: Shane Harvey Date: Mon, 23 Sep 2024 13:16:49 -0700 Subject: [PATCH 02/12] PYTHON-4782 Cleanup --- pymongo/lock.py | 31 ++++++++++--------------------- 1 file changed, 10 insertions(+), 21 deletions(-) diff --git a/pymongo/lock.py b/pymongo/lock.py index aa1ed6fd61..d23d25f062 100644 --- a/pymongo/lock.py +++ b/pymongo/lock.py @@ -84,12 +84,18 @@ async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None: self.release() +def _safe_set_result(fut: asyncio.Future) -> None: + # Ensure the future hasn't been cancelled before calling set_result. + if not fut.done(): + fut.set_result(False) + + class _ACondition: __slots__ = ("_condition", "_waiters") def __init__(self, condition: threading.Condition) -> None: self._condition = condition - self._waiters = collections.deque() + self._waiters: collections.deque = collections.deque() async def acquire(self, blocking: bool = True, timeout: float = -1) -> bool: if timeout > 0: @@ -157,20 +163,7 @@ async def wait(self, timeout: Optional[float] = None) -> bool: self.notify(1) raise - async def wait_for(self, predicate): - """Wait until a predicate becomes true. - - The predicate should be a callable which result will be - interpreted as a boolean value. The final predicate value is - the return value. - """ - result = predicate() - while not result: - await self.wait() - result = predicate() - return result - - def notify(self, n=1): + def notify(self, n: int = 1) -> None: """By default, wake up one coroutine waiting on this condition, if any. If the calling coroutine has not acquired the lock when this method is called, a RuntimeError is raised. @@ -191,12 +184,8 @@ def notify(self, n=1): if fut.done(): continue - def safe_set_result(fut): - if not fut.done(): - fut.set_result(False) - try: - loop.call_soon_threadsafe(safe_set_result, fut) + loop.call_soon_threadsafe(_safe_set_result, fut) except RuntimeError: # Loop was closed, ignore. to_remove.append((loop, fut)) @@ -207,7 +196,7 @@ def safe_set_result(fut): for waiter in to_remove: self._waiters.remove(waiter) - def notify_all(self): + def notify_all(self) -> None: """Wake up all threads waiting on this condition. This method acts like notify(), but wakes up all waiting threads instead of one. If the calling thread has not acquired the lock when this method is called, From 45ca4395e3cd3ad9c9c78a4da53c90124f1f883b Mon Sep 17 00:00:00 2001 From: Shane Harvey Date: Mon, 23 Sep 2024 18:10:17 -0700 Subject: [PATCH 03/12] PYTHON-4782 Add tests for _ACondition --- pymongo/lock.py | 22 +- test/asynchronous/test_locks.py | 505 ++++++++++++++++++++++++++++++++ 2 files changed, 526 insertions(+), 1 deletion(-) create mode 100644 test/asynchronous/test_locks.py diff --git a/pymongo/lock.py b/pymongo/lock.py index d23d25f062..22cbf8d96b 100644 --- a/pymongo/lock.py +++ b/pymongo/lock.py @@ -19,13 +19,15 @@ import threading import time import weakref -from typing import Any, Optional +from typing import Any, Callable, Optional, TypeVar _HAS_REGISTER_AT_FORK = hasattr(os, "register_at_fork") # References to instances of _create_lock _forkable_locks: weakref.WeakSet[threading.Lock] = weakref.WeakSet() +_T = TypeVar("_T") + def _create_lock() -> threading.Lock: """Represents a lock that is tracked upon instantiation using a WeakSet and @@ -163,6 +165,20 @@ async def wait(self, timeout: Optional[float] = None) -> bool: self.notify(1) raise + async def wait_for(self, predicate: Callable[[], _T]) -> _T: + """Wait until a predicate becomes true. + + The predicate should be a callable whose result will be + interpreted as a boolean value. The method will repeatedly + wait() until it evaluates to true. The final predicate value is + the return value. + """ + result = predicate() + while not result: + await self.wait() + result = predicate() + return result + def notify(self, n: int = 1) -> None: """By default, wake up one coroutine waiting on this condition, if any. If the calling coroutine has not acquired the lock when this method @@ -204,6 +220,10 @@ def notify_all(self) -> None: """ self.notify(len(self._waiters)) + def locked(self) -> bool: + """Only needed for tests in test_locks.""" + return self._condition._lock.locked() # type: ignore[attr-defined] + def release(self) -> None: self._condition.release() diff --git a/test/asynchronous/test_locks.py b/test/asynchronous/test_locks.py new file mode 100644 index 0000000000..9233a83b03 --- /dev/null +++ b/test/asynchronous/test_locks.py @@ -0,0 +1,505 @@ +# Copyright 2024-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for lock.py""" +from __future__ import annotations + +import asyncio +import sys +import threading +import unittest + +sys.path[0:0] = [""] + +from pymongo.lock import _ACondition + + +# Tests adapted from: https://github.com/python/cpython/blob/v3.13.0rc2/Lib/test/test_asyncio/test_locks.py +# Includes tests for: +# - https://github.com/python/cpython/issues/111693 +# - https://github.com/python/cpython/issues/112202 +class TestConditionStdlib(unittest.IsolatedAsyncioTestCase): + async def test_wait(self): + cond = _ACondition(threading.Condition(threading.Lock())) + result = [] + + async def c1(result): + await cond.acquire() + if await cond.wait(): + result.append(1) + return True + + async def c2(result): + await cond.acquire() + if await cond.wait(): + result.append(2) + return True + + async def c3(result): + await cond.acquire() + if await cond.wait(): + result.append(3) + return True + + t1 = asyncio.create_task(c1(result)) + t2 = asyncio.create_task(c2(result)) + t3 = asyncio.create_task(c3(result)) + + await asyncio.sleep(0) + self.assertEqual([], result) + self.assertFalse(cond.locked()) + + self.assertTrue(await cond.acquire()) + cond.notify() + await asyncio.sleep(0) + self.assertEqual([], result) + self.assertTrue(cond.locked()) + + cond.release() + await asyncio.sleep(0) + self.assertEqual([1], result) + self.assertTrue(cond.locked()) + + cond.notify(2) + await asyncio.sleep(0) + self.assertEqual([1], result) + self.assertTrue(cond.locked()) + + cond.release() + await asyncio.sleep(0) + self.assertEqual([1, 2], result) + self.assertTrue(cond.locked()) + + cond.release() + await asyncio.sleep(0) + self.assertEqual([1, 2, 3], result) + self.assertTrue(cond.locked()) + + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + self.assertTrue(t2.done()) + self.assertTrue(t2.result()) + self.assertTrue(t3.done()) + self.assertTrue(t3.result()) + + async def test_wait_cancel(self): + cond = _ACondition(threading.Condition(threading.Lock())) + await cond.acquire() + + wait = asyncio.create_task(cond.wait()) + asyncio.get_running_loop().call_soon(wait.cancel) + with self.assertRaises(asyncio.CancelledError): + await wait + self.assertFalse(cond._waiters) + self.assertTrue(cond.locked()) + + async def test_wait_cancel_contested(self): + cond = _ACondition(threading.Condition(threading.Lock())) + + await cond.acquire() + self.assertTrue(cond.locked()) + + wait_task = asyncio.create_task(cond.wait()) + await asyncio.sleep(0) + self.assertFalse(cond.locked()) + + # Notify, but contest the lock before cancelling + await cond.acquire() + self.assertTrue(cond.locked()) + cond.notify() + asyncio.get_running_loop().call_soon(wait_task.cancel) + asyncio.get_running_loop().call_soon(cond.release) + + try: + await wait_task + except asyncio.CancelledError: + # Should not happen, since no cancellation points + pass + + self.assertTrue(cond.locked()) + + async def test_wait_cancel_after_notify(self): + # See bpo-32841 + waited = False + + cond = _ACondition(threading.Condition(threading.Lock())) + + async def wait_on_cond(): + nonlocal waited + async with cond: + waited = True # Make sure this area was reached + await cond.wait() + + waiter = asyncio.create_task(wait_on_cond()) + await asyncio.sleep(0) # Start waiting + + await cond.acquire() + cond.notify() + await asyncio.sleep(0) # Get to acquire() + waiter.cancel() + await asyncio.sleep(0) # Activate cancellation + cond.release() + await asyncio.sleep(0) # Cancellation should occur + + self.assertTrue(waiter.cancelled()) + self.assertTrue(waited) + + async def test_wait_unacquired(self): + cond = _ACondition(threading.Condition(threading.Lock())) + with self.assertRaises(RuntimeError): + await cond.wait() + + async def test_wait_for(self): + cond = _ACondition(threading.Condition(threading.Lock())) + presult = False + + def predicate(): + return presult + + result = [] + + async def c1(result): + await cond.acquire() + if await cond.wait_for(predicate): + result.append(1) + cond.release() + return True + + t = asyncio.create_task(c1(result)) + + await asyncio.sleep(0) + self.assertEqual([], result) + + await cond.acquire() + cond.notify() + cond.release() + await asyncio.sleep(0) + self.assertEqual([], result) + + presult = True + await cond.acquire() + cond.notify() + cond.release() + await asyncio.sleep(0) + self.assertEqual([1], result) + + self.assertTrue(t.done()) + self.assertTrue(t.result()) + + async def test_wait_for_unacquired(self): + cond = _ACondition(threading.Condition(threading.Lock())) + + # predicate can return true immediately + res = await cond.wait_for(lambda: [1, 2, 3]) + self.assertEqual([1, 2, 3], res) + + with self.assertRaises(RuntimeError): + await cond.wait_for(lambda: False) + + async def test_notify(self): + cond = _ACondition(threading.Condition(threading.Lock())) + result = [] + + async def c1(result): + async with cond: + if await cond.wait(): + result.append(1) + return True + + async def c2(result): + async with cond: + if await cond.wait(): + result.append(2) + return True + + async def c3(result): + async with cond: + if await cond.wait(): + result.append(3) + return True + + t1 = asyncio.create_task(c1(result)) + t2 = asyncio.create_task(c2(result)) + t3 = asyncio.create_task(c3(result)) + + await asyncio.sleep(0) + self.assertEqual([], result) + + async with cond: + cond.notify(1) + await asyncio.sleep(1) + self.assertEqual([1], result) + + async with cond: + cond.notify(1) + cond.notify(2048) + await asyncio.sleep(1) + self.assertEqual([1, 2, 3], result) + + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + self.assertTrue(t2.done()) + self.assertTrue(t2.result()) + self.assertTrue(t3.done()) + self.assertTrue(t3.result()) + + async def test_notify_all(self): + cond = _ACondition(threading.Condition(threading.Lock())) + + result = [] + + async def c1(result): + async with cond: + if await cond.wait(): + result.append(1) + return True + + async def c2(result): + async with cond: + if await cond.wait(): + result.append(2) + return True + + t1 = asyncio.create_task(c1(result)) + t2 = asyncio.create_task(c2(result)) + + await asyncio.sleep(0) + self.assertEqual([], result) + + async with cond: + cond.notify_all() + await asyncio.sleep(1) + self.assertEqual([1, 2], result) + + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + self.assertTrue(t2.done()) + self.assertTrue(t2.result()) + + async def test_context_manager(self): + cond = _ACondition(threading.Condition(threading.Lock())) + self.assertFalse(cond.locked()) + async with cond: + self.assertTrue(cond.locked()) + self.assertFalse(cond.locked()) + + async def test_timeout_in_block(self): + condition = _ACondition(threading.Condition(threading.Lock())) + async with condition: + with self.assertRaises(asyncio.TimeoutError): + await asyncio.wait_for(condition.wait(), timeout=0.5) + + async def test_cancelled_error_wakeup(self): + # Test that a cancelled error, received when awaiting wakeup, + # will be re-raised un-modified. + wake = False + raised = None + cond = _ACondition(threading.Condition(threading.Lock())) + + async def func(): + nonlocal raised + async with cond: + with self.assertRaises(asyncio.CancelledError) as err: + await cond.wait_for(lambda: wake) + raised = err.exception + raise raised + + task = asyncio.create_task(func()) + await asyncio.sleep(0) + # Task is waiting on the condition, cancel it there. + task.cancel(msg="foo") + with self.assertRaises(asyncio.CancelledError) as err: + await task + self.assertEqual(err.exception.args, ("foo",)) + # We should have got the _same_ exception instance as the one + # originally raised. + self.assertIs(err.exception, raised) + + async def test_cancelled_error_re_aquire(self): + # Test that a cancelled error, received when re-aquiring lock, + # will be re-raised un-modified. + wake = False + raised = None + cond = _ACondition(threading.Condition(threading.Lock())) + + async def func(): + nonlocal raised + async with cond: + with self.assertRaises(asyncio.CancelledError) as err: + await cond.wait_for(lambda: wake) + raised = err.exception + raise raised + + task = asyncio.create_task(func()) + await asyncio.sleep(0) + # Task is waiting on the condition + await cond.acquire() + wake = True + cond.notify() + await asyncio.sleep(0) + # Task is now trying to re-acquire the lock, cancel it there. + task.cancel(msg="foo") + cond.release() + with self.assertRaises(asyncio.CancelledError) as err: + await task + self.assertEqual(err.exception.args, ("foo",)) + # We should have got the _same_ exception instance as the one + # originally raised. + self.assertIs(err.exception, raised) + + async def test_cancelled_wakeup(self): + # Test that a task cancelled at the "same" time as it is woken + # up as part of a Condition.notify() does not result in a lost wakeup. + # This test simulates a cancel while the target task is awaiting initial + # wakeup on the wakeup queue. + condition = _ACondition(threading.Condition(threading.Lock())) + state = 0 + + async def consumer(): + nonlocal state + async with condition: + while True: + await condition.wait_for(lambda: state != 0) + if state < 0: + return + state -= 1 + + # create two consumers + c = [asyncio.create_task(consumer()) for _ in range(2)] + # wait for them to settle + await asyncio.sleep(0) + async with condition: + # produce one item and wake up one + state += 1 + condition.notify(1) + + # Cancel it while it is awaiting to be run. + # This cancellation could come from the outside + c[0].cancel() + + # now wait for the item to be consumed + # if it doesn't means that our "notify" didn"t take hold. + # because it raced with a cancel() + try: + async with asyncio.timeout(0.01): + await condition.wait_for(lambda: state == 0) + except TimeoutError: + pass + self.assertEqual(state, 0) + + # clean up + state = -1 + condition.notify_all() + await c[1] + + async def test_cancelled_wakeup_relock(self): + # Test that a task cancelled at the "same" time as it is woken + # up as part of a Condition.notify() does not result in a lost wakeup. + # This test simulates a cancel while the target task is acquiring the lock + # again. + condition = _ACondition(threading.Condition(threading.Lock())) + state = 0 + + async def consumer(): + nonlocal state + async with condition: + while True: + await condition.wait_for(lambda: state != 0) + if state < 0: + return + state -= 1 + + # create two consumers + c = [asyncio.create_task(consumer()) for _ in range(2)] + # wait for them to settle + await asyncio.sleep(0) + async with condition: + # produce one item and wake up one + state += 1 + condition.notify(1) + + # now we sleep for a bit. This allows the target task to wake up and + # settle on re-aquiring the lock + await asyncio.sleep(0) + + # Cancel it while awaiting the lock + # This cancel could come the outside. + c[0].cancel() + + # now wait for the item to be consumed + # if it doesn't means that our "notify" didn"t take hold. + # because it raced with a cancel() + try: + async with asyncio.timeout(0.01): + await condition.wait_for(lambda: state == 0) + except TimeoutError: + pass + self.assertEqual(state, 0) + + # clean up + state = -1 + condition.notify_all() + await c[1] + + +class TestCondition(unittest.IsolatedAsyncioTestCase): + async def test_multiple_loops_notify(self): + cond = _ACondition(threading.Condition(threading.Lock())) + + def tmain(cond): + async def atmain(cond): + await asyncio.sleep(1) + async with cond: + cond.notify(1) + + asyncio.run(atmain(cond)) + + t = threading.Thread(target=tmain, args=(cond,)) + t.start() + + async with cond: + self.assertTrue(await cond.wait(30)) + t.join() + + async def test_multiple_loops_notify_all(self): + cond = _ACondition(threading.Condition(threading.Lock())) + results = [] + + def tmain(cond, results): + async def atmain(cond, results): + await asyncio.sleep(1) + async with cond: + res = await cond.wait(30) + results.append(res) + + asyncio.run(atmain(cond, results)) + + nthreads = 5 + threads = [] + for _ in range(nthreads): + threads.append(threading.Thread(target=tmain, args=(cond, results))) + for t in threads: + t.start() + + await asyncio.sleep(2) + async with cond: + cond.notify_all() + + for t in threads: + t.join() + + self.assertEqual(results, [True] * nthreads) + + +if __name__ == "__main__": + unittest.main() From b2327b6f17582280cc8c8fda7cd538ae7243e687 Mon Sep 17 00:00:00 2001 From: Shane Harvey Date: Tue, 24 Sep 2024 15:03:34 -0700 Subject: [PATCH 04/12] PYTHON-4782 Fix tests --- test/asynchronous/test_locks.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/test/asynchronous/test_locks.py b/test/asynchronous/test_locks.py index 9233a83b03..26223dcaa6 100644 --- a/test/asynchronous/test_locks.py +++ b/test/asynchronous/test_locks.py @@ -299,6 +299,9 @@ async def test_timeout_in_block(self): with self.assertRaises(asyncio.TimeoutError): await asyncio.wait_for(condition.wait(), timeout=0.5) + @unittest.skipIf( + sys.version_info < (3, 11), "raising the same cancelled error requires Python>=3.11" + ) async def test_cancelled_error_wakeup(self): # Test that a cancelled error, received when awaiting wakeup, # will be re-raised un-modified. @@ -325,6 +328,9 @@ async def func(): # originally raised. self.assertIs(err.exception, raised) + @unittest.skipIf( + sys.version_info < (3, 11), "raising the same cancelled error requires Python>=3.11" + ) async def test_cancelled_error_re_aquire(self): # Test that a cancelled error, received when re-aquiring lock, # will be re-raised un-modified. @@ -357,6 +363,7 @@ async def func(): # originally raised. self.assertIs(err.exception, raised) + @unittest.skipIf(sys.version_info < (3, 11), "asyncio.timeout requires Python>=3.11") async def test_cancelled_wakeup(self): # Test that a task cancelled at the "same" time as it is woken # up as part of a Condition.notify() does not result in a lost wakeup. @@ -402,6 +409,7 @@ async def consumer(): condition.notify_all() await c[1] + @unittest.skipIf(sys.version_info < (3, 11), "asyncio.timeout requires Python>=3.11") async def test_cancelled_wakeup_relock(self): # Test that a task cancelled at the "same" time as it is woken # up as part of a Condition.notify() does not result in a lost wakeup. From 0e9394b6ea8768db9defe999d54b3826dac8719c Mon Sep 17 00:00:00 2001 From: Shane Harvey Date: Tue, 24 Sep 2024 15:11:06 -0700 Subject: [PATCH 05/12] PYTHON-4782 Don't convert test_locks.py to sync --- tools/synchro.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/tools/synchro.py b/tools/synchro.py index 59d6e653e5..8aeecc7e99 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -144,7 +144,17 @@ _gridfs_base + f for f in listdir(_gridfs_base) if (Path(_gridfs_base) / f).is_file() ] -test_files = [_test_base + f for f in listdir(_test_base) if (Path(_test_base) / f).is_file()] + +def async_only_test(f: str) -> bool: + """Return True for async tests that should not be converted to sync.""" + return f in ["test_locks.py"] + + +test_files = [ + _test_base + f + for f in listdir(_test_base) + if (Path(_test_base) / f).is_file() and not async_only_test(f) +] sync_files = [ _pymongo_dest_base + f From 2aa02e05a423f68569d7f55bb90a5da321d2e7f5 Mon Sep 17 00:00:00 2001 From: Shane Harvey Date: Tue, 24 Sep 2024 15:43:12 -0700 Subject: [PATCH 06/12] PYTHON-4782 Final type ignore --- test/asynchronous/test_locks.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/asynchronous/test_locks.py b/test/asynchronous/test_locks.py index 26223dcaa6..ca9bc90a93 100644 --- a/test/asynchronous/test_locks.py +++ b/test/asynchronous/test_locks.py @@ -320,7 +320,7 @@ async def func(): task = asyncio.create_task(func()) await asyncio.sleep(0) # Task is waiting on the condition, cancel it there. - task.cancel(msg="foo") + task.cancel(msg="foo") # type: ignore[call-arg] with self.assertRaises(asyncio.CancelledError) as err: await task self.assertEqual(err.exception.args, ("foo",)) @@ -354,7 +354,7 @@ async def func(): cond.notify() await asyncio.sleep(0) # Task is now trying to re-acquire the lock, cancel it there. - task.cancel(msg="foo") + task.cancel(msg="foo") # type: ignore[call-arg] cond.release() with self.assertRaises(asyncio.CancelledError) as err: await task From 966ad507d16fe8e1c05c02cbfe4d3a6de61b81be Mon Sep 17 00:00:00 2001 From: Shane Harvey Date: Thu, 26 Sep 2024 16:26:30 -0700 Subject: [PATCH 07/12] PYTHON-4782 Fix _ACondition initializations --- pymongo/asynchronous/pool.py | 7 ++++--- pymongo/asynchronous/topology.py | 5 +++-- pymongo/lock.py | 5 +++++ pymongo/synchronous/pool.py | 9 +++++---- pymongo/synchronous/topology.py | 7 ++++--- tools/synchro.py | 2 +- 6 files changed, 22 insertions(+), 13 deletions(-) diff --git a/pymongo/asynchronous/pool.py b/pymongo/asynchronous/pool.py index a657042423..087dbb0e73 100644 --- a/pymongo/asynchronous/pool.py +++ b/pymongo/asynchronous/pool.py @@ -992,7 +992,8 @@ def __init__( # from the right side. self.conns: collections.deque = collections.deque() self.active_contexts: set[_CancellationContext] = set() - self.lock = _ALock(_create_lock()) + _lock = _create_lock() + self.lock = _ALock(_lock) self.active_sockets = 0 # Monotonically increasing connection ID required for CMAP Events. self.next_connection_id = 1 @@ -1018,7 +1019,7 @@ def __init__( # The first portion of the wait queue. # Enforces: maxPoolSize # Also used for: clearing the wait queue - self.size_cond = _ACondition(threading.Condition(self.lock)) # type: ignore[arg-type] + self.size_cond = _ACondition(threading.Condition(_lock)) self.requests = 0 self.max_pool_size = self.opts.max_pool_size if not self.max_pool_size: @@ -1026,7 +1027,7 @@ def __init__( # The second portion of the wait queue. # Enforces: maxConnecting # Also used for: clearing the wait queue - self._max_connecting_cond = _ACondition(threading.Condition(self.lock)) # type: ignore[arg-type] + self._max_connecting_cond = _ACondition(threading.Condition(_lock)) self._max_connecting = self.opts.max_connecting self._pending = 0 self._client_id = client_id diff --git a/pymongo/asynchronous/topology.py b/pymongo/asynchronous/topology.py index 4e778cbc17..82af4257ba 100644 --- a/pymongo/asynchronous/topology.py +++ b/pymongo/asynchronous/topology.py @@ -170,8 +170,9 @@ def __init__(self, topology_settings: TopologySettings): self._seed_addresses = list(topology_description.server_descriptions()) self._opened = False self._closed = False - self._lock = _ALock(_create_lock()) - self._condition = _ACondition(self._settings.condition_class(self._lock)) # type: ignore[arg-type] + _lock = _create_lock() + self._lock = _ALock(_lock) + self._condition = _ACondition(self._settings.condition_class(_lock)) self._servers: dict[_Address, Server] = {} self._pid: Optional[int] = None self._max_cluster_time: Optional[ClusterTime] = None diff --git a/pymongo/lock.py b/pymongo/lock.py index 22cbf8d96b..957e28cb82 100644 --- a/pymongo/lock.py +++ b/pymongo/lock.py @@ -46,6 +46,11 @@ def _release_locks() -> None: lock.release() +# TODO: remove this. +def _Lock(lock: threading.Lock) -> threading.Lock: + return lock + + class _ALock: __slots__ = ("_lock",) diff --git a/pymongo/synchronous/pool.py b/pymongo/synchronous/pool.py index 94a1d10436..7a7be0f40b 100644 --- a/pymongo/synchronous/pool.py +++ b/pymongo/synchronous/pool.py @@ -62,7 +62,7 @@ _CertificateError, ) from pymongo.hello import Hello, HelloCompat -from pymongo.lock import _create_lock +from pymongo.lock import _create_lock, _Lock from pymongo.logger import ( _CONNECTION_LOGGER, _ConnectionStatusMessage, @@ -988,7 +988,8 @@ def __init__( # from the right side. self.conns: collections.deque = collections.deque() self.active_contexts: set[_CancellationContext] = set() - self.lock = _create_lock() + _lock = _create_lock() + self.lock = _Lock(_lock) self.active_sockets = 0 # Monotonically increasing connection ID required for CMAP Events. self.next_connection_id = 1 @@ -1014,7 +1015,7 @@ def __init__( # The first portion of the wait queue. # Enforces: maxPoolSize # Also used for: clearing the wait queue - self.size_cond = threading.Condition(self.lock) # type: ignore[arg-type] + self.size_cond = threading.Condition(_lock) self.requests = 0 self.max_pool_size = self.opts.max_pool_size if not self.max_pool_size: @@ -1022,7 +1023,7 @@ def __init__( # The second portion of the wait queue. # Enforces: maxConnecting # Also used for: clearing the wait queue - self._max_connecting_cond = threading.Condition(self.lock) # type: ignore[arg-type] + self._max_connecting_cond = threading.Condition(_lock) self._max_connecting = self.opts.max_connecting self._pending = 0 self._client_id = client_id diff --git a/pymongo/synchronous/topology.py b/pymongo/synchronous/topology.py index e8070e30ab..a350c1702e 100644 --- a/pymongo/synchronous/topology.py +++ b/pymongo/synchronous/topology.py @@ -39,7 +39,7 @@ WriteError, ) from pymongo.hello import Hello -from pymongo.lock import _create_lock +from pymongo.lock import _create_lock, _Lock from pymongo.logger import ( _SDAM_LOGGER, _SERVER_SELECTION_LOGGER, @@ -170,8 +170,9 @@ def __init__(self, topology_settings: TopologySettings): self._seed_addresses = list(topology_description.server_descriptions()) self._opened = False self._closed = False - self._lock = _create_lock() - self._condition = self._settings.condition_class(self._lock) # type: ignore[arg-type] + _lock = _create_lock() + self._lock = _Lock(_lock) + self._condition = self._settings.condition_class(_lock) self._servers: dict[_Address, Server] = {} self._pid: Optional[int] = None self._max_cluster_time: Optional[ClusterTime] = None diff --git a/tools/synchro.py b/tools/synchro.py index 8aeecc7e99..4cb3b84639 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -250,7 +250,7 @@ def translate_locks(lines: list[str]) -> list[str]: lock_lines = [line for line in lines if "_Lock(" in line] cond_lines = [line for line in lines if "_Condition(" in line] for line in lock_lines: - res = re.search(r"_Lock\(([^()]*\(\))\)", line) + res = re.search(r"_Lock\(([^()]*\([^()]*\))\)", line) if res: old = res[0] index = lines.index(line) From f41a04cc9497d5c037212446a62827e9a0e1a11f Mon Sep 17 00:00:00 2001 From: Shane Harvey Date: Fri, 27 Sep 2024 13:30:54 -0700 Subject: [PATCH 08/12] PYTHON-4782 Make test_to_list_csot_applied less flaky --- test/asynchronous/test_cursor.py | 4 ++-- test/test_cursor.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/test/asynchronous/test_cursor.py b/test/asynchronous/test_cursor.py index 33eaacee96..e79ad00641 100644 --- a/test/asynchronous/test_cursor.py +++ b/test/asynchronous/test_cursor.py @@ -1414,7 +1414,7 @@ async def test_to_list_length(self): async def test_to_list_csot_applied(self): client = await self.async_single_client(timeoutMS=500) # Initialize the client with a larger timeout to help make test less flakey - with pymongo.timeout(2): + with pymongo.timeout(10): await client.admin.command("ping") coll = client.pymongo.test await coll.insert_many([{} for _ in range(5)]) @@ -1456,7 +1456,7 @@ async def test_command_cursor_to_list_length(self): async def test_command_cursor_to_list_csot_applied(self): client = await self.async_single_client(timeoutMS=500) # Initialize the client with a larger timeout to help make test less flakey - with pymongo.timeout(2): + with pymongo.timeout(10): await client.admin.command("ping") coll = client.pymongo.test await coll.insert_many([{} for _ in range(5)]) diff --git a/test/test_cursor.py b/test/test_cursor.py index d99732aec3..7c073bf351 100644 --- a/test/test_cursor.py +++ b/test/test_cursor.py @@ -1405,7 +1405,7 @@ def test_to_list_length(self): def test_to_list_csot_applied(self): client = self.single_client(timeoutMS=500) # Initialize the client with a larger timeout to help make test less flakey - with pymongo.timeout(2): + with pymongo.timeout(10): client.admin.command("ping") coll = client.pymongo.test coll.insert_many([{} for _ in range(5)]) @@ -1447,7 +1447,7 @@ def test_command_cursor_to_list_length(self): def test_command_cursor_to_list_csot_applied(self): client = self.single_client(timeoutMS=500) # Initialize the client with a larger timeout to help make test less flakey - with pymongo.timeout(2): + with pymongo.timeout(10): client.admin.command("ping") coll = client.pymongo.test coll.insert_many([{} for _ in range(5)]) From d1ffa63c9d4152ea5b611dd6766dd01aa4b3b69a Mon Sep 17 00:00:00 2001 From: Shane Harvey Date: Fri, 27 Sep 2024 16:25:19 -0700 Subject: [PATCH 09/12] PYTHON-4782 Make test_reconnect less flaky --- test/asynchronous/test_client.py | 4 +++- test/test_client.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/test/asynchronous/test_client.py b/test/asynchronous/test_client.py index f610f32779..d574307dce 100644 --- a/test/asynchronous/test_client.py +++ b/test/asynchronous/test_client.py @@ -2351,7 +2351,9 @@ async def test_reconnect(self): # But it can reconnect. c.revive_host("a:1") - await (await c._get_topology()).select_servers(writable_server_selector, _Op.TEST) + await (await c._get_topology()).select_servers( + writable_server_selector, _Op.TEST, server_selection_timeout=10 + ) self.assertEqual(await c.address, ("a", 1)) async def _test_network_error(self, operation_callback): diff --git a/test/test_client.py b/test/test_client.py index bc45325f0b..f21dbbec6a 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -2307,7 +2307,9 @@ def test_reconnect(self): # But it can reconnect. c.revive_host("a:1") - (c._get_topology()).select_servers(writable_server_selector, _Op.TEST) + (c._get_topology()).select_servers( + writable_server_selector, _Op.TEST, server_selection_timeout=10 + ) self.assertEqual(c.address, ("a", 1)) def _test_network_error(self, operation_callback): From 8e1155641e7d172d61fd1c2829a87daa2b1e237e Mon Sep 17 00:00:00 2001 From: Shane Harvey Date: Mon, 30 Sep 2024 14:01:10 -0700 Subject: [PATCH 10/12] PYTHON-4782 Improve reliability of test_cancelled_wakeup on windows --- test/asynchronous/test_locks.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/asynchronous/test_locks.py b/test/asynchronous/test_locks.py index ca9bc90a93..e0e7f2fc8d 100644 --- a/test/asynchronous/test_locks.py +++ b/test/asynchronous/test_locks.py @@ -384,7 +384,7 @@ async def consumer(): # create two consumers c = [asyncio.create_task(consumer()) for _ in range(2)] # wait for them to settle - await asyncio.sleep(0) + await asyncio.sleep(0.1) async with condition: # produce one item and wake up one state += 1 @@ -398,7 +398,7 @@ async def consumer(): # if it doesn't means that our "notify" didn"t take hold. # because it raced with a cancel() try: - async with asyncio.timeout(0.01): + async with asyncio.timeout(1): await condition.wait_for(lambda: state == 0) except TimeoutError: pass @@ -430,7 +430,7 @@ async def consumer(): # create two consumers c = [asyncio.create_task(consumer()) for _ in range(2)] # wait for them to settle - await asyncio.sleep(0) + await asyncio.sleep(0.1) async with condition: # produce one item and wake up one state += 1 @@ -448,7 +448,7 @@ async def consumer(): # if it doesn't means that our "notify" didn"t take hold. # because it raced with a cancel() try: - async with asyncio.timeout(0.01): + async with asyncio.timeout(1): await condition.wait_for(lambda: state == 0) except TimeoutError: pass From afe38db1101beed85c6a3b3662937789ed2d27eb Mon Sep 17 00:00:00 2001 From: Shane Harvey Date: Mon, 30 Sep 2024 14:22:38 -0700 Subject: [PATCH 11/12] PYTHON-4782 Make test_timeout_in_multi_batch_bulk_write less flaky --- pymongo/lock.py | 2 +- test/asynchronous/test_client_bulk_write.py | 5 ++++- test/test_client_bulk_write.py | 5 ++++- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/pymongo/lock.py b/pymongo/lock.py index 957e28cb82..0cbfb4a57e 100644 --- a/pymongo/lock.py +++ b/pymongo/lock.py @@ -46,7 +46,7 @@ def _release_locks() -> None: lock.release() -# TODO: remove this. +# Needed only for synchro.py compat. def _Lock(lock: threading.Lock) -> threading.Lock: return lock diff --git a/test/asynchronous/test_client_bulk_write.py b/test/asynchronous/test_client_bulk_write.py index 3a17299453..80cfd30bde 100644 --- a/test/asynchronous/test_client_bulk_write.py +++ b/test/asynchronous/test_client_bulk_write.py @@ -30,6 +30,7 @@ ) from unittest.mock import patch +import pymongo from pymongo.asynchronous.client_bulk import _AsyncClientBulk from pymongo.encryption_options import _HAVE_PYMONGOCRYPT, AutoEncryptionOpts from pymongo.errors import ( @@ -597,7 +598,9 @@ async def test_timeout_in_multi_batch_bulk_write(self): timeoutMS=2000, w="majority", ) - await client.admin.command("ping") # Init the client first. + # Initialize the client with a larger timeout to help make test less flakey + with pymongo.timeout(10): + await client.admin.command("ping") with self.assertRaises(ClientBulkWriteException) as context: await client.bulk_write(models=models) self.assertIsInstance(context.exception.error, NetworkTimeout) diff --git a/test/test_client_bulk_write.py b/test/test_client_bulk_write.py index ebbdc74c1c..d1aff03fc9 100644 --- a/test/test_client_bulk_write.py +++ b/test/test_client_bulk_write.py @@ -30,6 +30,7 @@ ) from unittest.mock import patch +import pymongo from pymongo.encryption_options import _HAVE_PYMONGOCRYPT, AutoEncryptionOpts from pymongo.errors import ( ClientBulkWriteException, @@ -597,7 +598,9 @@ def test_timeout_in_multi_batch_bulk_write(self): timeoutMS=2000, w="majority", ) - client.admin.command("ping") # Init the client first. + # Initialize the client with a larger timeout to help make test less flakey + with pymongo.timeout(10): + client.admin.command("ping") with self.assertRaises(ClientBulkWriteException) as context: client.bulk_write(models=models) self.assertIsInstance(context.exception.error, NetworkTimeout) From e69b773d00d7e21e212f53a6dd1d337108549990 Mon Sep 17 00:00:00 2001 From: Shane Harvey Date: Mon, 30 Sep 2024 16:15:02 -0700 Subject: [PATCH 12/12] PYTHON-4782 Make test_load_balancing less flaky --- test/test_server_selection_in_window.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/test/test_server_selection_in_window.py b/test/test_server_selection_in_window.py index 8e030f61e8..7cab42cca2 100644 --- a/test/test_server_selection_in_window.py +++ b/test/test_server_selection_in_window.py @@ -19,6 +19,7 @@ import threading from test import IntegrationTest, client_context, unittest from test.utils import ( + CMAPListener, OvertCommandListener, SpecTestCreator, get_pool, @@ -27,6 +28,7 @@ from test.utils_selection_tests import create_topology from pymongo.common import clean_node +from pymongo.monitoring import ConnectionReadyEvent from pymongo.operations import _Op from pymongo.read_preferences import ReadPreference @@ -131,19 +133,20 @@ def frequencies(self, client, listener, n_finds=10): @client_context.require_multiple_mongoses def test_load_balancing(self): listener = OvertCommandListener() + cmap_listener = CMAPListener() # PYTHON-2584: Use a large localThresholdMS to avoid the impact of # varying RTTs. client = self.rs_client( client_context.mongos_seeds(), appName="loadBalancingTest", - event_listeners=[listener], + event_listeners=[listener, cmap_listener], localThresholdMS=30000, minPoolSize=10, ) - self.addCleanup(client.close) wait_until(lambda: len(client.nodes) == 2, "discover both nodes") - wait_until(lambda: len(get_pool(client).conns) >= 10, "create 10 connections") - # Delay find commands on + # Wait for both pools to be populated. + cmap_listener.wait_for_event(ConnectionReadyEvent, 20) + # Delay find commands on only one mongos. delay_finds = { "configureFailPoint": "failCommand", "mode": {"times": 10000}, @@ -161,7 +164,7 @@ def test_load_balancing(self): freqs = self.frequencies(client, listener) self.assertLessEqual(freqs[delayed_server], 0.25) listener.reset() - freqs = self.frequencies(client, listener, n_finds=100) + freqs = self.frequencies(client, listener, n_finds=150) self.assertAlmostEqual(freqs[delayed_server], 0.50, delta=0.15)