Skip to content

Commit 33a912c

Browse files
committed
Correct edgedb.Client.close() timeout behavior
Also added tests in test_sync_query.py, and use only sync client in sync_* tests.
1 parent 6d0d6ab commit 33a912c

13 files changed

+628
-110
lines changed

edgedb/_testbase.py

Lines changed: 46 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -352,8 +352,13 @@ def connection(self):
352352
def is_proto_lt_1_0(self):
353353
return self.connection._protocol.is_legacy
354354

355+
@property
356+
def dbname(self):
357+
return self._impl._working_params.database
358+
355359

356360
class ConnectedTestCaseMixin:
361+
is_client_async = True
357362

358363
@classmethod
359364
def make_test_client(
@@ -362,11 +367,17 @@ def make_test_client(
362367
database='edgedb',
363368
user='edgedb',
364369
password='test',
365-
connection_class=asyncio_client.AsyncIOConnection,
370+
connection_class=...,
366371
):
367372
conargs = cls.get_connect_args(
368373
cluster=cluster, database=database, user=user, password=password)
369-
return TestAsyncIOClient(
374+
if connection_class is ...:
375+
connection_class = (
376+
asyncio_client.AsyncIOConnection
377+
if cls.is_client_async
378+
else blocking_client.BlockingIOConnection
379+
)
380+
return (TestAsyncIOClient if cls.is_client_async else TestClient)(
370381
connection_class=connection_class,
371382
max_concurrency=1,
372383
**conargs,
@@ -384,6 +395,10 @@ def get_connect_args(cls, *,
384395
database=database))
385396
return conargs
386397

398+
@classmethod
399+
def adapt_call(cls, coro):
400+
return cls.loop.run_until_complete(coro)
401+
387402

388403
class DatabaseTestCase(ClusterTestCase, ConnectedTestCaseMixin):
389404
SETUP = None
@@ -398,15 +413,15 @@ class DatabaseTestCase(ClusterTestCase, ConnectedTestCaseMixin):
398413

399414
def setUp(self):
400415
if self.SETUP_METHOD:
401-
self.loop.run_until_complete(
416+
self.adapt_call(
402417
self.client.execute(self.SETUP_METHOD))
403418

404419
super().setUp()
405420

406421
def tearDown(self):
407422
try:
408423
if self.TEARDOWN_METHOD:
409-
self.loop.run_until_complete(
424+
self.adapt_call(
410425
self.client.execute(self.TEARDOWN_METHOD))
411426
finally:
412427
try:
@@ -431,7 +446,7 @@ def setUpClass(cls):
431446
if not class_set_up:
432447
script = f'CREATE DATABASE {dbname};'
433448
cls.admin_client = cls.make_test_client()
434-
cls.loop.run_until_complete(cls.admin_client.execute(script))
449+
cls.adapt_call(cls.admin_client.execute(script))
435450

436451
cls.client = cls.make_test_client(database=dbname)
437452

@@ -440,11 +455,17 @@ def setUpClass(cls):
440455
if script:
441456
# The setup is expected to contain a CREATE MIGRATION,
442457
# which needs to be wrapped in a transaction.
443-
async def execute():
444-
async for tr in cls.client.transaction():
445-
async with tr:
446-
await tr.execute(script)
447-
cls.loop.run_until_complete(execute())
458+
if cls.is_client_async:
459+
async def execute():
460+
async for tr in cls.client.transaction():
461+
async with tr:
462+
await tr.execute(script)
463+
else:
464+
def execute():
465+
for tr in cls.client.transaction():
466+
with tr:
467+
tr.execute(script)
468+
cls.adapt_call(execute())
448469

449470
@classmethod
450471
def get_database_name(cls):
@@ -507,19 +528,22 @@ def tearDownClass(cls):
507528

508529
try:
509530
if script:
510-
cls.loop.run_until_complete(
531+
cls.adapt_call(
511532
cls.client.execute(script))
512533
finally:
513534
try:
514-
cls.loop.run_until_complete(cls.client.aclose())
535+
if cls.is_client_async:
536+
cls.adapt_call(cls.client.aclose())
537+
else:
538+
cls.client.close()
515539

516540
dbname = cls.get_database_name()
517541
script = f'DROP DATABASE {dbname};'
518542

519543
retry = cls.TEARDOWN_RETRY_DROP_DB
520544
for i in range(retry):
521545
try:
522-
cls.loop.run_until_complete(
546+
cls.adapt_call(
523547
cls.admin_client.execute(script))
524548
except edgedb.errors.ExecutionError:
525549
if i < retry - 1:
@@ -536,8 +560,11 @@ def tearDownClass(cls):
536560
finally:
537561
try:
538562
if cls.admin_client is not None:
539-
cls.loop.run_until_complete(
540-
cls.admin_client.aclose())
563+
if cls.is_client_async:
564+
cls.adapt_call(
565+
cls.admin_client.aclose())
566+
else:
567+
cls.admin_client.close()
541568
finally:
542569
super().tearDownClass()
543570

@@ -549,27 +576,11 @@ class AsyncQueryTestCase(DatabaseTestCase):
549576
class SyncQueryTestCase(DatabaseTestCase):
550577
BASE_TEST_CLASS = True
551578
TEARDOWN_RETRY_DROP_DB = 5
579+
is_client_async = False
552580

553-
def setUp(self):
554-
super().setUp()
555-
556-
cls = type(self)
557-
cls.async_client = cls.client
558-
559-
conargs = cls.get_connect_args().copy()
560-
conargs.update(dict(database=cls.async_client.dbname))
561-
562-
cls.client = TestClient(
563-
connection_class=blocking_client.BlockingIOConnection,
564-
max_concurrency=1,
565-
**conargs
566-
)
567-
568-
def tearDown(self):
569-
cls = type(self)
570-
cls.client.close()
571-
cls.client = cls.async_client
572-
del cls.async_client
581+
@classmethod
582+
def adapt_call(cls, result):
583+
return result
573584

574585

575586
_lock_cnt = 0

edgedb/asyncio_client.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@
4242

4343
class AsyncIOConnection(base_client.BaseConnection):
4444
__slots__ = ("_loop",)
45-
_close_exceptions = (Exception, asyncio.CancelledError)
4645

4746
def __init__(self, loop, *args, **kwargs):
4847
super().__init__(*args, **kwargs)
@@ -61,6 +60,18 @@ async def connect_addr(self, addr, timeout):
6160
async def sleep(self, seconds):
6261
await asyncio.sleep(seconds)
6362

63+
async def aclose(self):
64+
"""Send graceful termination message wait for connection to drop."""
65+
if not self.is_closed():
66+
try:
67+
self._protocol.terminate()
68+
await self._protocol.wait_for_disconnect()
69+
except (Exception, asyncio.CancelledError):
70+
self.terminate()
71+
raise
72+
finally:
73+
self._cleanup()
74+
6475
def _protocol_factory(self):
6576
return asyncio_proto.AsyncIOProtocol(self._params, self._loop)
6677

@@ -104,7 +115,7 @@ async def _connect_addr(self, addr):
104115
if tr is not None:
105116
tr.close()
106117
raise con_utils.wrap_error(e) from e
107-
except Exception:
118+
except BaseException:
108119
if tr is not None:
109120
tr.close()
110121
raise
@@ -125,9 +136,9 @@ async def close(self, *, wait=True):
125136
if self._con is None:
126137
return
127138
if wait:
128-
await self._con.close()
139+
await self._con.aclose()
129140
else:
130-
self._pool._loop.create_task(self._con.close())
141+
self._pool._loop.create_task(self._con.aclose())
131142

132143
async def wait_until_released(self, timeout=None):
133144
await self._release_event.wait()

edgedb/base_client.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@ class BaseConnection(metaclass=abc.ABCMeta):
4242
_log_listeners: typing.Set[
4343
typing.Callable[[BaseConnection_T, errors.EdgeDBMessage], None]
4444
]
45-
_close_exceptions = (Exception,)
4645
__slots__ = (
4746
"__weakref__",
4847
"_protocol",
@@ -313,18 +312,6 @@ def terminate(self):
313312
finally:
314313
self._cleanup()
315314

316-
async def close(self):
317-
"""Send graceful termination message wait for connection to drop."""
318-
if not self.is_closed():
319-
try:
320-
self._protocol.terminate()
321-
await self._protocol.wait_for_disconnect()
322-
except self._close_exceptions:
323-
self.terminate()
324-
raise
325-
finally:
326-
self._cleanup()
327-
328315
def __repr__(self):
329316
if self.is_closed():
330317
return '<{classname} [closed] {id:#x}>'.format(

edgedb/blocking_client.py

Lines changed: 39 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,23 @@ def is_closed(self):
110110
return not (proto and proto.sock is not None and
111111
proto.sock.fileno() >= 0 and proto.connected)
112112

113+
async def close(self, timeout=None):
114+
"""Send graceful termination message wait for connection to drop."""
115+
if not self.is_closed():
116+
try:
117+
self._protocol.terminate()
118+
if timeout is None:
119+
await self._protocol.wait_for_disconnect()
120+
else:
121+
await self._protocol.wait_for(
122+
self._protocol.wait_for_disconnect(), timeout
123+
)
124+
except Exception:
125+
self.terminate()
126+
raise
127+
finally:
128+
self._cleanup()
129+
113130
def _dispatch_log_message(self, msg):
114131
for cb in self._log_listeners:
115132
cb(self, msg)
@@ -119,13 +136,13 @@ class _PoolConnectionHolder(base_client.PoolConnectionHolder):
119136
__slots__ = ()
120137
_event_class = threading.Event
121138

122-
async def close(self, *, wait=True):
139+
async def close(self, *, wait=True, timeout=None):
123140
if self._con is None:
124141
return
125-
await self._con.close()
142+
await self._con.close(timeout=timeout)
126143

127144
async def wait_until_released(self, timeout=None):
128-
self._release_event.wait(timeout)
145+
return self._release_event.wait(timeout)
129146

130147

131148
class _PoolImpl(base_client.BasePoolImpl):
@@ -200,17 +217,27 @@ async def close(self, timeout=None):
200217
if timeout is None:
201218
for ch in self._holders:
202219
await ch.wait_until_released()
220+
for ch in self._holders:
221+
await ch.close()
203222
else:
204-
remaining = timeout
223+
deadline = time.monotonic() + timeout
224+
for ch in self._holders:
225+
secs = deadline - time.monotonic()
226+
if secs <= 0:
227+
raise TimeoutError
228+
if not await ch.wait_until_released(secs):
229+
raise TimeoutError
205230
for ch in self._holders:
206-
start = time.monotonic()
207-
await ch.wait_until_released(remaining)
208-
remaining -= time.monotonic() - start
209-
if remaining <= 0:
210-
self.terminate()
211-
return
212-
for ch in self._holders:
213-
await ch.close()
231+
secs = deadline - time.monotonic()
232+
if secs <= 0:
233+
raise TimeoutError
234+
await ch.close(timeout=secs)
235+
except TimeoutError as e:
236+
self.terminate()
237+
raise errors.InterfaceError(
238+
"client is not fully closed in {} seconds; "
239+
"terminating now.".format(timeout)
240+
) from e
214241
except Exception:
215242
self.terminate()
216243
raise

edgedb/protocol/asyncio_proto.pyx

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import asyncio
2121

2222
from edgedb import errors
23+
from edgedb import compat
2324
from edgedb.pgproto.pgproto cimport (
2425
WriteBuffer,
2526
ReadBuffer,

edgedb/protocol/blocking_proto.pxd

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,5 +26,6 @@ cdef class BlockingIOProtocol(protocol.SansIOProtocolBackwardsCompatible):
2626

2727
cdef:
2828
readonly object sock
29+
float deadline
2930

3031
cdef _disconnect(self)

0 commit comments

Comments
 (0)