Skip to content

Commit

Permalink
Privatize the helper functions.
Browse files Browse the repository at this point in the history
  • Loading branch information
pratyai committed Oct 29, 2024
1 parent a4e6071 commit f2771d5
Showing 1 changed file with 49 additions and 49 deletions.
98 changes: 49 additions & 49 deletions dace/transformation/dataflow/const_assignment_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from dace.transformation.interstate import StateFusionExtended


def unique_top_level_map_node(graph: SDFGState) -> Optional[Tuple[MapEntry, MapExit]]:
def _unique_top_level_map_node(graph: SDFGState) -> Optional[Tuple[MapEntry, MapExit]]:
all_top_nodes = [n for n, s in graph.scope_dict().items() if s is None]
if not all(isinstance(n, (MapEntry, AccessNode)) for n in all_top_nodes):
return None
Expand All @@ -25,14 +25,14 @@ def unique_top_level_map_node(graph: SDFGState) -> Optional[Tuple[MapEntry, MapE
return en[0], ex[0]


def floating_nodes_graph(*args):
def _floating_nodes_graph(*args):
g = OrderedDiGraph()
for n in args:
g.add_node(n)
return g


def consistent_branch_const_assignment_table(graph: Node) -> Tuple[bool, dict]:
def _consistent_branch_const_assignment_table(graph: Node) -> Tuple[bool, dict]:
"""
If the graph consists of only conditional consistent constant assignments, produces a table mapping data arrays
and memlets to their consistent constant assignments. See the class docstring for what is considered consistent.
Expand Down Expand Up @@ -86,10 +86,10 @@ def consistent_branch_const_assignment_table(graph: Node) -> Tuple[bool, dict]:
# ...must assign...
return False, table
op = n.code.code[0]
if not is_constant_or_numerical_literal(op.value) or len(op.targets) != 1:
if not _is_constant_or_numerical_literal(op.value) or len(op.targets) != 1:
# ...a constant to a single target.
return False, table
const = value_of_constant_or_numerical_literal(op.value)
const = _value_of_constant_or_numerical_literal(op.value)
for oe in body.out_edges(n):
dst = oe.data
dst_arr = oe.data.data
Expand All @@ -101,17 +101,17 @@ def consistent_branch_const_assignment_table(graph: Node) -> Tuple[bool, dict]:
return True, table


def is_constant_or_numerical_literal(n: ast.Expr):
def _is_constant_or_numerical_literal(n: ast.Expr):
"""Work around the API differences between Python versions (e.g., 3.7 and 3.12)"""
return isinstance(n, (ast.Constant, ast.Num))


def value_of_constant_or_numerical_literal(n: ast.Expr):
def _value_of_constant_or_numerical_literal(n: ast.Expr):
"""Work around the API differences between Python versions (e.g., 3.7 and 3.12)"""
return n.value if isinstance(n, ast.Constant) else n.n


def consistent_const_assignment_table(graph: SDFGState, en: MapEntry, ex: MapExit) -> Tuple[bool, dict]:
def _consistent_const_assignment_table(graph: SDFGState, en: MapEntry, ex: MapExit) -> Tuple[bool, dict]:
"""
If the graph consists of only (conditional or unconditional) consistent constant assignments, produces a table
mapping data arrays and memlets to their consistent constant assignments. See the class docstring for what is
Expand All @@ -121,7 +121,7 @@ def consistent_const_assignment_table(graph: SDFGState, en: MapEntry, ex: MapExi
for n in graph.all_nodes_between(en, ex):
if isinstance(n, NestedSDFG):
# First handle the case of conditional constant assignment.
is_branch_const_assignment, internal_table = consistent_branch_const_assignment_table(n)
is_branch_const_assignment, internal_table = _consistent_branch_const_assignment_table(n)
if not is_branch_const_assignment:
return False, table
for oe in graph.out_edges(n):
Expand All @@ -133,7 +133,7 @@ def consistent_const_assignment_table(graph: SDFGState, en: MapEntry, ex: MapExi
table[dst] = internal_table[oe.src_conn]
table[dst_arr] = internal_table[oe.src_conn]
elif isinstance(n, MapEntry):
is_const_assignment, internal_table = consistent_const_assignment_table(graph, n, graph.exit_node(n))
is_const_assignment, internal_table = _consistent_const_assignment_table(graph, n, graph.exit_node(n))
if not is_const_assignment:
return False, table
for k, v in internal_table.items():
Expand All @@ -151,10 +151,10 @@ def consistent_const_assignment_table(graph: SDFGState, en: MapEntry, ex: MapExi
# ...that assigns...
return False, table
op = n.code.code[0]
if not is_constant_or_numerical_literal(op.value) or len(op.targets) != 1:
if not _is_constant_or_numerical_literal(op.value) or len(op.targets) != 1:
# ...a constant to a single target.
return False, table
const = value_of_constant_or_numerical_literal(op.value)
const = _value_of_constant_or_numerical_literal(op.value)
for oe in graph.out_edges(n):
dst = oe.data
dst_arr = oe.data.data
Expand All @@ -166,22 +166,22 @@ def consistent_const_assignment_table(graph: SDFGState, en: MapEntry, ex: MapExi
return True, table


def removeprefix(c: str, p: str):
def _removeprefix(c: str, p: str):
"""Since `str.removeprefix()` wasn't added until Python 3.9"""
if not c.startswith(p):
return c
return c[len(p):]


def add_equivalent_connectors(dst: Union[EntryNode, ExitNode], src: Union[EntryNode, ExitNode]):
def _add_equivalent_connectors(dst: Union[EntryNode, ExitNode], src: Union[EntryNode, ExitNode]):
"""
Create the additional connectors in the first exit node that matches the second exit node (which will be removed
later).
"""
conn_map = defaultdict()
for c, v in src.in_connectors.items():
assert c.startswith('IN_')
cbase = removeprefix(c, 'IN_')
cbase = _removeprefix(c, 'IN_')
sc = dst.next_connector(cbase)
conn_map[f"IN_{cbase}"] = f"IN_{sc}"
conn_map[f"OUT_{cbase}"] = f"OUT_{sc}"
Expand All @@ -192,19 +192,19 @@ def add_equivalent_connectors(dst: Union[EntryNode, ExitNode], src: Union[EntryN
return conn_map


def connector_counterpart(c: Union[str, None]) -> Union[str, None]:
def _connector_counterpart(c: Union[str, None]) -> Union[str, None]:
"""If it's an input connector, find the corresponding output connector, and vice versa."""
if c is None:
return None
assert isinstance(c, str)
if c.startswith('IN_'):
return f"OUT_{removeprefix(c, 'IN_')}"
return f"OUT_{_removeprefix(c, 'IN_')}"
elif c.startswith('OUT_'):
return f"IN_{removeprefix(c, 'OUT_')}"
return f"IN_{_removeprefix(c, 'OUT_')}"
return None


def consolidate_empty_dependencies(graph: SDFGState, first_entry: MapEntry, second_entry: MapEntry):
def _consolidate_empty_dependencies(graph: SDFGState, first_entry: MapEntry, second_entry: MapEntry):
"""
Remove all the incoming edges of the two maps and add empty edges from the union of the access nodes they
depended on before.
Expand Down Expand Up @@ -245,7 +245,7 @@ def consolidate_empty_dependencies(graph: SDFGState, first_entry: MapEntry, seco
graph.add_memlet_path(n, en, memlet=Memlet())


def consolidate_written_nodes(graph: SDFGState, first_exit: MapExit, second_exit: MapExit):
def _consolidate_written_nodes(graph: SDFGState, first_exit: MapExit, second_exit: MapExit):
"""
If the two maps write to the same underlying data array through two access nodes, replace those edges'
destination with a single shared copy.
Expand Down Expand Up @@ -291,15 +291,15 @@ def consolidate_written_nodes(graph: SDFGState, first_exit: MapExit, second_exit
graph.remove_node(n)


def consume_map_exactly(graph: SDFGState, dst: Tuple[MapEntry, MapExit], src: Tuple[MapEntry, MapExit]):
def _consume_map_exactly(graph: SDFGState, dst: Tuple[MapEntry, MapExit], src: Tuple[MapEntry, MapExit]):
"""
Transfer the entirety of `src` map's body into `dst` map. Only possible when the two maps' ranges are identical.
"""
dst_en, dst_ex = dst
src_en, src_ex = src

assert all(e.data.is_empty() for e in graph.in_edges(src_en))
cmap = add_equivalent_connectors(dst_en, src_en)
cmap = _add_equivalent_connectors(dst_en, src_en)
for e in graph.in_edges(src_en):
graph.add_memlet_path(e.src, dst_en,
src_conn=e.src_conn, dst_conn=cmap.get(e.dst_conn),
Expand All @@ -311,7 +311,7 @@ def consume_map_exactly(graph: SDFGState, dst: Tuple[MapEntry, MapExit], src: Tu
memlet=Memlet.from_memlet(e.data))
graph.remove_edge(e)

cmap = add_equivalent_connectors(dst_ex, src_ex)
cmap = _add_equivalent_connectors(dst_ex, src_ex)
for e in graph.in_edges(src_ex):
graph.add_memlet_path(e.src, dst_ex,
src_conn=e.src_conn, dst_conn=cmap.get(e.dst_conn),
Expand All @@ -327,8 +327,8 @@ def consume_map_exactly(graph: SDFGState, dst: Tuple[MapEntry, MapExit], src: Tu
graph.remove_node(src_ex)


def consume_map_with_grid_strided_loop(graph: SDFGState, dst: Tuple[MapEntry, MapExit],
src: Tuple[MapEntry, MapExit]):
def _consume_map_with_grid_strided_loop(graph: SDFGState, dst: Tuple[MapEntry, MapExit],
src: Tuple[MapEntry, MapExit]):
"""
Transfer the entirety of `src` map's body into `dst` map, guarded behind a _grid-strided_ loop.
Prerequisite: `dst` map's range must cover `src` map's range in entirety. Statically checking this may not
Expand All @@ -350,10 +350,10 @@ def range_for_grid_stride(r, val, bound):
en, ex = graph.add_map(graph.sdfg._find_new_name('gsl'),
ndrange={k: v for k, v in zip(gsl_params, gsl_ranges)},
schedule=ScheduleType.Sequential)
consume_map_exactly(graph, (en, ex), src)
_consume_map_exactly(graph, (en, ex), src)

assert all(e.data.is_empty() for e in graph.in_edges(en))
cmap = add_equivalent_connectors(dst_en, en)
cmap = _add_equivalent_connectors(dst_en, en)
for e in graph.in_edges(en):
graph.add_memlet_path(e.src, dst_en,
src_conn=e.src_conn, dst_conn=cmap.get(e.dst_conn),
Expand All @@ -363,11 +363,11 @@ def range_for_grid_stride(r, val, bound):
memlet=Memlet.from_memlet(e.data))
graph.remove_edge(e)

cmap = add_equivalent_connectors(dst_ex, ex)
cmap = _add_equivalent_connectors(dst_ex, ex)
for e in graph.out_edges(ex):
graph.add_memlet_path(e.src, dst_ex,
src_conn=e.src_conn,
dst_conn=connector_counterpart(cmap.get(e.src_conn)),
dst_conn=_connector_counterpart(cmap.get(e.src_conn)),
memlet=Memlet.from_memlet(e.data))
graph.add_memlet_path(dst_ex, e.dst,
src_conn=cmap.get(e.src_conn), dst_conn=e.dst_conn,
Expand All @@ -379,7 +379,7 @@ def range_for_grid_stride(r, val, bound):
graph.add_memlet_path(ex, dst_ex, memlet=Memlet())


def fused_range(r1: Range, r2: Range) -> Optional[Range]:
def _fused_range(r1: Range, r2: Range) -> Optional[Range]:
if r1 == r2:
return r1
if len(r1) != len(r2):
Expand All @@ -398,7 +398,7 @@ def fused_range(r1: Range, r2: Range) -> Optional[Range]:
return r


def maps_have_compatible_ranges(first_entry: MapEntry, second_entry: MapEntry, use_grid_strided_loops: bool) -> bool:
def _maps_have_compatible_ranges(first_entry: MapEntry, second_entry: MapEntry, use_grid_strided_loops: bool) -> bool:
"""Decide if the two ranges are compatible. See the class docstring for what is considered compatible."""
if first_entry.map.schedule != second_entry.map.schedule:
# If the two maps are not to be scheduled on the same device, don't fuse them.
Expand Down Expand Up @@ -446,7 +446,7 @@ class ConstAssignmentMapFusion(MapFusion):
def expressions(cls):
# Take any two maps, then check that _every_ path from the first map to second map has exactly one access node
# in the middle and the second edge of the path is empty.
return [floating_nodes_graph(cls.first_map_entry, cls.second_map_entry)]
return [_floating_nodes_graph(cls.first_map_entry, cls.second_map_entry)]

def map_nodes(self, graph: SDFGState):
"""Return the entry and exit nodes of the relevant maps as a tuple: entry_1, exit_1, entry_2, exit_2."""
Expand Down Expand Up @@ -486,15 +486,15 @@ def can_be_applied(self, graph: SDFGState, expr_index: int, sdfg: SDFG, permissi
return False

first_entry, first_exit, second_entry, second_exit = self.map_nodes(graph)
if not maps_have_compatible_ranges(first_entry, second_entry,
use_grid_strided_loops=self.use_grid_strided_loops):
if not _maps_have_compatible_ranges(first_entry, second_entry,
use_grid_strided_loops=self.use_grid_strided_loops):
return False

# Both maps must have consistent constant assignment for the target arrays.
is_const_assignment, assignments = consistent_const_assignment_table(graph, first_entry, first_exit)
is_const_assignment, assignments = _consistent_const_assignment_table(graph, first_entry, first_exit)
if not is_const_assignment:
return False
is_const_assignment, further_assignments = consistent_const_assignment_table(graph, second_entry, second_exit)
is_const_assignment, further_assignments = _consistent_const_assignment_table(graph, second_entry, second_exit)
if not is_const_assignment:
return False
for k, v in further_assignments.items():
Expand All @@ -508,7 +508,7 @@ def apply(self, graph: SDFGState, sdfg: SDFG):

# By now, we know that the two maps are compatible, not reading anything, and just blindly writing constants
# _consistently_.
is_const_assignment, assignments = consistent_const_assignment_table(graph, first_entry, first_exit)
is_const_assignment, assignments = _consistent_const_assignment_table(graph, first_entry, first_exit)
assert is_const_assignment

# Rename in case loop variables are named differently.
Expand All @@ -519,28 +519,28 @@ def apply(self, graph: SDFGState, sdfg: SDFG):
second_entry.map.params = first_entry.map.params

# Consolidate the incoming dependencies of the two maps.
consolidate_empty_dependencies(graph, first_entry, second_entry)
_consolidate_empty_dependencies(graph, first_entry, second_entry)

# Consolidate the written access nodes of the two maps.
consolidate_written_nodes(graph, first_exit, second_exit)
_consolidate_written_nodes(graph, first_exit, second_exit)

# If the ranges are identical, then simply fuse the two maps. Otherwise, use grid-strided loops.
assert fused_range(first_entry.map.range, second_entry.map.range) is not None
assert _fused_range(first_entry.map.range, second_entry.map.range) is not None
en, ex = graph.add_map(sdfg._find_new_name('map_fusion_wrapper'),
ndrange={k: v for k, v in zip(first_entry.map.params,
fused_range(first_entry.map.range,
second_entry.map.range))},
_fused_range(first_entry.map.range,
second_entry.map.range))},
schedule=first_entry.map.schedule)
if first_entry.map.range == second_entry.map.range:
for cur_en, cur_ex in [(first_entry, first_exit), (second_entry, second_exit)]:
consume_map_exactly(graph, (en, ex), (cur_en, cur_ex))
_consume_map_exactly(graph, (en, ex), (cur_en, cur_ex))
elif self.use_grid_strided_loops:
assert ScheduleType.Sequential not in [first_entry.map.schedule, second_entry.map.schedule]
for cur_en, cur_ex in [(first_entry, first_exit), (second_entry, second_exit)]:
if en.map.range == cur_en.map.range:
consume_map_exactly(graph, (en, ex), (cur_en, cur_ex))
_consume_map_exactly(graph, (en, ex), (cur_en, cur_ex))
else:
consume_map_with_grid_strided_loop(graph, (en, ex), (cur_en, cur_ex))
_consume_map_with_grid_strided_loop(graph, (en, ex), (cur_en, cur_ex))

# Cleanup: remove duplicate empty dependencies.
seen = set()
Expand Down Expand Up @@ -577,13 +577,13 @@ def can_be_applied(self, graph: ControlFlowRegion, expr_index: int, sdfg: SDFG,
# Moreover, the states together must contain a consistent constant assignment map.
assignments = {}
for st in [st0, st1]:
en_ex = unique_top_level_map_node(st)
en_ex = _unique_top_level_map_node(st)
if not en_ex:
return False
en, ex = en_ex
if any(not e.data.is_empty for e in st.in_edges(en)):
return False
is_const_assignment, further_assignments = consistent_const_assignment_table(st, en, ex)
is_const_assignment, further_assignments = _consistent_const_assignment_table(st, en, ex)
if not is_const_assignment:
return False
for k, v in further_assignments.items():
Expand All @@ -592,8 +592,8 @@ def can_be_applied(self, graph: ControlFlowRegion, expr_index: int, sdfg: SDFG,
assignments[k] = v

# Moreover, both states' ranges must be compatible.
if not maps_have_compatible_ranges(unique_top_level_map_node(st0)[0], unique_top_level_map_node(st1)[0],
use_grid_strided_loops=self.use_grid_strided_loops):
if not _maps_have_compatible_ranges(_unique_top_level_map_node(st0)[0], _unique_top_level_map_node(st1)[0],
use_grid_strided_loops=self.use_grid_strided_loops):
return False

return True
Expand Down

0 comments on commit f2771d5

Please sign in to comment.