@@ -352,8 +352,13 @@ def connection(self):
352352 def is_proto_lt_1_0 (self ):
353353 return self .connection ._protocol .is_legacy
354354
355+ @property
356+ def dbname (self ):
357+ return self ._impl ._working_params .database
358+
355359
356360class ConnectedTestCaseMixin :
361+ is_client_async = True
357362
358363 @classmethod
359364 def make_test_client (
@@ -362,11 +367,17 @@ def make_test_client(
362367 database = 'edgedb' ,
363368 user = 'edgedb' ,
364369 password = 'test' ,
365- connection_class = asyncio_client . AsyncIOConnection ,
370+ connection_class = ... ,
366371 ):
367372 conargs = cls .get_connect_args (
368373 cluster = cluster , database = database , user = user , password = password )
369- return TestAsyncIOClient (
374+ if connection_class is ...:
375+ connection_class = (
376+ asyncio_client .AsyncIOConnection
377+ if cls .is_client_async
378+ else blocking_client .BlockingIOConnection
379+ )
380+ return (TestAsyncIOClient if cls .is_client_async else TestClient )(
370381 connection_class = connection_class ,
371382 max_concurrency = 1 ,
372383 ** conargs ,
@@ -384,6 +395,10 @@ def get_connect_args(cls, *,
384395 database = database ))
385396 return conargs
386397
398+ @classmethod
399+ def adapt_call (cls , coro ):
400+ return cls .loop .run_until_complete (coro )
401+
387402
388403class DatabaseTestCase (ClusterTestCase , ConnectedTestCaseMixin ):
389404 SETUP = None
@@ -398,15 +413,15 @@ class DatabaseTestCase(ClusterTestCase, ConnectedTestCaseMixin):
398413
399414 def setUp (self ):
400415 if self .SETUP_METHOD :
401- self .loop . run_until_complete (
416+ self .adapt_call (
402417 self .client .execute (self .SETUP_METHOD ))
403418
404419 super ().setUp ()
405420
406421 def tearDown (self ):
407422 try :
408423 if self .TEARDOWN_METHOD :
409- self .loop . run_until_complete (
424+ self .adapt_call (
410425 self .client .execute (self .TEARDOWN_METHOD ))
411426 finally :
412427 try :
@@ -431,7 +446,7 @@ def setUpClass(cls):
431446 if not class_set_up :
432447 script = f'CREATE DATABASE { dbname } ;'
433448 cls .admin_client = cls .make_test_client ()
434- cls .loop . run_until_complete (cls .admin_client .execute (script ))
449+ cls .adapt_call (cls .admin_client .execute (script ))
435450
436451 cls .client = cls .make_test_client (database = dbname )
437452
@@ -440,11 +455,17 @@ def setUpClass(cls):
440455 if script :
441456 # The setup is expected to contain a CREATE MIGRATION,
442457 # which needs to be wrapped in a transaction.
443- async def execute ():
444- async for tr in cls .client .transaction ():
445- async with tr :
446- await tr .execute (script )
447- cls .loop .run_until_complete (execute ())
458+ if cls .is_client_async :
459+ async def execute ():
460+ async for tr in cls .client .transaction ():
461+ async with tr :
462+ await tr .execute (script )
463+ else :
464+ def execute ():
465+ for tr in cls .client .transaction ():
466+ with tr :
467+ tr .execute (script )
468+ cls .adapt_call (execute ())
448469
449470 @classmethod
450471 def get_database_name (cls ):
@@ -507,19 +528,22 @@ def tearDownClass(cls):
507528
508529 try :
509530 if script :
510- cls .loop . run_until_complete (
531+ cls .adapt_call (
511532 cls .client .execute (script ))
512533 finally :
513534 try :
514- cls .loop .run_until_complete (cls .client .aclose ())
535+ if cls .is_client_async :
536+ cls .adapt_call (cls .client .aclose ())
537+ else :
538+ cls .client .close ()
515539
516540 dbname = cls .get_database_name ()
517541 script = f'DROP DATABASE { dbname } ;'
518542
519543 retry = cls .TEARDOWN_RETRY_DROP_DB
520544 for i in range (retry ):
521545 try :
522- cls .loop . run_until_complete (
546+ cls .adapt_call (
523547 cls .admin_client .execute (script ))
524548 except edgedb .errors .ExecutionError :
525549 if i < retry - 1 :
@@ -536,8 +560,11 @@ def tearDownClass(cls):
536560 finally :
537561 try :
538562 if cls .admin_client is not None :
539- cls .loop .run_until_complete (
540- cls .admin_client .aclose ())
563+ if cls .is_client_async :
564+ cls .adapt_call (
565+ cls .admin_client .aclose ())
566+ else :
567+ cls .admin_client .close ()
541568 finally :
542569 super ().tearDownClass ()
543570
@@ -549,27 +576,11 @@ class AsyncQueryTestCase(DatabaseTestCase):
549576class SyncQueryTestCase (DatabaseTestCase ):
550577 BASE_TEST_CLASS = True
551578 TEARDOWN_RETRY_DROP_DB = 5
579+ is_client_async = False
552580
553- def setUp (self ):
554- super ().setUp ()
555-
556- cls = type (self )
557- cls .async_client = cls .client
558-
559- conargs = cls .get_connect_args ().copy ()
560- conargs .update (dict (database = cls .async_client .dbname ))
561-
562- cls .client = TestClient (
563- connection_class = blocking_client .BlockingIOConnection ,
564- max_concurrency = 1 ,
565- ** conargs
566- )
567-
568- def tearDown (self ):
569- cls = type (self )
570- cls .client .close ()
571- cls .client = cls .async_client
572- del cls .async_client
581+ @classmethod
582+ def adapt_call (cls , result ):
583+ return result
573584
574585
575586_lock_cnt = 0
0 commit comments