From 172dd0859bb97095cc56acee6451ff281517a7ca Mon Sep 17 00:00:00 2001 From: Stu Hood Date: Fri, 28 Oct 2022 17:01:36 -0700 Subject: [PATCH] Use mutable nodes to memoize `_DependencyMappingRequest`. --- src/python/pants/engine/internals/graph.py | 56 ++++++++++++++++++++-- 1 file changed, 52 insertions(+), 4 deletions(-) diff --git a/src/python/pants/engine/internals/graph.py b/src/python/pants/engine/internals/graph.py index 781f87f4476..1855f7bbf78 100644 --- a/src/python/pants/engine/internals/graph.py +++ b/src/python/pants/engine/internals/graph.py @@ -559,6 +559,26 @@ def visit(address: Address): ) +@dataclass +class _DependencyMappings: + """A mutable @rule output which stores memoized dependencies.""" + + dependencies: dict[Address, tuple[Target, ...]] + + +@dataclass(frozen=True) +class _DependencyMappingsRequest: + expanded_targets: bool + include_special_cased_deps: bool + + +@rule(_mutable=True) +async def dependency_mappings(_: _DependencyMappingsRequest) -> _DependencyMappings: + # NB: A new `_DependencyMappings` object will be created per distinct + # `_DependencyMappingsRequest` argument, but there is no need to actually consume it. + return _DependencyMappings({}) + + @dataclass(frozen=True) class _DependencyMappingRequest: tt_request: TransitiveTargetsRequest @@ -579,11 +599,31 @@ async def transitive_dependency_mapping(request: _DependencyMappingRequest) -> _ Unlike a traditional BFS algorithm, we batch each round of traversals via `MultiGet` for improved performance / concurrency. """ + memo = await Get( + _DependencyMappings, + _DependencyMappingsRequest( + expanded_targets=request.expanded_targets, + include_special_cased_deps=request.tt_request.include_special_cased_deps, + ), + ) + roots_as_targets = await Get(UnexpandedTargets, Addresses(request.tt_request.roots)) visited: OrderedSet[Target] = OrderedSet() - queued = FrozenOrderedSet(roots_as_targets) + queued = OrderedSet(roots_as_targets) dependency_mapping: dict[Address, tuple[Address, ...]] = {} while queued: + # Collect any dependencies which have already been computed by other callers. + memoized_dependencies = [ + (target.address, memo.dependencies[target.address]) + for target in queued + if target.address in memo.dependencies + ] + dependency_mapping.update( + (a, tuple(d.address for d in deps)) for a, deps in memoized_dependencies + ) + queued = OrderedSet(t for t in queued if t.address not in dependency_mapping) + + # Then compute any that were not memoized. direct_dependencies: tuple[Collection[Target], ...] if request.expanded_targets: direct_dependencies = await MultiGet( @@ -608,6 +648,7 @@ async def transitive_dependency_mapping(request: _DependencyMappingRequest) -> _ for tgt in queued ) + memo.dependencies.update(zip((t.address for t in queued), direct_dependencies)) dependency_mapping.update( zip( (t.address for t in queued), @@ -615,9 +656,16 @@ async def transitive_dependency_mapping(request: _DependencyMappingRequest) -> _ ) ) - queued = FrozenOrderedSet(itertools.chain.from_iterable(direct_dependencies)).difference( - visited - ) + queued = OrderedSet( + ( + *itertools.chain.from_iterable(direct_dependencies), + *( + dep + for _, dependencies_list in memoized_dependencies + for dep in dependencies_list + ), + ) + ).difference(visited) visited.update(queued) # NB: We use `roots_as_targets` to get the root addresses, rather than `request.roots`. This