diff --git a/asyncpg/connect_utils.py b/asyncpg/connect_utils.py index 8d80dc29..86259be3 100644 --- a/asyncpg/connect_utils.py +++ b/asyncpg/connect_utils.py @@ -729,7 +729,7 @@ async def _connect(*, loop, timeout, connection_class, record_class, **kwargs): for addr in addrs: before = time.monotonic() try: - con = await _connect_addr( + return await _connect_addr( addr=addr, loop=loop, timeout=timeout, @@ -740,8 +740,6 @@ async def _connect(*, loop, timeout, connection_class, record_class, **kwargs): ) except (OSError, asyncio.TimeoutError, ConnectionError) as ex: last_error = ex - else: - return con finally: timeout -= time.monotonic() - before diff --git a/asyncpg/connection.py b/asyncpg/connection.py index 043c6ddd..3f678d16 100644 --- a/asyncpg/connection.py +++ b/asyncpg/connection.py @@ -9,11 +9,13 @@ import asyncpg import collections import collections.abc +import functools import itertools import sys import time import traceback import warnings +import weakref from . import compat from . import connect_utils @@ -70,7 +72,8 @@ def __init__(self, protocol, transport, loop, self._stmt_cache = _StatementCache( loop=loop, max_size=config.statement_cache_size, - on_remove=self._maybe_gc_stmt, + on_remove=functools.partial( + _weak_maybe_gc_stmt, weakref.ref(self)), max_lifetime=config.max_cached_statement_lifetime) self._stmts_to_close = set() @@ -2260,4 +2263,10 @@ def _check_record_class(record_class): ) +def _weak_maybe_gc_stmt(weak_ref, stmt): + self = weak_ref() + if self is not None: + self._maybe_gc_stmt(stmt) + + _uid = 0 diff --git a/tests/test_connect.py b/tests/test_connect.py index 7b08f93d..ff884af8 100644 --- a/tests/test_connect.py +++ b/tests/test_connect.py @@ -7,7 +7,6 @@ import asyncio import contextlib -import gc import ipaddress import os import platform @@ -1448,14 +1447,11 @@ class TestConnectionGC(tb.ClusterTestCase): async def _run_no_explicit_close_test(self): con = await self.connect() + await con.fetchval("select 123") proto = con._protocol conref = weakref.ref(con) del con - gc.collect() - gc.collect() - gc.collect() - self.assertIsNone(conref()) self.assertTrue(proto.is_closed())