diff --git a/iree/turbine/kernel/wave/scheduling/graph_utils.py b/iree/turbine/kernel/wave/scheduling/graph_utils.py index e3d759dc..610c56cd 100644 --- a/iree/turbine/kernel/wave/scheduling/graph_utils.py +++ b/iree/turbine/kernel/wave/scheduling/graph_utils.py @@ -15,6 +15,7 @@ from functools import partial from ..utils import safe_subs import multiprocessing as mp +from typing import Optional T = index_symbol("$INITIATION_INTERVAL") @@ -203,7 +204,7 @@ def all_pairs_longest_paths_symbolic( def all_pairs_longest_paths( - graph: fx.Graph, edges: list[Edge], T: int, pool: mp.Pool + graph: fx.Graph, edges: list[Edge], T: int, pool: Optional[mp.Pool] ) -> dict[tuple[fx.Node, fx.Node], IndexExpr]: """ For each node in the graph, compute the longest path to all other nodes. @@ -229,7 +230,11 @@ def all_pairs_longest_paths( # Parallel implementation for k in range(N): func = partial(all_pairs_longest_path_parallel, N, D, k) - results = pool.map(func, range(N)) + if pool is not None: + results = pool.map(func, range(N)) + else: + results = map(func, range(N)) + for result in results: D[result[0]] = result[1] diff --git a/iree/turbine/kernel/wave/scheduling/modulo_scheduling.py b/iree/turbine/kernel/wave/scheduling/modulo_scheduling.py index 7de08162..fb7794e4 100644 --- a/iree/turbine/kernel/wave/scheduling/modulo_scheduling.py +++ b/iree/turbine/kernel/wave/scheduling/modulo_scheduling.py @@ -111,7 +111,16 @@ def schedule_graph(self) -> tuple[dict[fx.Node, int], bool]: # TODO: Come up with a better heuristic on an upper bound for the initiation interval. T_max_range = 3 * T0 success = False - pool = mp.get_context("fork").Pool(processes=mp.cpu_count()) + + # We cannot create create child processes when running in daemon process + # so just run sequentially. + # TODO: Find a way to reuse processes from the outside pool if we are + # already running inside. + if mp.current_process().daemon: + pool = None + else: + pool = mp.get_context("fork").Pool(processes=mp.cpu_count()) + for T in range(T0, T0 + T_max_range): logger.debug(f"Trying initiation interval: {T}.") self.RT = np.zeros((T, len(self.resources))) @@ -150,8 +159,10 @@ def schedule_graph(self) -> tuple[dict[fx.Node, int], bool]: break else: raise Exception("Failed to schedule the graph.") - pool.close() - pool.join() + + if pool is not None: + pool.close() + pool.join() self._initiation_interval = T return self.schedule, success