diff --git a/pymongo/asynchronous/pool.py b/pymongo/asynchronous/pool.py index a657042423..087dbb0e73 100644 --- a/pymongo/asynchronous/pool.py +++ b/pymongo/asynchronous/pool.py @@ -992,7 +992,8 @@ def __init__( # from the right side. self.conns: collections.deque = collections.deque() self.active_contexts: set[_CancellationContext] = set() - self.lock = _ALock(_create_lock()) + _lock = _create_lock() + self.lock = _ALock(_lock) self.active_sockets = 0 # Monotonically increasing connection ID required for CMAP Events. self.next_connection_id = 1 @@ -1018,7 +1019,7 @@ def __init__( # The first portion of the wait queue. # Enforces: maxPoolSize # Also used for: clearing the wait queue - self.size_cond = _ACondition(threading.Condition(self.lock)) # type: ignore[arg-type] + self.size_cond = _ACondition(threading.Condition(_lock)) self.requests = 0 self.max_pool_size = self.opts.max_pool_size if not self.max_pool_size: @@ -1026,7 +1027,7 @@ def __init__( # The second portion of the wait queue. # Enforces: maxConnecting # Also used for: clearing the wait queue - self._max_connecting_cond = _ACondition(threading.Condition(self.lock)) # type: ignore[arg-type] + self._max_connecting_cond = _ACondition(threading.Condition(_lock)) self._max_connecting = self.opts.max_connecting self._pending = 0 self._client_id = client_id diff --git a/pymongo/asynchronous/topology.py b/pymongo/asynchronous/topology.py index 4e778cbc17..82af4257ba 100644 --- a/pymongo/asynchronous/topology.py +++ b/pymongo/asynchronous/topology.py @@ -170,8 +170,9 @@ def __init__(self, topology_settings: TopologySettings): self._seed_addresses = list(topology_description.server_descriptions()) self._opened = False self._closed = False - self._lock = _ALock(_create_lock()) - self._condition = _ACondition(self._settings.condition_class(self._lock)) # type: ignore[arg-type] + _lock = _create_lock() + self._lock = _ALock(_lock) + self._condition = _ACondition(self._settings.condition_class(_lock)) self._servers: dict[_Address, Server] = {} self._pid: Optional[int] = None self._max_cluster_time: Optional[ClusterTime] = None diff --git a/pymongo/lock.py b/pymongo/lock.py index b05f6acffb..0cbfb4a57e 100644 --- a/pymongo/lock.py +++ b/pymongo/lock.py @@ -14,17 +14,20 @@ from __future__ import annotations import asyncio +import collections import os import threading import time import weakref -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, TypeVar _HAS_REGISTER_AT_FORK = hasattr(os, "register_at_fork") # References to instances of _create_lock _forkable_locks: weakref.WeakSet[threading.Lock] = weakref.WeakSet() +_T = TypeVar("_T") + def _create_lock() -> threading.Lock: """Represents a lock that is tracked upon instantiation using a WeakSet and @@ -43,7 +46,14 @@ def _release_locks() -> None: lock.release() +# Needed only for synchro.py compat. +def _Lock(lock: threading.Lock) -> threading.Lock: + return lock + + class _ALock: + __slots__ = ("_lock",) + def __init__(self, lock: threading.Lock) -> None: self._lock = lock @@ -81,9 +91,18 @@ async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None: self.release() +def _safe_set_result(fut: asyncio.Future) -> None: + # Ensure the future hasn't been cancelled before calling set_result. + if not fut.done(): + fut.set_result(False) + + class _ACondition: + __slots__ = ("_condition", "_waiters") + def __init__(self, condition: threading.Condition) -> None: self._condition = condition + self._waiters: collections.deque = collections.deque() async def acquire(self, blocking: bool = True, timeout: float = -1) -> bool: if timeout > 0: @@ -99,30 +118,116 @@ async def acquire(self, blocking: bool = True, timeout: float = -1) -> bool: await asyncio.sleep(0) async def wait(self, timeout: Optional[float] = None) -> bool: - if timeout is not None: - tstart = time.monotonic() - while True: - notified = self._condition.wait(0.001) - if notified: - return True - if timeout is not None and (time.monotonic() - tstart) > timeout: - return False - - async def wait_for(self, predicate: Callable, timeout: Optional[float] = None) -> bool: - if timeout is not None: - tstart = time.monotonic() - while True: - notified = self._condition.wait_for(predicate, 0.001) - if notified: - return True - if timeout is not None and (time.monotonic() - tstart) > timeout: - return False + """Wait until notified. + + If the calling task has not acquired the lock when this + method is called, a RuntimeError is raised. + + This method releases the underlying lock, and then blocks + until it is awakened by a notify() or notify_all() call for + the same condition variable in another task. Once + awakened, it re-acquires the lock and returns True. + + This method may return spuriously, + which is why the caller should always + re-check the state and be prepared to wait() again. + """ + loop = asyncio.get_running_loop() + fut = loop.create_future() + self._waiters.append((loop, fut)) + self.release() + try: + try: + try: + await asyncio.wait_for(fut, timeout) + return True + except asyncio.TimeoutError: + return False # Return false on timeout for sync pool compat. + finally: + # Must re-acquire lock even if wait is cancelled. + # We only catch CancelledError here, since we don't want any + # other (fatal) errors with the future to cause us to spin. + err = None + while True: + try: + await self.acquire() + break + except asyncio.exceptions.CancelledError as e: + err = e + + self._waiters.remove((loop, fut)) + if err is not None: + try: + raise err # Re-raise most recent exception instance. + finally: + err = None # Break reference cycles. + except BaseException: + # Any error raised out of here _may_ have occurred after this Task + # believed to have been successfully notified. + # Make sure to notify another Task instead. This may result + # in a "spurious wakeup", which is allowed as part of the + # Condition Variable protocol. + self.notify(1) + raise + + async def wait_for(self, predicate: Callable[[], _T]) -> _T: + """Wait until a predicate becomes true. + + The predicate should be a callable whose result will be + interpreted as a boolean value. The method will repeatedly + wait() until it evaluates to true. The final predicate value is + the return value. + """ + result = predicate() + while not result: + await self.wait() + result = predicate() + return result def notify(self, n: int = 1) -> None: - self._condition.notify(n) + """By default, wake up one coroutine waiting on this condition, if any. + If the calling coroutine has not acquired the lock when this method + is called, a RuntimeError is raised. + + This method wakes up at most n of the coroutines waiting for the + condition variable; it is a no-op if no coroutines are waiting. + + Note: an awakened coroutine does not actually return from its + wait() call until it can reacquire the lock. Since notify() does + not release the lock, its caller should. + """ + idx = 0 + to_remove = [] + for loop, fut in self._waiters: + if idx >= n: + break + + if fut.done(): + continue + + try: + loop.call_soon_threadsafe(_safe_set_result, fut) + except RuntimeError: + # Loop was closed, ignore. + to_remove.append((loop, fut)) + continue + + idx += 1 + + for waiter in to_remove: + self._waiters.remove(waiter) def notify_all(self) -> None: - self._condition.notify_all() + """Wake up all threads waiting on this condition. This method acts + like notify(), but wakes up all waiting threads instead of one. If the + calling thread has not acquired the lock when this method is called, + a RuntimeError is raised. + """ + self.notify(len(self._waiters)) + + def locked(self) -> bool: + """Only needed for tests in test_locks.""" + return self._condition._lock.locked() # type: ignore[attr-defined] def release(self) -> None: self._condition.release() diff --git a/pymongo/synchronous/pool.py b/pymongo/synchronous/pool.py index 94a1d10436..7a7be0f40b 100644 --- a/pymongo/synchronous/pool.py +++ b/pymongo/synchronous/pool.py @@ -62,7 +62,7 @@ _CertificateError, ) from pymongo.hello import Hello, HelloCompat -from pymongo.lock import _create_lock +from pymongo.lock import _create_lock, _Lock from pymongo.logger import ( _CONNECTION_LOGGER, _ConnectionStatusMessage, @@ -988,7 +988,8 @@ def __init__( # from the right side. self.conns: collections.deque = collections.deque() self.active_contexts: set[_CancellationContext] = set() - self.lock = _create_lock() + _lock = _create_lock() + self.lock = _Lock(_lock) self.active_sockets = 0 # Monotonically increasing connection ID required for CMAP Events. self.next_connection_id = 1 @@ -1014,7 +1015,7 @@ def __init__( # The first portion of the wait queue. # Enforces: maxPoolSize # Also used for: clearing the wait queue - self.size_cond = threading.Condition(self.lock) # type: ignore[arg-type] + self.size_cond = threading.Condition(_lock) self.requests = 0 self.max_pool_size = self.opts.max_pool_size if not self.max_pool_size: @@ -1022,7 +1023,7 @@ def __init__( # The second portion of the wait queue. # Enforces: maxConnecting # Also used for: clearing the wait queue - self._max_connecting_cond = threading.Condition(self.lock) # type: ignore[arg-type] + self._max_connecting_cond = threading.Condition(_lock) self._max_connecting = self.opts.max_connecting self._pending = 0 self._client_id = client_id diff --git a/pymongo/synchronous/topology.py b/pymongo/synchronous/topology.py index e8070e30ab..a350c1702e 100644 --- a/pymongo/synchronous/topology.py +++ b/pymongo/synchronous/topology.py @@ -39,7 +39,7 @@ WriteError, ) from pymongo.hello import Hello -from pymongo.lock import _create_lock +from pymongo.lock import _create_lock, _Lock from pymongo.logger import ( _SDAM_LOGGER, _SERVER_SELECTION_LOGGER, @@ -170,8 +170,9 @@ def __init__(self, topology_settings: TopologySettings): self._seed_addresses = list(topology_description.server_descriptions()) self._opened = False self._closed = False - self._lock = _create_lock() - self._condition = self._settings.condition_class(self._lock) # type: ignore[arg-type] + _lock = _create_lock() + self._lock = _Lock(_lock) + self._condition = self._settings.condition_class(_lock) self._servers: dict[_Address, Server] = {} self._pid: Optional[int] = None self._max_cluster_time: Optional[ClusterTime] = None diff --git a/test/asynchronous/test_client.py b/test/asynchronous/test_client.py index f610f32779..d574307dce 100644 --- a/test/asynchronous/test_client.py +++ b/test/asynchronous/test_client.py @@ -2351,7 +2351,9 @@ async def test_reconnect(self): # But it can reconnect. c.revive_host("a:1") - await (await c._get_topology()).select_servers(writable_server_selector, _Op.TEST) + await (await c._get_topology()).select_servers( + writable_server_selector, _Op.TEST, server_selection_timeout=10 + ) self.assertEqual(await c.address, ("a", 1)) async def _test_network_error(self, operation_callback): diff --git a/test/asynchronous/test_client_bulk_write.py b/test/asynchronous/test_client_bulk_write.py index 3a17299453..80cfd30bde 100644 --- a/test/asynchronous/test_client_bulk_write.py +++ b/test/asynchronous/test_client_bulk_write.py @@ -30,6 +30,7 @@ ) from unittest.mock import patch +import pymongo from pymongo.asynchronous.client_bulk import _AsyncClientBulk from pymongo.encryption_options import _HAVE_PYMONGOCRYPT, AutoEncryptionOpts from pymongo.errors import ( @@ -597,7 +598,9 @@ async def test_timeout_in_multi_batch_bulk_write(self): timeoutMS=2000, w="majority", ) - await client.admin.command("ping") # Init the client first. + # Initialize the client with a larger timeout to help make test less flakey + with pymongo.timeout(10): + await client.admin.command("ping") with self.assertRaises(ClientBulkWriteException) as context: await client.bulk_write(models=models) self.assertIsInstance(context.exception.error, NetworkTimeout) diff --git a/test/asynchronous/test_cursor.py b/test/asynchronous/test_cursor.py index 33eaacee96..e79ad00641 100644 --- a/test/asynchronous/test_cursor.py +++ b/test/asynchronous/test_cursor.py @@ -1414,7 +1414,7 @@ async def test_to_list_length(self): async def test_to_list_csot_applied(self): client = await self.async_single_client(timeoutMS=500) # Initialize the client with a larger timeout to help make test less flakey - with pymongo.timeout(2): + with pymongo.timeout(10): await client.admin.command("ping") coll = client.pymongo.test await coll.insert_many([{} for _ in range(5)]) @@ -1456,7 +1456,7 @@ async def test_command_cursor_to_list_length(self): async def test_command_cursor_to_list_csot_applied(self): client = await self.async_single_client(timeoutMS=500) # Initialize the client with a larger timeout to help make test less flakey - with pymongo.timeout(2): + with pymongo.timeout(10): await client.admin.command("ping") coll = client.pymongo.test await coll.insert_many([{} for _ in range(5)]) diff --git a/test/asynchronous/test_locks.py b/test/asynchronous/test_locks.py new file mode 100644 index 0000000000..e0e7f2fc8d --- /dev/null +++ b/test/asynchronous/test_locks.py @@ -0,0 +1,513 @@ +# Copyright 2024-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for lock.py""" +from __future__ import annotations + +import asyncio +import sys +import threading +import unittest + +sys.path[0:0] = [""] + +from pymongo.lock import _ACondition + + +# Tests adapted from: https://github.com/python/cpython/blob/v3.13.0rc2/Lib/test/test_asyncio/test_locks.py +# Includes tests for: +# - https://github.com/python/cpython/issues/111693 +# - https://github.com/python/cpython/issues/112202 +class TestConditionStdlib(unittest.IsolatedAsyncioTestCase): + async def test_wait(self): + cond = _ACondition(threading.Condition(threading.Lock())) + result = [] + + async def c1(result): + await cond.acquire() + if await cond.wait(): + result.append(1) + return True + + async def c2(result): + await cond.acquire() + if await cond.wait(): + result.append(2) + return True + + async def c3(result): + await cond.acquire() + if await cond.wait(): + result.append(3) + return True + + t1 = asyncio.create_task(c1(result)) + t2 = asyncio.create_task(c2(result)) + t3 = asyncio.create_task(c3(result)) + + await asyncio.sleep(0) + self.assertEqual([], result) + self.assertFalse(cond.locked()) + + self.assertTrue(await cond.acquire()) + cond.notify() + await asyncio.sleep(0) + self.assertEqual([], result) + self.assertTrue(cond.locked()) + + cond.release() + await asyncio.sleep(0) + self.assertEqual([1], result) + self.assertTrue(cond.locked()) + + cond.notify(2) + await asyncio.sleep(0) + self.assertEqual([1], result) + self.assertTrue(cond.locked()) + + cond.release() + await asyncio.sleep(0) + self.assertEqual([1, 2], result) + self.assertTrue(cond.locked()) + + cond.release() + await asyncio.sleep(0) + self.assertEqual([1, 2, 3], result) + self.assertTrue(cond.locked()) + + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + self.assertTrue(t2.done()) + self.assertTrue(t2.result()) + self.assertTrue(t3.done()) + self.assertTrue(t3.result()) + + async def test_wait_cancel(self): + cond = _ACondition(threading.Condition(threading.Lock())) + await cond.acquire() + + wait = asyncio.create_task(cond.wait()) + asyncio.get_running_loop().call_soon(wait.cancel) + with self.assertRaises(asyncio.CancelledError): + await wait + self.assertFalse(cond._waiters) + self.assertTrue(cond.locked()) + + async def test_wait_cancel_contested(self): + cond = _ACondition(threading.Condition(threading.Lock())) + + await cond.acquire() + self.assertTrue(cond.locked()) + + wait_task = asyncio.create_task(cond.wait()) + await asyncio.sleep(0) + self.assertFalse(cond.locked()) + + # Notify, but contest the lock before cancelling + await cond.acquire() + self.assertTrue(cond.locked()) + cond.notify() + asyncio.get_running_loop().call_soon(wait_task.cancel) + asyncio.get_running_loop().call_soon(cond.release) + + try: + await wait_task + except asyncio.CancelledError: + # Should not happen, since no cancellation points + pass + + self.assertTrue(cond.locked()) + + async def test_wait_cancel_after_notify(self): + # See bpo-32841 + waited = False + + cond = _ACondition(threading.Condition(threading.Lock())) + + async def wait_on_cond(): + nonlocal waited + async with cond: + waited = True # Make sure this area was reached + await cond.wait() + + waiter = asyncio.create_task(wait_on_cond()) + await asyncio.sleep(0) # Start waiting + + await cond.acquire() + cond.notify() + await asyncio.sleep(0) # Get to acquire() + waiter.cancel() + await asyncio.sleep(0) # Activate cancellation + cond.release() + await asyncio.sleep(0) # Cancellation should occur + + self.assertTrue(waiter.cancelled()) + self.assertTrue(waited) + + async def test_wait_unacquired(self): + cond = _ACondition(threading.Condition(threading.Lock())) + with self.assertRaises(RuntimeError): + await cond.wait() + + async def test_wait_for(self): + cond = _ACondition(threading.Condition(threading.Lock())) + presult = False + + def predicate(): + return presult + + result = [] + + async def c1(result): + await cond.acquire() + if await cond.wait_for(predicate): + result.append(1) + cond.release() + return True + + t = asyncio.create_task(c1(result)) + + await asyncio.sleep(0) + self.assertEqual([], result) + + await cond.acquire() + cond.notify() + cond.release() + await asyncio.sleep(0) + self.assertEqual([], result) + + presult = True + await cond.acquire() + cond.notify() + cond.release() + await asyncio.sleep(0) + self.assertEqual([1], result) + + self.assertTrue(t.done()) + self.assertTrue(t.result()) + + async def test_wait_for_unacquired(self): + cond = _ACondition(threading.Condition(threading.Lock())) + + # predicate can return true immediately + res = await cond.wait_for(lambda: [1, 2, 3]) + self.assertEqual([1, 2, 3], res) + + with self.assertRaises(RuntimeError): + await cond.wait_for(lambda: False) + + async def test_notify(self): + cond = _ACondition(threading.Condition(threading.Lock())) + result = [] + + async def c1(result): + async with cond: + if await cond.wait(): + result.append(1) + return True + + async def c2(result): + async with cond: + if await cond.wait(): + result.append(2) + return True + + async def c3(result): + async with cond: + if await cond.wait(): + result.append(3) + return True + + t1 = asyncio.create_task(c1(result)) + t2 = asyncio.create_task(c2(result)) + t3 = asyncio.create_task(c3(result)) + + await asyncio.sleep(0) + self.assertEqual([], result) + + async with cond: + cond.notify(1) + await asyncio.sleep(1) + self.assertEqual([1], result) + + async with cond: + cond.notify(1) + cond.notify(2048) + await asyncio.sleep(1) + self.assertEqual([1, 2, 3], result) + + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + self.assertTrue(t2.done()) + self.assertTrue(t2.result()) + self.assertTrue(t3.done()) + self.assertTrue(t3.result()) + + async def test_notify_all(self): + cond = _ACondition(threading.Condition(threading.Lock())) + + result = [] + + async def c1(result): + async with cond: + if await cond.wait(): + result.append(1) + return True + + async def c2(result): + async with cond: + if await cond.wait(): + result.append(2) + return True + + t1 = asyncio.create_task(c1(result)) + t2 = asyncio.create_task(c2(result)) + + await asyncio.sleep(0) + self.assertEqual([], result) + + async with cond: + cond.notify_all() + await asyncio.sleep(1) + self.assertEqual([1, 2], result) + + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + self.assertTrue(t2.done()) + self.assertTrue(t2.result()) + + async def test_context_manager(self): + cond = _ACondition(threading.Condition(threading.Lock())) + self.assertFalse(cond.locked()) + async with cond: + self.assertTrue(cond.locked()) + self.assertFalse(cond.locked()) + + async def test_timeout_in_block(self): + condition = _ACondition(threading.Condition(threading.Lock())) + async with condition: + with self.assertRaises(asyncio.TimeoutError): + await asyncio.wait_for(condition.wait(), timeout=0.5) + + @unittest.skipIf( + sys.version_info < (3, 11), "raising the same cancelled error requires Python>=3.11" + ) + async def test_cancelled_error_wakeup(self): + # Test that a cancelled error, received when awaiting wakeup, + # will be re-raised un-modified. + wake = False + raised = None + cond = _ACondition(threading.Condition(threading.Lock())) + + async def func(): + nonlocal raised + async with cond: + with self.assertRaises(asyncio.CancelledError) as err: + await cond.wait_for(lambda: wake) + raised = err.exception + raise raised + + task = asyncio.create_task(func()) + await asyncio.sleep(0) + # Task is waiting on the condition, cancel it there. + task.cancel(msg="foo") # type: ignore[call-arg] + with self.assertRaises(asyncio.CancelledError) as err: + await task + self.assertEqual(err.exception.args, ("foo",)) + # We should have got the _same_ exception instance as the one + # originally raised. + self.assertIs(err.exception, raised) + + @unittest.skipIf( + sys.version_info < (3, 11), "raising the same cancelled error requires Python>=3.11" + ) + async def test_cancelled_error_re_aquire(self): + # Test that a cancelled error, received when re-aquiring lock, + # will be re-raised un-modified. + wake = False + raised = None + cond = _ACondition(threading.Condition(threading.Lock())) + + async def func(): + nonlocal raised + async with cond: + with self.assertRaises(asyncio.CancelledError) as err: + await cond.wait_for(lambda: wake) + raised = err.exception + raise raised + + task = asyncio.create_task(func()) + await asyncio.sleep(0) + # Task is waiting on the condition + await cond.acquire() + wake = True + cond.notify() + await asyncio.sleep(0) + # Task is now trying to re-acquire the lock, cancel it there. + task.cancel(msg="foo") # type: ignore[call-arg] + cond.release() + with self.assertRaises(asyncio.CancelledError) as err: + await task + self.assertEqual(err.exception.args, ("foo",)) + # We should have got the _same_ exception instance as the one + # originally raised. + self.assertIs(err.exception, raised) + + @unittest.skipIf(sys.version_info < (3, 11), "asyncio.timeout requires Python>=3.11") + async def test_cancelled_wakeup(self): + # Test that a task cancelled at the "same" time as it is woken + # up as part of a Condition.notify() does not result in a lost wakeup. + # This test simulates a cancel while the target task is awaiting initial + # wakeup on the wakeup queue. + condition = _ACondition(threading.Condition(threading.Lock())) + state = 0 + + async def consumer(): + nonlocal state + async with condition: + while True: + await condition.wait_for(lambda: state != 0) + if state < 0: + return + state -= 1 + + # create two consumers + c = [asyncio.create_task(consumer()) for _ in range(2)] + # wait for them to settle + await asyncio.sleep(0.1) + async with condition: + # produce one item and wake up one + state += 1 + condition.notify(1) + + # Cancel it while it is awaiting to be run. + # This cancellation could come from the outside + c[0].cancel() + + # now wait for the item to be consumed + # if it doesn't means that our "notify" didn"t take hold. + # because it raced with a cancel() + try: + async with asyncio.timeout(1): + await condition.wait_for(lambda: state == 0) + except TimeoutError: + pass + self.assertEqual(state, 0) + + # clean up + state = -1 + condition.notify_all() + await c[1] + + @unittest.skipIf(sys.version_info < (3, 11), "asyncio.timeout requires Python>=3.11") + async def test_cancelled_wakeup_relock(self): + # Test that a task cancelled at the "same" time as it is woken + # up as part of a Condition.notify() does not result in a lost wakeup. + # This test simulates a cancel while the target task is acquiring the lock + # again. + condition = _ACondition(threading.Condition(threading.Lock())) + state = 0 + + async def consumer(): + nonlocal state + async with condition: + while True: + await condition.wait_for(lambda: state != 0) + if state < 0: + return + state -= 1 + + # create two consumers + c = [asyncio.create_task(consumer()) for _ in range(2)] + # wait for them to settle + await asyncio.sleep(0.1) + async with condition: + # produce one item and wake up one + state += 1 + condition.notify(1) + + # now we sleep for a bit. This allows the target task to wake up and + # settle on re-aquiring the lock + await asyncio.sleep(0) + + # Cancel it while awaiting the lock + # This cancel could come the outside. + c[0].cancel() + + # now wait for the item to be consumed + # if it doesn't means that our "notify" didn"t take hold. + # because it raced with a cancel() + try: + async with asyncio.timeout(1): + await condition.wait_for(lambda: state == 0) + except TimeoutError: + pass + self.assertEqual(state, 0) + + # clean up + state = -1 + condition.notify_all() + await c[1] + + +class TestCondition(unittest.IsolatedAsyncioTestCase): + async def test_multiple_loops_notify(self): + cond = _ACondition(threading.Condition(threading.Lock())) + + def tmain(cond): + async def atmain(cond): + await asyncio.sleep(1) + async with cond: + cond.notify(1) + + asyncio.run(atmain(cond)) + + t = threading.Thread(target=tmain, args=(cond,)) + t.start() + + async with cond: + self.assertTrue(await cond.wait(30)) + t.join() + + async def test_multiple_loops_notify_all(self): + cond = _ACondition(threading.Condition(threading.Lock())) + results = [] + + def tmain(cond, results): + async def atmain(cond, results): + await asyncio.sleep(1) + async with cond: + res = await cond.wait(30) + results.append(res) + + asyncio.run(atmain(cond, results)) + + nthreads = 5 + threads = [] + for _ in range(nthreads): + threads.append(threading.Thread(target=tmain, args=(cond, results))) + for t in threads: + t.start() + + await asyncio.sleep(2) + async with cond: + cond.notify_all() + + for t in threads: + t.join() + + self.assertEqual(results, [True] * nthreads) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_client.py b/test/test_client.py index bc45325f0b..f21dbbec6a 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -2307,7 +2307,9 @@ def test_reconnect(self): # But it can reconnect. c.revive_host("a:1") - (c._get_topology()).select_servers(writable_server_selector, _Op.TEST) + (c._get_topology()).select_servers( + writable_server_selector, _Op.TEST, server_selection_timeout=10 + ) self.assertEqual(c.address, ("a", 1)) def _test_network_error(self, operation_callback): diff --git a/test/test_client_bulk_write.py b/test/test_client_bulk_write.py index ebbdc74c1c..d1aff03fc9 100644 --- a/test/test_client_bulk_write.py +++ b/test/test_client_bulk_write.py @@ -30,6 +30,7 @@ ) from unittest.mock import patch +import pymongo from pymongo.encryption_options import _HAVE_PYMONGOCRYPT, AutoEncryptionOpts from pymongo.errors import ( ClientBulkWriteException, @@ -597,7 +598,9 @@ def test_timeout_in_multi_batch_bulk_write(self): timeoutMS=2000, w="majority", ) - client.admin.command("ping") # Init the client first. + # Initialize the client with a larger timeout to help make test less flakey + with pymongo.timeout(10): + client.admin.command("ping") with self.assertRaises(ClientBulkWriteException) as context: client.bulk_write(models=models) self.assertIsInstance(context.exception.error, NetworkTimeout) diff --git a/test/test_cursor.py b/test/test_cursor.py index d99732aec3..7c073bf351 100644 --- a/test/test_cursor.py +++ b/test/test_cursor.py @@ -1405,7 +1405,7 @@ def test_to_list_length(self): def test_to_list_csot_applied(self): client = self.single_client(timeoutMS=500) # Initialize the client with a larger timeout to help make test less flakey - with pymongo.timeout(2): + with pymongo.timeout(10): client.admin.command("ping") coll = client.pymongo.test coll.insert_many([{} for _ in range(5)]) @@ -1447,7 +1447,7 @@ def test_command_cursor_to_list_length(self): def test_command_cursor_to_list_csot_applied(self): client = self.single_client(timeoutMS=500) # Initialize the client with a larger timeout to help make test less flakey - with pymongo.timeout(2): + with pymongo.timeout(10): client.admin.command("ping") coll = client.pymongo.test coll.insert_many([{} for _ in range(5)]) diff --git a/test/test_server_selection_in_window.py b/test/test_server_selection_in_window.py index 8e030f61e8..7cab42cca2 100644 --- a/test/test_server_selection_in_window.py +++ b/test/test_server_selection_in_window.py @@ -19,6 +19,7 @@ import threading from test import IntegrationTest, client_context, unittest from test.utils import ( + CMAPListener, OvertCommandListener, SpecTestCreator, get_pool, @@ -27,6 +28,7 @@ from test.utils_selection_tests import create_topology from pymongo.common import clean_node +from pymongo.monitoring import ConnectionReadyEvent from pymongo.operations import _Op from pymongo.read_preferences import ReadPreference @@ -131,19 +133,20 @@ def frequencies(self, client, listener, n_finds=10): @client_context.require_multiple_mongoses def test_load_balancing(self): listener = OvertCommandListener() + cmap_listener = CMAPListener() # PYTHON-2584: Use a large localThresholdMS to avoid the impact of # varying RTTs. client = self.rs_client( client_context.mongos_seeds(), appName="loadBalancingTest", - event_listeners=[listener], + event_listeners=[listener, cmap_listener], localThresholdMS=30000, minPoolSize=10, ) - self.addCleanup(client.close) wait_until(lambda: len(client.nodes) == 2, "discover both nodes") - wait_until(lambda: len(get_pool(client).conns) >= 10, "create 10 connections") - # Delay find commands on + # Wait for both pools to be populated. + cmap_listener.wait_for_event(ConnectionReadyEvent, 20) + # Delay find commands on only one mongos. delay_finds = { "configureFailPoint": "failCommand", "mode": {"times": 10000}, @@ -161,7 +164,7 @@ def test_load_balancing(self): freqs = self.frequencies(client, listener) self.assertLessEqual(freqs[delayed_server], 0.25) listener.reset() - freqs = self.frequencies(client, listener, n_finds=100) + freqs = self.frequencies(client, listener, n_finds=150) self.assertAlmostEqual(freqs[delayed_server], 0.50, delta=0.15) diff --git a/tools/synchro.py b/tools/synchro.py index 59d6e653e5..4cb3b84639 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -144,7 +144,17 @@ _gridfs_base + f for f in listdir(_gridfs_base) if (Path(_gridfs_base) / f).is_file() ] -test_files = [_test_base + f for f in listdir(_test_base) if (Path(_test_base) / f).is_file()] + +def async_only_test(f: str) -> bool: + """Return True for async tests that should not be converted to sync.""" + return f in ["test_locks.py"] + + +test_files = [ + _test_base + f + for f in listdir(_test_base) + if (Path(_test_base) / f).is_file() and not async_only_test(f) +] sync_files = [ _pymongo_dest_base + f @@ -240,7 +250,7 @@ def translate_locks(lines: list[str]) -> list[str]: lock_lines = [line for line in lines if "_Lock(" in line] cond_lines = [line for line in lines if "_Condition(" in line] for line in lock_lines: - res = re.search(r"_Lock\(([^()]*\(\))\)", line) + res = re.search(r"_Lock\(([^()]*\([^()]*\))\)", line) if res: old = res[0] index = lines.index(line)