From f2771d507895eaed357cdba7a2f8d4dd6d582483 Mon Sep 17 00:00:00 2001 From: Pratyai Mazumder Date: Tue, 29 Oct 2024 16:29:52 +0100 Subject: [PATCH] Privatize the helper functions. --- .../dataflow/const_assignment_fusion.py | 98 +++++++++---------- 1 file changed, 49 insertions(+), 49 deletions(-) diff --git a/dace/transformation/dataflow/const_assignment_fusion.py b/dace/transformation/dataflow/const_assignment_fusion.py index ab5027a4f4..70a249c3dd 100644 --- a/dace/transformation/dataflow/const_assignment_fusion.py +++ b/dace/transformation/dataflow/const_assignment_fusion.py @@ -14,7 +14,7 @@ from dace.transformation.interstate import StateFusionExtended -def unique_top_level_map_node(graph: SDFGState) -> Optional[Tuple[MapEntry, MapExit]]: +def _unique_top_level_map_node(graph: SDFGState) -> Optional[Tuple[MapEntry, MapExit]]: all_top_nodes = [n for n, s in graph.scope_dict().items() if s is None] if not all(isinstance(n, (MapEntry, AccessNode)) for n in all_top_nodes): return None @@ -25,14 +25,14 @@ def unique_top_level_map_node(graph: SDFGState) -> Optional[Tuple[MapEntry, MapE return en[0], ex[0] -def floating_nodes_graph(*args): +def _floating_nodes_graph(*args): g = OrderedDiGraph() for n in args: g.add_node(n) return g -def consistent_branch_const_assignment_table(graph: Node) -> Tuple[bool, dict]: +def _consistent_branch_const_assignment_table(graph: Node) -> Tuple[bool, dict]: """ If the graph consists of only conditional consistent constant assignments, produces a table mapping data arrays and memlets to their consistent constant assignments. See the class docstring for what is considered consistent. @@ -86,10 +86,10 @@ def consistent_branch_const_assignment_table(graph: Node) -> Tuple[bool, dict]: # ...must assign... return False, table op = n.code.code[0] - if not is_constant_or_numerical_literal(op.value) or len(op.targets) != 1: + if not _is_constant_or_numerical_literal(op.value) or len(op.targets) != 1: # ...a constant to a single target. return False, table - const = value_of_constant_or_numerical_literal(op.value) + const = _value_of_constant_or_numerical_literal(op.value) for oe in body.out_edges(n): dst = oe.data dst_arr = oe.data.data @@ -101,17 +101,17 @@ def consistent_branch_const_assignment_table(graph: Node) -> Tuple[bool, dict]: return True, table -def is_constant_or_numerical_literal(n: ast.Expr): +def _is_constant_or_numerical_literal(n: ast.Expr): """Work around the API differences between Python versions (e.g., 3.7 and 3.12)""" return isinstance(n, (ast.Constant, ast.Num)) -def value_of_constant_or_numerical_literal(n: ast.Expr): +def _value_of_constant_or_numerical_literal(n: ast.Expr): """Work around the API differences between Python versions (e.g., 3.7 and 3.12)""" return n.value if isinstance(n, ast.Constant) else n.n -def consistent_const_assignment_table(graph: SDFGState, en: MapEntry, ex: MapExit) -> Tuple[bool, dict]: +def _consistent_const_assignment_table(graph: SDFGState, en: MapEntry, ex: MapExit) -> Tuple[bool, dict]: """ If the graph consists of only (conditional or unconditional) consistent constant assignments, produces a table mapping data arrays and memlets to their consistent constant assignments. See the class docstring for what is @@ -121,7 +121,7 @@ def consistent_const_assignment_table(graph: SDFGState, en: MapEntry, ex: MapExi for n in graph.all_nodes_between(en, ex): if isinstance(n, NestedSDFG): # First handle the case of conditional constant assignment. - is_branch_const_assignment, internal_table = consistent_branch_const_assignment_table(n) + is_branch_const_assignment, internal_table = _consistent_branch_const_assignment_table(n) if not is_branch_const_assignment: return False, table for oe in graph.out_edges(n): @@ -133,7 +133,7 @@ def consistent_const_assignment_table(graph: SDFGState, en: MapEntry, ex: MapExi table[dst] = internal_table[oe.src_conn] table[dst_arr] = internal_table[oe.src_conn] elif isinstance(n, MapEntry): - is_const_assignment, internal_table = consistent_const_assignment_table(graph, n, graph.exit_node(n)) + is_const_assignment, internal_table = _consistent_const_assignment_table(graph, n, graph.exit_node(n)) if not is_const_assignment: return False, table for k, v in internal_table.items(): @@ -151,10 +151,10 @@ def consistent_const_assignment_table(graph: SDFGState, en: MapEntry, ex: MapExi # ...that assigns... return False, table op = n.code.code[0] - if not is_constant_or_numerical_literal(op.value) or len(op.targets) != 1: + if not _is_constant_or_numerical_literal(op.value) or len(op.targets) != 1: # ...a constant to a single target. return False, table - const = value_of_constant_or_numerical_literal(op.value) + const = _value_of_constant_or_numerical_literal(op.value) for oe in graph.out_edges(n): dst = oe.data dst_arr = oe.data.data @@ -166,14 +166,14 @@ def consistent_const_assignment_table(graph: SDFGState, en: MapEntry, ex: MapExi return True, table -def removeprefix(c: str, p: str): +def _removeprefix(c: str, p: str): """Since `str.removeprefix()` wasn't added until Python 3.9""" if not c.startswith(p): return c return c[len(p):] -def add_equivalent_connectors(dst: Union[EntryNode, ExitNode], src: Union[EntryNode, ExitNode]): +def _add_equivalent_connectors(dst: Union[EntryNode, ExitNode], src: Union[EntryNode, ExitNode]): """ Create the additional connectors in the first exit node that matches the second exit node (which will be removed later). @@ -181,7 +181,7 @@ def add_equivalent_connectors(dst: Union[EntryNode, ExitNode], src: Union[EntryN conn_map = defaultdict() for c, v in src.in_connectors.items(): assert c.startswith('IN_') - cbase = removeprefix(c, 'IN_') + cbase = _removeprefix(c, 'IN_') sc = dst.next_connector(cbase) conn_map[f"IN_{cbase}"] = f"IN_{sc}" conn_map[f"OUT_{cbase}"] = f"OUT_{sc}" @@ -192,19 +192,19 @@ def add_equivalent_connectors(dst: Union[EntryNode, ExitNode], src: Union[EntryN return conn_map -def connector_counterpart(c: Union[str, None]) -> Union[str, None]: +def _connector_counterpart(c: Union[str, None]) -> Union[str, None]: """If it's an input connector, find the corresponding output connector, and vice versa.""" if c is None: return None assert isinstance(c, str) if c.startswith('IN_'): - return f"OUT_{removeprefix(c, 'IN_')}" + return f"OUT_{_removeprefix(c, 'IN_')}" elif c.startswith('OUT_'): - return f"IN_{removeprefix(c, 'OUT_')}" + return f"IN_{_removeprefix(c, 'OUT_')}" return None -def consolidate_empty_dependencies(graph: SDFGState, first_entry: MapEntry, second_entry: MapEntry): +def _consolidate_empty_dependencies(graph: SDFGState, first_entry: MapEntry, second_entry: MapEntry): """ Remove all the incoming edges of the two maps and add empty edges from the union of the access nodes they depended on before. @@ -245,7 +245,7 @@ def consolidate_empty_dependencies(graph: SDFGState, first_entry: MapEntry, seco graph.add_memlet_path(n, en, memlet=Memlet()) -def consolidate_written_nodes(graph: SDFGState, first_exit: MapExit, second_exit: MapExit): +def _consolidate_written_nodes(graph: SDFGState, first_exit: MapExit, second_exit: MapExit): """ If the two maps write to the same underlying data array through two access nodes, replace those edges' destination with a single shared copy. @@ -291,7 +291,7 @@ def consolidate_written_nodes(graph: SDFGState, first_exit: MapExit, second_exit graph.remove_node(n) -def consume_map_exactly(graph: SDFGState, dst: Tuple[MapEntry, MapExit], src: Tuple[MapEntry, MapExit]): +def _consume_map_exactly(graph: SDFGState, dst: Tuple[MapEntry, MapExit], src: Tuple[MapEntry, MapExit]): """ Transfer the entirety of `src` map's body into `dst` map. Only possible when the two maps' ranges are identical. """ @@ -299,7 +299,7 @@ def consume_map_exactly(graph: SDFGState, dst: Tuple[MapEntry, MapExit], src: Tu src_en, src_ex = src assert all(e.data.is_empty() for e in graph.in_edges(src_en)) - cmap = add_equivalent_connectors(dst_en, src_en) + cmap = _add_equivalent_connectors(dst_en, src_en) for e in graph.in_edges(src_en): graph.add_memlet_path(e.src, dst_en, src_conn=e.src_conn, dst_conn=cmap.get(e.dst_conn), @@ -311,7 +311,7 @@ def consume_map_exactly(graph: SDFGState, dst: Tuple[MapEntry, MapExit], src: Tu memlet=Memlet.from_memlet(e.data)) graph.remove_edge(e) - cmap = add_equivalent_connectors(dst_ex, src_ex) + cmap = _add_equivalent_connectors(dst_ex, src_ex) for e in graph.in_edges(src_ex): graph.add_memlet_path(e.src, dst_ex, src_conn=e.src_conn, dst_conn=cmap.get(e.dst_conn), @@ -327,8 +327,8 @@ def consume_map_exactly(graph: SDFGState, dst: Tuple[MapEntry, MapExit], src: Tu graph.remove_node(src_ex) -def consume_map_with_grid_strided_loop(graph: SDFGState, dst: Tuple[MapEntry, MapExit], - src: Tuple[MapEntry, MapExit]): +def _consume_map_with_grid_strided_loop(graph: SDFGState, dst: Tuple[MapEntry, MapExit], + src: Tuple[MapEntry, MapExit]): """ Transfer the entirety of `src` map's body into `dst` map, guarded behind a _grid-strided_ loop. Prerequisite: `dst` map's range must cover `src` map's range in entirety. Statically checking this may not @@ -350,10 +350,10 @@ def range_for_grid_stride(r, val, bound): en, ex = graph.add_map(graph.sdfg._find_new_name('gsl'), ndrange={k: v for k, v in zip(gsl_params, gsl_ranges)}, schedule=ScheduleType.Sequential) - consume_map_exactly(graph, (en, ex), src) + _consume_map_exactly(graph, (en, ex), src) assert all(e.data.is_empty() for e in graph.in_edges(en)) - cmap = add_equivalent_connectors(dst_en, en) + cmap = _add_equivalent_connectors(dst_en, en) for e in graph.in_edges(en): graph.add_memlet_path(e.src, dst_en, src_conn=e.src_conn, dst_conn=cmap.get(e.dst_conn), @@ -363,11 +363,11 @@ def range_for_grid_stride(r, val, bound): memlet=Memlet.from_memlet(e.data)) graph.remove_edge(e) - cmap = add_equivalent_connectors(dst_ex, ex) + cmap = _add_equivalent_connectors(dst_ex, ex) for e in graph.out_edges(ex): graph.add_memlet_path(e.src, dst_ex, src_conn=e.src_conn, - dst_conn=connector_counterpart(cmap.get(e.src_conn)), + dst_conn=_connector_counterpart(cmap.get(e.src_conn)), memlet=Memlet.from_memlet(e.data)) graph.add_memlet_path(dst_ex, e.dst, src_conn=cmap.get(e.src_conn), dst_conn=e.dst_conn, @@ -379,7 +379,7 @@ def range_for_grid_stride(r, val, bound): graph.add_memlet_path(ex, dst_ex, memlet=Memlet()) -def fused_range(r1: Range, r2: Range) -> Optional[Range]: +def _fused_range(r1: Range, r2: Range) -> Optional[Range]: if r1 == r2: return r1 if len(r1) != len(r2): @@ -398,7 +398,7 @@ def fused_range(r1: Range, r2: Range) -> Optional[Range]: return r -def maps_have_compatible_ranges(first_entry: MapEntry, second_entry: MapEntry, use_grid_strided_loops: bool) -> bool: +def _maps_have_compatible_ranges(first_entry: MapEntry, second_entry: MapEntry, use_grid_strided_loops: bool) -> bool: """Decide if the two ranges are compatible. See the class docstring for what is considered compatible.""" if first_entry.map.schedule != second_entry.map.schedule: # If the two maps are not to be scheduled on the same device, don't fuse them. @@ -446,7 +446,7 @@ class ConstAssignmentMapFusion(MapFusion): def expressions(cls): # Take any two maps, then check that _every_ path from the first map to second map has exactly one access node # in the middle and the second edge of the path is empty. - return [floating_nodes_graph(cls.first_map_entry, cls.second_map_entry)] + return [_floating_nodes_graph(cls.first_map_entry, cls.second_map_entry)] def map_nodes(self, graph: SDFGState): """Return the entry and exit nodes of the relevant maps as a tuple: entry_1, exit_1, entry_2, exit_2.""" @@ -486,15 +486,15 @@ def can_be_applied(self, graph: SDFGState, expr_index: int, sdfg: SDFG, permissi return False first_entry, first_exit, second_entry, second_exit = self.map_nodes(graph) - if not maps_have_compatible_ranges(first_entry, second_entry, - use_grid_strided_loops=self.use_grid_strided_loops): + if not _maps_have_compatible_ranges(first_entry, second_entry, + use_grid_strided_loops=self.use_grid_strided_loops): return False # Both maps must have consistent constant assignment for the target arrays. - is_const_assignment, assignments = consistent_const_assignment_table(graph, first_entry, first_exit) + is_const_assignment, assignments = _consistent_const_assignment_table(graph, first_entry, first_exit) if not is_const_assignment: return False - is_const_assignment, further_assignments = consistent_const_assignment_table(graph, second_entry, second_exit) + is_const_assignment, further_assignments = _consistent_const_assignment_table(graph, second_entry, second_exit) if not is_const_assignment: return False for k, v in further_assignments.items(): @@ -508,7 +508,7 @@ def apply(self, graph: SDFGState, sdfg: SDFG): # By now, we know that the two maps are compatible, not reading anything, and just blindly writing constants # _consistently_. - is_const_assignment, assignments = consistent_const_assignment_table(graph, first_entry, first_exit) + is_const_assignment, assignments = _consistent_const_assignment_table(graph, first_entry, first_exit) assert is_const_assignment # Rename in case loop variables are named differently. @@ -519,28 +519,28 @@ def apply(self, graph: SDFGState, sdfg: SDFG): second_entry.map.params = first_entry.map.params # Consolidate the incoming dependencies of the two maps. - consolidate_empty_dependencies(graph, first_entry, second_entry) + _consolidate_empty_dependencies(graph, first_entry, second_entry) # Consolidate the written access nodes of the two maps. - consolidate_written_nodes(graph, first_exit, second_exit) + _consolidate_written_nodes(graph, first_exit, second_exit) # If the ranges are identical, then simply fuse the two maps. Otherwise, use grid-strided loops. - assert fused_range(first_entry.map.range, second_entry.map.range) is not None + assert _fused_range(first_entry.map.range, second_entry.map.range) is not None en, ex = graph.add_map(sdfg._find_new_name('map_fusion_wrapper'), ndrange={k: v for k, v in zip(first_entry.map.params, - fused_range(first_entry.map.range, - second_entry.map.range))}, + _fused_range(first_entry.map.range, + second_entry.map.range))}, schedule=first_entry.map.schedule) if first_entry.map.range == second_entry.map.range: for cur_en, cur_ex in [(first_entry, first_exit), (second_entry, second_exit)]: - consume_map_exactly(graph, (en, ex), (cur_en, cur_ex)) + _consume_map_exactly(graph, (en, ex), (cur_en, cur_ex)) elif self.use_grid_strided_loops: assert ScheduleType.Sequential not in [first_entry.map.schedule, second_entry.map.schedule] for cur_en, cur_ex in [(first_entry, first_exit), (second_entry, second_exit)]: if en.map.range == cur_en.map.range: - consume_map_exactly(graph, (en, ex), (cur_en, cur_ex)) + _consume_map_exactly(graph, (en, ex), (cur_en, cur_ex)) else: - consume_map_with_grid_strided_loop(graph, (en, ex), (cur_en, cur_ex)) + _consume_map_with_grid_strided_loop(graph, (en, ex), (cur_en, cur_ex)) # Cleanup: remove duplicate empty dependencies. seen = set() @@ -577,13 +577,13 @@ def can_be_applied(self, graph: ControlFlowRegion, expr_index: int, sdfg: SDFG, # Moreover, the states together must contain a consistent constant assignment map. assignments = {} for st in [st0, st1]: - en_ex = unique_top_level_map_node(st) + en_ex = _unique_top_level_map_node(st) if not en_ex: return False en, ex = en_ex if any(not e.data.is_empty for e in st.in_edges(en)): return False - is_const_assignment, further_assignments = consistent_const_assignment_table(st, en, ex) + is_const_assignment, further_assignments = _consistent_const_assignment_table(st, en, ex) if not is_const_assignment: return False for k, v in further_assignments.items(): @@ -592,8 +592,8 @@ def can_be_applied(self, graph: ControlFlowRegion, expr_index: int, sdfg: SDFG, assignments[k] = v # Moreover, both states' ranges must be compatible. - if not maps_have_compatible_ranges(unique_top_level_map_node(st0)[0], unique_top_level_map_node(st1)[0], - use_grid_strided_loops=self.use_grid_strided_loops): + if not _maps_have_compatible_ranges(_unique_top_level_map_node(st0)[0], _unique_top_level_map_node(st1)[0], + use_grid_strided_loops=self.use_grid_strided_loops): return False return True