Skip to content

Commit

Permalink
PYTHON-5053 - AsyncMongoClient.close() should await all background ta…
Browse files Browse the repository at this point in the history
…sks (mongodb#2127)
  • Loading branch information
NoahStapp authored Feb 5, 2025
1 parent f344eb7 commit 1b81847
Show file tree
Hide file tree
Showing 7 changed files with 82 additions and 6 deletions.
6 changes: 6 additions & 0 deletions pymongo/asynchronous/mongo_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1565,6 +1565,12 @@ async def close(self) -> None:
# TODO: PYTHON-1921 Encrypted MongoClients cannot be re-opened.
await self._encrypter.close()
self._closed = True
if not _IS_SYNC:
await asyncio.gather(
self._topology.cleanup_monitors(), # type: ignore[func-returns-value]
self._kill_cursors_executor.join(), # type: ignore[func-returns-value]
return_exceptions=True,
)

if not _IS_SYNC:
# Add support for contextlib.aclosing.
Expand Down
9 changes: 7 additions & 2 deletions pymongo/asynchronous/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,9 @@ async def close(self) -> None:
"""
self.gc_safe_close()

async def join(self, timeout: Optional[int] = None) -> None:
async def join(self) -> None:
"""Wait for the monitor to stop."""
await self._executor.join(timeout)
await self._executor.join()

def request_check(self) -> None:
"""If the monitor is sleeping, wake it soon."""
Expand Down Expand Up @@ -189,6 +189,11 @@ def gc_safe_close(self) -> None:
self._rtt_monitor.gc_safe_close()
self.cancel_check()

async def join(self) -> None:
await asyncio.gather(
self._executor.join(), self._rtt_monitor.join(), return_exceptions=True
) # type: ignore[func-returns-value]

async def close(self) -> None:
self.gc_safe_close()
await self._rtt_monitor.close()
Expand Down
29 changes: 28 additions & 1 deletion pymongo/asynchronous/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from __future__ import annotations

import asyncio
import logging
import os
import queue
Expand All @@ -29,7 +30,7 @@

from pymongo import _csot, common, helpers_shared, periodic_executor
from pymongo.asynchronous.client_session import _ServerSession, _ServerSessionPool
from pymongo.asynchronous.monitor import SrvMonitor
from pymongo.asynchronous.monitor import MonitorBase, SrvMonitor
from pymongo.asynchronous.pool import Pool
from pymongo.asynchronous.server import Server
from pymongo.errors import (
Expand Down Expand Up @@ -207,6 +208,9 @@ async def target() -> bool:
if self._settings.fqdn is not None and not self._settings.load_balanced:
self._srv_monitor = SrvMonitor(self, self._settings)

# Stores all monitor tasks that need to be joined on close or server selection
self._monitor_tasks: list[MonitorBase] = []

async def open(self) -> None:
"""Start monitoring, or restart after a fork.
Expand Down Expand Up @@ -241,6 +245,8 @@ async def open(self) -> None:
# Close servers and clear the pools.
for server in self._servers.values():
await server.close()
if not _IS_SYNC:
self._monitor_tasks.append(server._monitor)
# Reset the session pool to avoid duplicate sessions in
# the child process.
self._session_pool.reset()
Expand Down Expand Up @@ -283,6 +289,10 @@ async def select_servers(
else:
server_timeout = server_selection_timeout

# Cleanup any completed monitor tasks safely
if not _IS_SYNC and self._monitor_tasks:
await self.cleanup_monitors()

async with self._lock:
server_descriptions = await self._select_servers_loop(
selector, server_timeout, operation, operation_id, address
Expand Down Expand Up @@ -520,6 +530,8 @@ async def _process_change(
and self._description.topology_type not in SRV_POLLING_TOPOLOGIES
):
await self._srv_monitor.close()
if not _IS_SYNC:
self._monitor_tasks.append(self._srv_monitor)

# Clear the pool from a failed heartbeat.
if reset_pool:
Expand Down Expand Up @@ -695,6 +707,8 @@ async def close(self) -> None:
old_td = self._description
for server in self._servers.values():
await server.close()
if not _IS_SYNC:
self._monitor_tasks.append(server._monitor)

# Mark all servers Unknown.
self._description = self._description.reset()
Expand All @@ -705,6 +719,8 @@ async def close(self) -> None:
# Stop SRV polling thread.
if self._srv_monitor:
await self._srv_monitor.close()
if not _IS_SYNC:
self._monitor_tasks.append(self._srv_monitor)

self._opened = False
self._closed = True
Expand Down Expand Up @@ -944,6 +960,8 @@ async def _update_servers(self) -> None:
for address, server in list(self._servers.items()):
if not self._description.has_server(address):
await server.close()
if not _IS_SYNC:
self._monitor_tasks.append(server._monitor)
self._servers.pop(address)

def _create_pool_for_server(self, address: _Address) -> Pool:
Expand Down Expand Up @@ -1031,6 +1049,15 @@ def _error_message(self, selector: Callable[[Selection], Selection]) -> str:
else:
return ",".join(str(server.error) for server in servers if server.error)

async def cleanup_monitors(self) -> None:
tasks = []
try:
while self._monitor_tasks:
tasks.append(self._monitor_tasks.pop())
except IndexError:
pass
await asyncio.gather(*[t.join() for t in tasks], return_exceptions=True) # type: ignore[func-returns-value]

def __repr__(self) -> str:
msg = ""
if not self._opened:
Expand Down
2 changes: 2 additions & 0 deletions pymongo/periodic_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ def close(self, dummy: Any = None) -> None:
callback; see monitor.py.
"""
self._stopped = True
if self._task is not None:
self._task.cancel()

async def join(self, timeout: Optional[int] = None) -> None:
if self._task is not None:
Expand Down
6 changes: 6 additions & 0 deletions pymongo/synchronous/mongo_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1559,6 +1559,12 @@ def close(self) -> None:
# TODO: PYTHON-1921 Encrypted MongoClients cannot be re-opened.
self._encrypter.close()
self._closed = True
if not _IS_SYNC:
asyncio.gather(
self._topology.cleanup_monitors(), # type: ignore[func-returns-value]
self._kill_cursors_executor.join(), # type: ignore[func-returns-value]
return_exceptions=True,
)

if not _IS_SYNC:
# Add support for contextlib.closing.
Expand Down
7 changes: 5 additions & 2 deletions pymongo/synchronous/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,9 @@ def close(self) -> None:
"""
self.gc_safe_close()

def join(self, timeout: Optional[int] = None) -> None:
def join(self) -> None:
"""Wait for the monitor to stop."""
self._executor.join(timeout)
self._executor.join()

def request_check(self) -> None:
"""If the monitor is sleeping, wake it soon."""
Expand Down Expand Up @@ -189,6 +189,9 @@ def gc_safe_close(self) -> None:
self._rtt_monitor.gc_safe_close()
self.cancel_check()

def join(self) -> None:
asyncio.gather(self._executor.join(), self._rtt_monitor.join(), return_exceptions=True) # type: ignore[func-returns-value]

def close(self) -> None:
self.gc_safe_close()
self._rtt_monitor.close()
Expand Down
29 changes: 28 additions & 1 deletion pymongo/synchronous/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from __future__ import annotations

import asyncio
import logging
import os
import queue
Expand Down Expand Up @@ -61,7 +62,7 @@
writable_server_selector,
)
from pymongo.synchronous.client_session import _ServerSession, _ServerSessionPool
from pymongo.synchronous.monitor import SrvMonitor
from pymongo.synchronous.monitor import MonitorBase, SrvMonitor
from pymongo.synchronous.pool import Pool
from pymongo.synchronous.server import Server
from pymongo.topology_description import (
Expand Down Expand Up @@ -207,6 +208,9 @@ def target() -> bool:
if self._settings.fqdn is not None and not self._settings.load_balanced:
self._srv_monitor = SrvMonitor(self, self._settings)

# Stores all monitor tasks that need to be joined on close or server selection
self._monitor_tasks: list[MonitorBase] = []

def open(self) -> None:
"""Start monitoring, or restart after a fork.
Expand Down Expand Up @@ -241,6 +245,8 @@ def open(self) -> None:
# Close servers and clear the pools.
for server in self._servers.values():
server.close()
if not _IS_SYNC:
self._monitor_tasks.append(server._monitor)
# Reset the session pool to avoid duplicate sessions in
# the child process.
self._session_pool.reset()
Expand Down Expand Up @@ -283,6 +289,10 @@ def select_servers(
else:
server_timeout = server_selection_timeout

# Cleanup any completed monitor tasks safely
if not _IS_SYNC and self._monitor_tasks:
self.cleanup_monitors()

with self._lock:
server_descriptions = self._select_servers_loop(
selector, server_timeout, operation, operation_id, address
Expand Down Expand Up @@ -520,6 +530,8 @@ def _process_change(
and self._description.topology_type not in SRV_POLLING_TOPOLOGIES
):
self._srv_monitor.close()
if not _IS_SYNC:
self._monitor_tasks.append(self._srv_monitor)

# Clear the pool from a failed heartbeat.
if reset_pool:
Expand Down Expand Up @@ -693,6 +705,8 @@ def close(self) -> None:
old_td = self._description
for server in self._servers.values():
server.close()
if not _IS_SYNC:
self._monitor_tasks.append(server._monitor)

# Mark all servers Unknown.
self._description = self._description.reset()
Expand All @@ -703,6 +717,8 @@ def close(self) -> None:
# Stop SRV polling thread.
if self._srv_monitor:
self._srv_monitor.close()
if not _IS_SYNC:
self._monitor_tasks.append(self._srv_monitor)

self._opened = False
self._closed = True
Expand Down Expand Up @@ -942,6 +958,8 @@ def _update_servers(self) -> None:
for address, server in list(self._servers.items()):
if not self._description.has_server(address):
server.close()
if not _IS_SYNC:
self._monitor_tasks.append(server._monitor)
self._servers.pop(address)

def _create_pool_for_server(self, address: _Address) -> Pool:
Expand Down Expand Up @@ -1029,6 +1047,15 @@ def _error_message(self, selector: Callable[[Selection], Selection]) -> str:
else:
return ",".join(str(server.error) for server in servers if server.error)

def cleanup_monitors(self) -> None:
tasks = []
try:
while self._monitor_tasks:
tasks.append(self._monitor_tasks.pop())
except IndexError:
pass
asyncio.gather(*[t.join() for t in tasks], return_exceptions=True) # type: ignore[func-returns-value]

def __repr__(self) -> str:
msg = ""
if not self._opened:
Expand Down

0 comments on commit 1b81847

Please sign in to comment.