Skip to content

Commit

Permalink
[batch] Fix async exit stacks (#13969)
Browse files Browse the repository at this point in the history
I couldn't find the best issue for this. Should fix #13908, but I
thought there was another issue about reducing noisy grafana alerts
which this PR also addresses.
  • Loading branch information
jigold authored Nov 2, 2023
1 parent d231b40 commit a5c7a8a
Show file tree
Hide file tree
Showing 9 changed files with 80 additions and 59 deletions.
19 changes: 13 additions & 6 deletions batch/batch/cloud/azure/driver/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ async def create(
machine_name_prefix: str,
namespace: str,
inst_coll_configs: InstanceCollectionConfigs,
task_manager: aiotools.BackgroundTaskManager, # BORROWED
) -> 'AzureDriver':
azure_config = get_azure_config()
subscription_id = azure_config.subscription_id
Expand Down Expand Up @@ -68,6 +67,8 @@ async def create(
app, subscription_id, resource_group, ssh_public_key, arm_client, compute_client, billing_manager
)

task_manager = aiotools.BackgroundTaskManager()

create_pools_coros = [
Pool.create(
app,
Expand Down Expand Up @@ -110,6 +111,7 @@ async def create(
inst_coll_manager,
jpim,
billing_manager,
task_manager,
)

task_manager.ensure_future(periodically_call(60, driver.delete_orphaned_nics))
Expand All @@ -135,6 +137,7 @@ def __init__(
inst_coll_manager: InstanceCollectionManager,
job_private_inst_manager: JobPrivateInstanceManager,
billing_manager: AzureBillingManager,
task_manager: aiotools.BackgroundTaskManager,
):
self.db = db
self.machine_name_prefix = machine_name_prefix
Expand All @@ -150,6 +153,7 @@ def __init__(
self.job_private_inst_manager = job_private_inst_manager
self._billing_manager = billing_manager
self._inst_coll_manager = inst_coll_manager
self._task_manager = task_manager

@property
def billing_manager(self) -> AzureBillingManager:
Expand All @@ -161,18 +165,21 @@ def inst_coll_manager(self) -> InstanceCollectionManager:

async def shutdown(self) -> None:
try:
await self.arm_client.close()
await self._task_manager.shutdown_and_wait()
finally:
try:
await self.compute_client.close()
await self.arm_client.close()
finally:
try:
await self.resources_client.close()
await self.compute_client.close()
finally:
try:
await self.network_client.close()
await self.resources_client.close()
finally:
await self.pricing_client.close()
try:
await self.network_client.close()
finally:
await self.pricing_client.close()

def _resource_is_orphaned(self, resource_name: str) -> bool:
instance_name = resource_name.rsplit('-', maxsplit=1)[0]
Expand Down
6 changes: 2 additions & 4 deletions batch/batch/cloud/driver.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from gear import Database
from gear.cloud_config import get_global_config
from hailtop import aiotools

from ..driver.driver import CloudDriver
from ..inst_coll_config import InstanceCollectionConfigs
Expand All @@ -14,12 +13,11 @@ async def get_cloud_driver(
machine_name_prefix: str,
namespace: str,
inst_coll_configs: InstanceCollectionConfigs,
task_manager: aiotools.BackgroundTaskManager,
) -> CloudDriver:
cloud = get_global_config()['cloud']

if cloud == 'azure':
return await AzureDriver.create(app, db, machine_name_prefix, namespace, inst_coll_configs, task_manager)
return await AzureDriver.create(app, db, machine_name_prefix, namespace, inst_coll_configs)

assert cloud == 'gcp', cloud
return await GCPDriver.create(app, db, machine_name_prefix, namespace, inst_coll_configs, task_manager)
return await GCPDriver.create(app, db, machine_name_prefix, namespace, inst_coll_configs)
15 changes: 11 additions & 4 deletions batch/batch/cloud/gcp/driver/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ async def create(
machine_name_prefix: str,
namespace: str,
inst_coll_configs: InstanceCollectionConfigs,
task_manager: aiotools.BackgroundTaskManager, # BORROWED
) -> 'GCPDriver':
gcp_config = get_gcp_config()
project = gcp_config.project
Expand Down Expand Up @@ -67,6 +66,8 @@ async def create(
inst_coll_manager = InstanceCollectionManager(db, machine_name_prefix, zone_monitor, region, regions)
resource_manager = GCPResourceManager(project, compute_client, billing_manager)

task_manager = aiotools.BackgroundTaskManager()

create_pools_coros = [
Pool.create(
app,
Expand Down Expand Up @@ -105,6 +106,7 @@ async def create(
inst_coll_manager,
jpim,
billing_manager,
task_manager,
)

task_manager.ensure_future(periodically_call(15, driver.process_activity_logs))
Expand All @@ -126,6 +128,7 @@ def __init__(
inst_coll_manager: InstanceCollectionManager,
job_private_inst_manager: JobPrivateInstanceManager,
billing_manager: GCPBillingManager,
task_manager: aiotools.BackgroundTaskManager,
):
self.db = db
self.machine_name_prefix = machine_name_prefix
Expand All @@ -137,6 +140,7 @@ def __init__(
self.job_private_inst_manager = job_private_inst_manager
self._billing_manager = billing_manager
self._inst_coll_manager = inst_coll_manager
self._task_manager = task_manager

@property
def billing_manager(self) -> GCPBillingManager:
Expand All @@ -148,12 +152,15 @@ def inst_coll_manager(self) -> InstanceCollectionManager:

async def shutdown(self) -> None:
try:
await self.compute_client.close()
await self._task_manager.shutdown_and_wait()
finally:
try:
await self.activity_logs_client.close()
await self.compute_client.close()
finally:
await self._billing_manager.close()
try:
await self.activity_logs_client.close()
finally:
await self._billing_manager.close()

async def process_activity_logs(self) -> None:
async def _process_activity_log_events_since(mark):
Expand Down
6 changes: 3 additions & 3 deletions batch/batch/driver/canceller.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,11 @@ def __init__(self, app):

self.task_manager = aiotools.BackgroundTaskManager()

def shutdown(self):
async def shutdown_and_wait(self):
try:
self.task_manager.shutdown()
await self.task_manager.shutdown_and_wait()
finally:
self.async_worker_pool.shutdown()
await self.async_worker_pool.shutdown_and_wait()

async def cancel_cancelled_ready_jobs_loop_body(self):
records = self.db.select_and_fetchall(
Expand Down
50 changes: 25 additions & 25 deletions batch/batch/driver/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -1558,18 +1558,25 @@ def log(self, request, response, time):


async def on_startup(app):
task_manager = aiotools.BackgroundTaskManager()
app['task_manager'] = task_manager

app['client_session'] = httpx.client_session()
exit_stack = AsyncExitStack()
app['exit_stack'] = exit_stack

kubernetes_asyncio.config.load_incluster_config()
app['k8s_client'] = kubernetes_asyncio.client.CoreV1Api()
app['k8s_cache'] = K8sCache(app['k8s_client'])

async def close_and_wait():
# - Following warning mitigation described here: https://github.com/aio-libs/aiohttp/pull/2045
# - Fixed in aiohttp 4.0.0: https://github.com/aio-libs/aiohttp/issues/1925
await app['k8s_client'].api_client.close()
await asyncio.sleep(0.250)

exit_stack.push_async_callback(close_and_wait)

db = Database()
await db.async_init(maxsize=50)
app['db'] = db
exit_stack.push_async_callback(app['db'].async_close)

row = await db.select_and_fetchone(
'''
Expand All @@ -1590,18 +1597,28 @@ async def on_startup(app):
app['cancel_ready_state_changed'] = asyncio.Event()
app['cancel_creating_state_changed'] = asyncio.Event()
app['cancel_running_state_changed'] = asyncio.Event()

app['async_worker_pool'] = AsyncWorkerPool(100, queue_size=100)
exit_stack.push_async_callback(app['async_worker_pool'].shutdown_and_wait)

fs = get_cloud_async_fs()
app['file_store'] = FileStore(fs, BATCH_STORAGE_URI, instance_id)
exit_stack.push_async_callback(app['file_store'].close)

inst_coll_configs = await InstanceCollectionConfigs.create(db)

app['driver'] = await get_cloud_driver(
app, db, MACHINE_NAME_PREFIX, DEFAULT_NAMESPACE, inst_coll_configs, task_manager
)
app['client_session'] = httpx.client_session()
exit_stack.push_async_callback(app['client_session'].close)

app['driver'] = await get_cloud_driver(app, db, MACHINE_NAME_PREFIX, DEFAULT_NAMESPACE, inst_coll_configs)
exit_stack.push_async_callback(app['driver'].shutdown)

app['canceller'] = await Canceller.create(app)
exit_stack.push_async_callback(app['canceller'].shutdown_and_wait)

task_manager = aiotools.BackgroundTaskManager()
app['task_manager'] = task_manager
exit_stack.push_async_callback(app['task_manager'].shutdown_and_wait)

task_manager.ensure_future(periodically_call(10, monitor_billing_limits, app))
task_manager.ensure_future(periodically_call(10, cancel_fast_failing_batches, app))
Expand All @@ -1614,24 +1631,7 @@ async def on_startup(app):

async def on_cleanup(app):
try:
async with AsyncExitStack() as cleanup:
cleanup.callback(app['canceller'].shutdown)
cleanup.callback(app['task_manager'].shutdown)
cleanup.push_async_callback(app['driver'].shutdown)
cleanup.push_async_callback(app['file_store'].shutdown)
cleanup.push_async_callback(app['client_session'].close)
cleanup.callback(app['async_worker_pool'].shutdown)
cleanup.push_async_callback(app['db'].async_close)

k8s: kubernetes_asyncio.client.CoreV1Api = app['k8s_client']

async def close_and_wait():
# - Following warning mitigation described here: https://github.com/aio-libs/aiohttp/pull/2045
# - Fixed in aiohttp 4.0.0: https://github.com/aio-libs/aiohttp/issues/1925
await k8s.api_client.close()
await asyncio.sleep(0.250)

cleanup.push_async_callback(close_and_wait)
await app['exit_stack'].aclose()
finally:
await asyncio.gather(*(t for t in asyncio.all_tasks() if t is not asyncio.current_task()))

Expand Down
19 changes: 12 additions & 7 deletions batch/batch/front_end/front_end.py
Original file line number Diff line number Diff line change
Expand Up @@ -2903,12 +2903,16 @@ def log(self, request, response, time):


async def on_startup(app):
app['task_manager'] = aiotools.BackgroundTaskManager()
exit_stack = AsyncExitStack()
app['exit_stack'] = exit_stack

app['client_session'] = httpx.client_session()
exit_stack.push_async_callback(app['client_session'].close)

db = Database()
await db.async_init()
app['db'] = db
exit_stack.push_async_callback(app['db'].async_close)

row = await db.select_and_fetchone(
'''
Expand All @@ -2923,6 +2927,7 @@ async def on_startup(app):
app['instance_id'] = instance_id

app['hail_credentials'] = hail_credentials()
exit_stack.push_async_callback(app['hail_credentials'].close)

app['frozen'] = row['frozen']

Expand All @@ -2937,8 +2942,13 @@ async def on_startup(app):

fs = get_cloud_async_fs()
app['file_store'] = FileStore(fs, BATCH_STORAGE_URI, instance_id)
exit_stack.push_async_callback(app['file_store'].close)

app['task_manager'] = aiotools.BackgroundTaskManager()
exit_stack.callback(app['task_manager'].shutdown)

app['inst_coll_configs'] = await InstanceCollectionConfigs.create(db)
exit_stack.push_async_callback(app['file_store'].close)

cancel_batch_state_changed = asyncio.Event()
app['cancel_batch_state_changed'] = cancel_batch_state_changed
Expand All @@ -2958,12 +2968,7 @@ async def on_startup(app):


async def on_cleanup(app):
async with AsyncExitStack() as stack:
stack.callback(app['task_manager'].shutdown)
stack.push_async_callback(app['hail_credentials'].close)
stack.push_async_callback(app['client_session'].close)
stack.push_async_callback(app['file_store'].close)
stack.push_async_callback(app['db'].async_close)
await app['exit_stack'].aclose()


def run():
Expand Down
17 changes: 8 additions & 9 deletions batch/batch/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3039,15 +3039,15 @@ async def shutdown(self):
log.info('Worker.shutdown')
self._jvm_initializer_task.cancel()
async with AsyncExitStack() as cleanup:
cleanup.push_async_callback(self.client_session.close)
if self.fs:
cleanup.push_async_callback(self.fs.close)
if self.file_store:
cleanup.push_async_callback(self.file_store.close)
for jvmqueue in self._jvmpools_by_cores.values():
while not jvmqueue.queue.empty():
cleanup.push_async_callback(jvmqueue.queue.get_nowait().kill)
cleanup.push_async_callback(self.task_manager.shutdown_and_wait)
if self.file_store:
cleanup.push_async_callback(self.file_store.close)
if self.fs:
cleanup.push_async_callback(self.fs.close)
cleanup.push_async_callback(self.client_session.close)

async def run_job(self, job):
try:
Expand Down Expand Up @@ -3475,11 +3475,10 @@ async def async_main():
with aiomonitor.start_monitor(asyncio.get_event_loop(), locals=locals()):
try:
async with AsyncExitStack() as cleanup:
cleanup.push_async_callback(worker.shutdown)
cleanup.push_async_callback(CLOUD_WORKER_API.close)
cleanup.push_async_callback(network_allocator_task_manager.shutdown_and_wait)
cleanup.push_async_callback(docker.close)

cleanup.push_async_callback(network_allocator_task_manager.shutdown_and_wait)
cleanup.push_async_callback(CLOUD_WORKER_API.close)
cleanup.push_async_callback(worker.shutdown)
await worker.run()
finally:
asyncio.get_event_loop().set_debug(True)
Expand Down
3 changes: 2 additions & 1 deletion hail/python/hailtop/aiotools/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,5 @@ def shutdown(self):

async def shutdown_and_wait(self):
self.shutdown()
await asyncio.wait(self.tasks, return_when=asyncio.ALL_COMPLETED)
if self.tasks:
await asyncio.wait(self.tasks, return_when=asyncio.ALL_COMPLETED)
4 changes: 4 additions & 0 deletions hail/python/hailtop/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,10 @@ def shutdown(self):
except Exception:
pass

async def shutdown_and_wait(self):
self.shutdown()
await asyncio.gather(*self.workers, return_exceptions=True)


class WaitableSharedPool:
def __init__(self, worker_pool: AsyncWorkerPool):
Expand Down

0 comments on commit a5c7a8a

Please sign in to comment.