From a5c7a8a3dcc48ebdb8b0ad66550d5c8e54db63fd Mon Sep 17 00:00:00 2001 From: jigold Date: Thu, 2 Nov 2023 17:43:51 -0400 Subject: [PATCH] [batch] Fix async exit stacks (#13969) I couldn't find the best issue for this. Should fix #13908, but I thought there was another issue about reducing noisy grafana alerts which this PR also addresses. --- batch/batch/cloud/azure/driver/driver.py | 19 ++++++--- batch/batch/cloud/driver.py | 6 +-- batch/batch/cloud/gcp/driver/driver.py | 15 +++++-- batch/batch/driver/canceller.py | 6 +-- batch/batch/driver/main.py | 50 ++++++++++++------------ batch/batch/front_end/front_end.py | 19 +++++---- batch/batch/worker/worker.py | 17 ++++---- hail/python/hailtop/aiotools/tasks.py | 3 +- hail/python/hailtop/utils/utils.py | 4 ++ 9 files changed, 80 insertions(+), 59 deletions(-) diff --git a/batch/batch/cloud/azure/driver/driver.py b/batch/batch/cloud/azure/driver/driver.py index e05e0d986b5..58ffb1f3570 100644 --- a/batch/batch/cloud/azure/driver/driver.py +++ b/batch/batch/cloud/azure/driver/driver.py @@ -28,7 +28,6 @@ async def create( machine_name_prefix: str, namespace: str, inst_coll_configs: InstanceCollectionConfigs, - task_manager: aiotools.BackgroundTaskManager, # BORROWED ) -> 'AzureDriver': azure_config = get_azure_config() subscription_id = azure_config.subscription_id @@ -68,6 +67,8 @@ async def create( app, subscription_id, resource_group, ssh_public_key, arm_client, compute_client, billing_manager ) + task_manager = aiotools.BackgroundTaskManager() + create_pools_coros = [ Pool.create( app, @@ -110,6 +111,7 @@ async def create( inst_coll_manager, jpim, billing_manager, + task_manager, ) task_manager.ensure_future(periodically_call(60, driver.delete_orphaned_nics)) @@ -135,6 +137,7 @@ def __init__( inst_coll_manager: InstanceCollectionManager, job_private_inst_manager: JobPrivateInstanceManager, billing_manager: AzureBillingManager, + task_manager: aiotools.BackgroundTaskManager, ): self.db = db self.machine_name_prefix = machine_name_prefix @@ -150,6 +153,7 @@ def __init__( self.job_private_inst_manager = job_private_inst_manager self._billing_manager = billing_manager self._inst_coll_manager = inst_coll_manager + self._task_manager = task_manager @property def billing_manager(self) -> AzureBillingManager: @@ -161,18 +165,21 @@ def inst_coll_manager(self) -> InstanceCollectionManager: async def shutdown(self) -> None: try: - await self.arm_client.close() + await self._task_manager.shutdown_and_wait() finally: try: - await self.compute_client.close() + await self.arm_client.close() finally: try: - await self.resources_client.close() + await self.compute_client.close() finally: try: - await self.network_client.close() + await self.resources_client.close() finally: - await self.pricing_client.close() + try: + await self.network_client.close() + finally: + await self.pricing_client.close() def _resource_is_orphaned(self, resource_name: str) -> bool: instance_name = resource_name.rsplit('-', maxsplit=1)[0] diff --git a/batch/batch/cloud/driver.py b/batch/batch/cloud/driver.py index a9b92fdb36b..0be00d9a749 100644 --- a/batch/batch/cloud/driver.py +++ b/batch/batch/cloud/driver.py @@ -1,6 +1,5 @@ from gear import Database from gear.cloud_config import get_global_config -from hailtop import aiotools from ..driver.driver import CloudDriver from ..inst_coll_config import InstanceCollectionConfigs @@ -14,12 +13,11 @@ async def get_cloud_driver( machine_name_prefix: str, namespace: str, inst_coll_configs: InstanceCollectionConfigs, - task_manager: aiotools.BackgroundTaskManager, ) -> CloudDriver: cloud = get_global_config()['cloud'] if cloud == 'azure': - return await AzureDriver.create(app, db, machine_name_prefix, namespace, inst_coll_configs, task_manager) + return await AzureDriver.create(app, db, machine_name_prefix, namespace, inst_coll_configs) assert cloud == 'gcp', cloud - return await GCPDriver.create(app, db, machine_name_prefix, namespace, inst_coll_configs, task_manager) + return await GCPDriver.create(app, db, machine_name_prefix, namespace, inst_coll_configs) diff --git a/batch/batch/cloud/gcp/driver/driver.py b/batch/batch/cloud/gcp/driver/driver.py index 339cc7e7c2a..4000b650469 100644 --- a/batch/batch/cloud/gcp/driver/driver.py +++ b/batch/batch/cloud/gcp/driver/driver.py @@ -25,7 +25,6 @@ async def create( machine_name_prefix: str, namespace: str, inst_coll_configs: InstanceCollectionConfigs, - task_manager: aiotools.BackgroundTaskManager, # BORROWED ) -> 'GCPDriver': gcp_config = get_gcp_config() project = gcp_config.project @@ -67,6 +66,8 @@ async def create( inst_coll_manager = InstanceCollectionManager(db, machine_name_prefix, zone_monitor, region, regions) resource_manager = GCPResourceManager(project, compute_client, billing_manager) + task_manager = aiotools.BackgroundTaskManager() + create_pools_coros = [ Pool.create( app, @@ -105,6 +106,7 @@ async def create( inst_coll_manager, jpim, billing_manager, + task_manager, ) task_manager.ensure_future(periodically_call(15, driver.process_activity_logs)) @@ -126,6 +128,7 @@ def __init__( inst_coll_manager: InstanceCollectionManager, job_private_inst_manager: JobPrivateInstanceManager, billing_manager: GCPBillingManager, + task_manager: aiotools.BackgroundTaskManager, ): self.db = db self.machine_name_prefix = machine_name_prefix @@ -137,6 +140,7 @@ def __init__( self.job_private_inst_manager = job_private_inst_manager self._billing_manager = billing_manager self._inst_coll_manager = inst_coll_manager + self._task_manager = task_manager @property def billing_manager(self) -> GCPBillingManager: @@ -148,12 +152,15 @@ def inst_coll_manager(self) -> InstanceCollectionManager: async def shutdown(self) -> None: try: - await self.compute_client.close() + await self._task_manager.shutdown_and_wait() finally: try: - await self.activity_logs_client.close() + await self.compute_client.close() finally: - await self._billing_manager.close() + try: + await self.activity_logs_client.close() + finally: + await self._billing_manager.close() async def process_activity_logs(self) -> None: async def _process_activity_log_events_since(mark): diff --git a/batch/batch/driver/canceller.py b/batch/batch/driver/canceller.py index b939c8e9fef..09a4efa86ae 100644 --- a/batch/batch/driver/canceller.py +++ b/batch/batch/driver/canceller.py @@ -67,11 +67,11 @@ def __init__(self, app): self.task_manager = aiotools.BackgroundTaskManager() - def shutdown(self): + async def shutdown_and_wait(self): try: - self.task_manager.shutdown() + await self.task_manager.shutdown_and_wait() finally: - self.async_worker_pool.shutdown() + await self.async_worker_pool.shutdown_and_wait() async def cancel_cancelled_ready_jobs_loop_body(self): records = self.db.select_and_fetchall( diff --git a/batch/batch/driver/main.py b/batch/batch/driver/main.py index 18fedb8ca6f..f6e6656918b 100644 --- a/batch/batch/driver/main.py +++ b/batch/batch/driver/main.py @@ -1558,18 +1558,25 @@ def log(self, request, response, time): async def on_startup(app): - task_manager = aiotools.BackgroundTaskManager() - app['task_manager'] = task_manager - - app['client_session'] = httpx.client_session() + exit_stack = AsyncExitStack() + app['exit_stack'] = exit_stack kubernetes_asyncio.config.load_incluster_config() app['k8s_client'] = kubernetes_asyncio.client.CoreV1Api() app['k8s_cache'] = K8sCache(app['k8s_client']) + async def close_and_wait(): + # - Following warning mitigation described here: https://github.com/aio-libs/aiohttp/pull/2045 + # - Fixed in aiohttp 4.0.0: https://github.com/aio-libs/aiohttp/issues/1925 + await app['k8s_client'].api_client.close() + await asyncio.sleep(0.250) + + exit_stack.push_async_callback(close_and_wait) + db = Database() await db.async_init(maxsize=50) app['db'] = db + exit_stack.push_async_callback(app['db'].async_close) row = await db.select_and_fetchone( ''' @@ -1590,18 +1597,28 @@ async def on_startup(app): app['cancel_ready_state_changed'] = asyncio.Event() app['cancel_creating_state_changed'] = asyncio.Event() app['cancel_running_state_changed'] = asyncio.Event() + app['async_worker_pool'] = AsyncWorkerPool(100, queue_size=100) + exit_stack.push_async_callback(app['async_worker_pool'].shutdown_and_wait) fs = get_cloud_async_fs() app['file_store'] = FileStore(fs, BATCH_STORAGE_URI, instance_id) + exit_stack.push_async_callback(app['file_store'].close) inst_coll_configs = await InstanceCollectionConfigs.create(db) - app['driver'] = await get_cloud_driver( - app, db, MACHINE_NAME_PREFIX, DEFAULT_NAMESPACE, inst_coll_configs, task_manager - ) + app['client_session'] = httpx.client_session() + exit_stack.push_async_callback(app['client_session'].close) + + app['driver'] = await get_cloud_driver(app, db, MACHINE_NAME_PREFIX, DEFAULT_NAMESPACE, inst_coll_configs) + exit_stack.push_async_callback(app['driver'].shutdown) app['canceller'] = await Canceller.create(app) + exit_stack.push_async_callback(app['canceller'].shutdown_and_wait) + + task_manager = aiotools.BackgroundTaskManager() + app['task_manager'] = task_manager + exit_stack.push_async_callback(app['task_manager'].shutdown_and_wait) task_manager.ensure_future(periodically_call(10, monitor_billing_limits, app)) task_manager.ensure_future(periodically_call(10, cancel_fast_failing_batches, app)) @@ -1614,24 +1631,7 @@ async def on_startup(app): async def on_cleanup(app): try: - async with AsyncExitStack() as cleanup: - cleanup.callback(app['canceller'].shutdown) - cleanup.callback(app['task_manager'].shutdown) - cleanup.push_async_callback(app['driver'].shutdown) - cleanup.push_async_callback(app['file_store'].shutdown) - cleanup.push_async_callback(app['client_session'].close) - cleanup.callback(app['async_worker_pool'].shutdown) - cleanup.push_async_callback(app['db'].async_close) - - k8s: kubernetes_asyncio.client.CoreV1Api = app['k8s_client'] - - async def close_and_wait(): - # - Following warning mitigation described here: https://github.com/aio-libs/aiohttp/pull/2045 - # - Fixed in aiohttp 4.0.0: https://github.com/aio-libs/aiohttp/issues/1925 - await k8s.api_client.close() - await asyncio.sleep(0.250) - - cleanup.push_async_callback(close_and_wait) + await app['exit_stack'].aclose() finally: await asyncio.gather(*(t for t in asyncio.all_tasks() if t is not asyncio.current_task())) diff --git a/batch/batch/front_end/front_end.py b/batch/batch/front_end/front_end.py index a9740788ef4..960de5d193c 100644 --- a/batch/batch/front_end/front_end.py +++ b/batch/batch/front_end/front_end.py @@ -2903,12 +2903,16 @@ def log(self, request, response, time): async def on_startup(app): - app['task_manager'] = aiotools.BackgroundTaskManager() + exit_stack = AsyncExitStack() + app['exit_stack'] = exit_stack + app['client_session'] = httpx.client_session() + exit_stack.push_async_callback(app['client_session'].close) db = Database() await db.async_init() app['db'] = db + exit_stack.push_async_callback(app['db'].async_close) row = await db.select_and_fetchone( ''' @@ -2923,6 +2927,7 @@ async def on_startup(app): app['instance_id'] = instance_id app['hail_credentials'] = hail_credentials() + exit_stack.push_async_callback(app['hail_credentials'].close) app['frozen'] = row['frozen'] @@ -2937,8 +2942,13 @@ async def on_startup(app): fs = get_cloud_async_fs() app['file_store'] = FileStore(fs, BATCH_STORAGE_URI, instance_id) + exit_stack.push_async_callback(app['file_store'].close) + + app['task_manager'] = aiotools.BackgroundTaskManager() + exit_stack.callback(app['task_manager'].shutdown) app['inst_coll_configs'] = await InstanceCollectionConfigs.create(db) + exit_stack.push_async_callback(app['file_store'].close) cancel_batch_state_changed = asyncio.Event() app['cancel_batch_state_changed'] = cancel_batch_state_changed @@ -2958,12 +2968,7 @@ async def on_startup(app): async def on_cleanup(app): - async with AsyncExitStack() as stack: - stack.callback(app['task_manager'].shutdown) - stack.push_async_callback(app['hail_credentials'].close) - stack.push_async_callback(app['client_session'].close) - stack.push_async_callback(app['file_store'].close) - stack.push_async_callback(app['db'].async_close) + await app['exit_stack'].aclose() def run(): diff --git a/batch/batch/worker/worker.py b/batch/batch/worker/worker.py index b40a377f3c8..5fd7dc68586 100644 --- a/batch/batch/worker/worker.py +++ b/batch/batch/worker/worker.py @@ -3039,15 +3039,15 @@ async def shutdown(self): log.info('Worker.shutdown') self._jvm_initializer_task.cancel() async with AsyncExitStack() as cleanup: + cleanup.push_async_callback(self.client_session.close) + if self.fs: + cleanup.push_async_callback(self.fs.close) + if self.file_store: + cleanup.push_async_callback(self.file_store.close) for jvmqueue in self._jvmpools_by_cores.values(): while not jvmqueue.queue.empty(): cleanup.push_async_callback(jvmqueue.queue.get_nowait().kill) cleanup.push_async_callback(self.task_manager.shutdown_and_wait) - if self.file_store: - cleanup.push_async_callback(self.file_store.close) - if self.fs: - cleanup.push_async_callback(self.fs.close) - cleanup.push_async_callback(self.client_session.close) async def run_job(self, job): try: @@ -3475,11 +3475,10 @@ async def async_main(): with aiomonitor.start_monitor(asyncio.get_event_loop(), locals=locals()): try: async with AsyncExitStack() as cleanup: - cleanup.push_async_callback(worker.shutdown) - cleanup.push_async_callback(CLOUD_WORKER_API.close) - cleanup.push_async_callback(network_allocator_task_manager.shutdown_and_wait) cleanup.push_async_callback(docker.close) - + cleanup.push_async_callback(network_allocator_task_manager.shutdown_and_wait) + cleanup.push_async_callback(CLOUD_WORKER_API.close) + cleanup.push_async_callback(worker.shutdown) await worker.run() finally: asyncio.get_event_loop().set_debug(True) diff --git a/hail/python/hailtop/aiotools/tasks.py b/hail/python/hailtop/aiotools/tasks.py index 57c39b731bf..a467793d5b6 100644 --- a/hail/python/hailtop/aiotools/tasks.py +++ b/hail/python/hailtop/aiotools/tasks.py @@ -45,4 +45,5 @@ def shutdown(self): async def shutdown_and_wait(self): self.shutdown() - await asyncio.wait(self.tasks, return_when=asyncio.ALL_COMPLETED) + if self.tasks: + await asyncio.wait(self.tasks, return_when=asyncio.ALL_COMPLETED) diff --git a/hail/python/hailtop/utils/utils.py b/hail/python/hailtop/utils/utils.py index f161cc988cf..ae4805a10c6 100644 --- a/hail/python/hailtop/utils/utils.py +++ b/hail/python/hailtop/utils/utils.py @@ -224,6 +224,10 @@ def shutdown(self): except Exception: pass + async def shutdown_and_wait(self): + self.shutdown() + await asyncio.gather(*self.workers, return_exceptions=True) + class WaitableSharedPool: def __init__(self, worker_pool: AsyncWorkerPool):