diff --git a/cylc/flow/network/scan_nt.py b/cylc/flow/network/scan_nt.py index 8e1ba11766e..a47440a83a0 100644 --- a/cylc/flow/network/scan_nt.py +++ b/cylc/flow/network/scan_nt.py @@ -71,7 +71,7 @@ async def dir_is_flow(listing): @pipe -async def scan(run_dir=None): +async def scan(run_dir=None, max_depth=3): """List flows installed on the filesystem. This is an async generator so use and async for to extract results:: @@ -80,8 +80,13 @@ async def scan(run_dir=None): print(flow['name']) Args: - directory (pathlib.Path): + run_dir (pathlib.Path): The directory to scan, defaults to the cylc run directory. + max_depth (int): + The maximum number of levels to descend before bailing. + + * ``max_depth=1`` will pick up top-level suites (e.g. ``foo``). + * ``max_depth=2`` will pick up nested suites (e.g. ``foo/bar``). Yields: dict - Dictionary containing information about the flow. @@ -94,10 +99,10 @@ async def scan(run_dir=None): stack = asyncio.Queue() for subdir in await scandir(run_dir): if subdir.is_dir(): - await stack.put(subdir) + await stack.put((1, subdir)) # for path in stack: - async for path in asyncqgen(stack): + async for depth, path in asyncqgen(stack): contents = await scandir(path) if await dir_is_flow(contents): # this is a flow directory @@ -105,11 +110,11 @@ async def scan(run_dir=None): 'name': str(path.relative_to(run_dir)), 'path': path, } - else: + elif depth < max_depth: # we may have a nested flow, lets see... for subdir in contents: if subdir.is_dir(): - await stack.put(subdir) + await stack.put((depth + 1, subdir)) def join_regexes(*patterns): diff --git a/tests/integration/test_scan_nt.py b/tests/integration/test_scan_nt.py index e16f47e91ac..f6e85854e0e 100644 --- a/tests/integration/test_scan_nt.py +++ b/tests/integration/test_scan_nt.py @@ -124,6 +124,18 @@ def run_dir_with_really_nasty_symlinks(): rmtree(tmp_path) +@pytest.fixture(scope='session') +def nested_run_dir(): + tmp_path = Path(TemporaryDirectory().name) + tmp_path.mkdir() + init_flows( + tmp_path, + running=('a', 'b/c', 'd/e/f', 'g/h/i/j'), + ) + yield tmp_path + rmtree(tmp_path) + + async def listify(async_gen, field='name'): """Convert an async generator into a list.""" ret = [] @@ -219,3 +231,21 @@ async def test_is_active(sample_run_dir): {'path': sample_run_dir / 'elephant'}, True ) + + +@pytest.mark.asyncio +async def test_max_depth(nested_run_dir): + """It should descend only as far as permitted.""" + assert await listify( + scan(nested_run_dir, max_depth=1) + ) == [ + 'a' + ] + + assert await listify( + scan(nested_run_dir, max_depth=3) + ) == [ + 'a', + 'b/c', + 'd/e/f' + ]