diff --git a/src/promptflow-core/promptflow/_utils/async_utils.py b/src/promptflow-core/promptflow/_utils/async_utils.py index 41454e0d0f2..feb707c7cc4 100644 --- a/src/promptflow-core/promptflow/_utils/async_utils.py +++ b/src/promptflow-core/promptflow/_utils/async_utils.py @@ -4,6 +4,8 @@ import asyncio import functools +import signal +import threading from promptflow.tracing import ThreadPoolExecutorWithContext @@ -22,6 +24,60 @@ def _has_running_loop() -> bool: return False +class _AsyncTaskSigIntHandler: + """The handler to cancel the current task if SIGINT is received. + This is only for python<3.11 where the default cancelling behavior is not supported. + The code is similar to the python>=3.11 builtin implementation. + https://github.com/python/cpython/blob/46c808172fd3148e3397234b23674bf70734fb55/Lib/asyncio/runners.py#L150 + """ + + def __init__(self, task: asyncio.Task, loop: asyncio.AbstractEventLoop): + self._loop = loop + self._task = task + self._interrupt_count = 0 + + def on_sigint(self, signum, frame): + self._interrupt_count += 1 + if self._interrupt_count == 1 and not self._task.done(): + self._task.cancel() + # This call_soon_threadsafe would schedule the call as soon as possible, + # it would force the event loop to wake up then handle the cancellation request. + # This is to avoid the loop blocking with long timeout. + self._loop.call_soon_threadsafe(lambda: None) + return + raise KeyboardInterrupt() + + +async def _invoke_async_with_sigint_handler(async_func, *args, **kwargs): + """In python>=3.11, when sigint is hit, + asyncio.run in default cancel the running tasks before raising the KeyboardInterrupt, + this introduces the chance to handle the cancelled error. + So we have a similar implementation here so python<3.11 also have such feature. + https://github.com/python/cpython/blob/46c808172fd3148e3397234b23674bf70734fb55/Lib/asyncio/runners.py#L150 + """ + # For the scenario that we don't need to update sigint, just return. + # The scenarios include: + # For python >= 3.11, asyncio.run already updated the sigint for cancelling tasks. + # The user already has his own customized sigint. + # The current code is not in main thread. + if not _should_update_sigint(): + return await async_func(*args, **kwargs) + try: + loop = asyncio.get_running_loop() + task = asyncio.create_task(async_func(*args, **kwargs)) + signal.signal(signal.SIGINT, _AsyncTaskSigIntHandler(task, loop).on_sigint) + return await task + finally: + signal.signal(signal.SIGINT, signal.default_int_handler) + + +def _should_update_sigint(): + return ( + threading.current_thread() is threading.main_thread() + and signal.getsignal(signal.SIGINT) is signal.default_int_handler + ) + + def async_run_allowing_running_loop(async_func, *args, **kwargs): """Run an async function in a new thread, allowing the current thread to have a running event loop. @@ -36,7 +92,7 @@ def async_run_allowing_running_loop(async_func, *args, **kwargs): with ThreadPoolExecutorWithContext() as executor: return executor.submit(lambda: asyncio.run(async_func(*args, **kwargs))).result() else: - return asyncio.run(async_func(*args, **kwargs)) + return asyncio.run(_invoke_async_with_sigint_handler(async_func, *args, **kwargs)) def async_to_sync(func): diff --git a/src/promptflow-core/promptflow/_utils/process_utils.py b/src/promptflow-core/promptflow/_utils/process_utils.py index d4c4bd35eaa..9fc24331fd8 100644 --- a/src/promptflow-core/promptflow/_utils/process_utils.py +++ b/src/promptflow-core/promptflow/_utils/process_utils.py @@ -29,7 +29,7 @@ def block_terminate_signal_to_parent(): signal.set_wakeup_fd(-1) signal.signal(signal.SIGTERM, signal.SIG_DFL) - signal.signal(signal.SIGINT, signal.SIG_DFL) + signal.signal(signal.SIGINT, signal.default_int_handler) def get_available_max_worker_count(logger: logging.Logger = bulk_logger): diff --git a/src/promptflow-core/promptflow/executor/_async_nodes_scheduler.py b/src/promptflow-core/promptflow/executor/_async_nodes_scheduler.py index 01b2e926f34..19f5661f024 100644 --- a/src/promptflow-core/promptflow/executor/_async_nodes_scheduler.py +++ b/src/promptflow-core/promptflow/executor/_async_nodes_scheduler.py @@ -6,7 +6,6 @@ import contextvars import inspect import os -import signal import threading import time import traceback @@ -46,15 +45,6 @@ async def execute( inputs: Dict[str, Any], context: FlowExecutionContext, ) -> Tuple[dict, dict]: - # TODO: Provide cancel API - if threading.current_thread() is threading.main_thread(): - signal.signal(signal.SIGINT, signal_handler) - signal.signal(signal.SIGTERM, signal_handler) - else: - flow_logger.info( - "Current thread is not main thread, skip signal handler registration in AsyncNodesScheduler." - ) - # Semaphore should be created in the loop, otherwise it will not work. loop = asyncio.get_running_loop() self._semaphore = asyncio.Semaphore(self._node_concurrency) @@ -62,7 +52,11 @@ async def execute( monitor = ThreadWithContextVars( target=monitor_long_running_coroutine, args=( - interval, loop, self._task_start_time, self._task_last_log_time, self._dag_manager_completed_event + interval, + loop, + self._task_start_time, + self._task_last_log_time, + self._dag_manager_completed_event, ), daemon=True, ) @@ -80,7 +74,11 @@ async def execute( # This is because it will always call `executor.shutdown()` when exiting the `with` block. # Then the event loop will wait for all tasks to be completed before raising the cancellation error. # See reference: https://docs.python.org/3/library/concurrent.futures.html#concurrent.futures.Executor - outputs = await self._execute_with_thread_pool(executor, nodes, inputs, context) + try: + outputs = await self._execute_with_thread_pool(executor, nodes, inputs, context) + except asyncio.CancelledError: + await self.cancel() + raise executor.shutdown() return outputs @@ -171,16 +169,11 @@ async def _sync_function_to_async_task( # The task will not be executed before calling create_task. return await asyncio.get_running_loop().run_in_executor(executor, context.invoke_tool, node, f, kwargs) - -def signal_handler(sig, frame): - """ - Start a thread to monitor coroutines after receiving signal. - """ - flow_logger.info(f"Received signal {sig}({signal.Signals(sig).name}), start coroutine monitor thread.") - loop = asyncio.get_running_loop() - monitor = ThreadWithContextVars(target=monitor_coroutine_after_cancellation, args=(loop,)) - monitor.start() - raise KeyboardInterrupt + async def cancel(self): + flow_logger.info("Cancel requested, monitoring coroutines after cancellation.") + loop = asyncio.get_running_loop() + monitor = ThreadWithContextVars(target=monitor_coroutine_after_cancellation, args=(loop,)) + monitor.start() def log_stack_recursively(task: asyncio.Task, elapse_time: float): diff --git a/src/promptflow-core/promptflow/executor/flow_executor.py b/src/promptflow-core/promptflow/executor/flow_executor.py index 68100007737..df3099404b4 100644 --- a/src/promptflow-core/promptflow/executor/flow_executor.py +++ b/src/promptflow-core/promptflow/executor/flow_executor.py @@ -1098,15 +1098,12 @@ async def _exec_async( context, allow_generator_output, ) - except KeyboardInterrupt as ex: - # Run will be cancelled when the process receives a SIGINT signal. - # KeyboardInterrupt will be raised after asyncio finishes its signal handling - # End run with the KeyboardInterrupt exception, so that its status will be Canceled - flow_logger.info("Received KeyboardInterrupt, cancel the run.") - # Update the run info of those running nodes to a canceled status. + except asyncio.CancelledError as ex: + flow_logger.info("Received cancelled error, cancel the run.") run_tracker.cancel_node_runs(run_id) run_tracker.end_run(line_run_id, ex=ex) - raise + if self._raise_ex: + raise except Exception as e: run_tracker.end_run(line_run_id, ex=e) if self._raise_ex: