Skip to content

Commit

Permalink
Add more explanations to StateReplication and make its only test
Browse files Browse the repository at this point in the history
narrower.
  • Loading branch information
pratyai committed Oct 29, 2024
1 parent 99226dc commit ad097eb
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 97 deletions.
65 changes: 31 additions & 34 deletions dace/transformation/interstate/state_replication.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved.
""" State replication transformation """

from dace import data as dt, sdfg as sd
from copy import deepcopy

from dace import SDFG, data
from dace.properties import make_properties
from dace.sdfg import utils as sdutil
from dace.sdfg.state import SDFGState
from dace.sdfg.state import SDFGState, ControlFlowRegion
from dace.transformation import transformation
from copy import deepcopy
from dace.transformation.interstate.loop_detection import DetectLoop
from dace.properties import make_properties


@make_properties
class StateReplication(transformation.MultiStateTransformation):
Expand All @@ -16,34 +18,27 @@ class StateReplication(transformation.MultiStateTransformation):
This results in states with only one incoming edge.
"""

target_state = transformation.PatternNode(sd.SDFGState)

@staticmethod
def annotates_memlets():
return True
target_state = transformation.PatternNode(SDFGState)

@classmethod
def expressions(cls):
return [sdutil.node_path_graph(cls.target_state)]


def can_be_applied(self, graph, expr_index, sdfg, permissive=False):
target_state: SDFGState = self.target_state

out_edges = graph.out_edges(target_state)
in_edges = graph.in_edges(target_state)

def can_be_applied(self, graph: ControlFlowRegion, expr_index: int, sdfg: SDFG, permissive: bool = False):
in_edges, out_edges = graph.in_edges(self.target_state), graph.out_edges(self.target_state)
if len(in_edges) < 2:
# If it has only one incoming edge, then there is nothing to replicate.
return False

# avoid useless replications
if target_state.is_empty() and len(out_edges) < 2:
if self.target_state.is_empty() and len(out_edges) < 2:
# No point replicating an empty state that does not branch out again.
# TODO: But _why_ are we focusing on "branching out again"?
return False

# make sure this is not a loop guard
# Make sure this is not a loop guard.
# TODO: But _why_?
if len(out_edges) == 2:
detect = DetectLoop()
detect.loop_guard = target_state
detect.loop_guard = self.target_state
detect.loop_begin = out_edges[0].dst
detect.exit_state = out_edges[1].dst
if detect.can_be_applied(graph, 0, sdfg):
Expand All @@ -52,27 +47,29 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False):
detect.loop_begin = out_edges[1].dst
if detect.can_be_applied(graph, 0, sdfg):
return False

return True

def apply(self, _, sdfg: sd.SDFG):
target_state: SDFGState = self.target_state

if len(sdfg.out_edges(target_state)) == 0:
sdfg.add_state_after(target_state)
def apply(self, graph: ControlFlowRegion, sdfg: SDFG):
state = self.target_state
blueprint = state.to_json()

state_names = set(s.label for s in sdfg.nodes())
in_edges, out_edges = sdfg.in_edges(state), sdfg.out_edges(state)
if not out_edges:
# If this was a sink state, then create an extra sink state to synchronize on.
sdfg.add_state_after(state)

root_blueprint = target_state.to_json()
for e in sdfg.in_edges(target_state)[1:]:
state_copy = sd.SDFGState.from_json(root_blueprint, context={'sdfg': sdfg})
state_copy.label = dt.find_new_name(state_copy.label, state_names)
state_names = set(s.label for s in sdfg.nodes())
for e in in_edges[1:]:
state_copy = SDFGState.from_json(blueprint, context={'sdfg': sdfg})
state_copy.label = data.find_new_name(state_copy.label, state_names)
state_names.add(state_copy.label)
sdfg.add_node(state_copy)

# Replace the `e.src -> state` edge with an `e.src -> state_copy` edge.
sdfg.remove_edge(e)
sdfg.add_edge(e.src, state_copy, e.data)

# connect out edges
for oe in sdfg.out_edges(target_state):
# Replicate the outgoing edges of `state` to `state_copy` too.
for oe in sdfg.out_edges(state):
sdfg.add_edge(state_copy, oe.dst, deepcopy(oe.data))
2 changes: 1 addition & 1 deletion tests/transformations/if_raising_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def make_branched_sdfg_with_raisable_if():
g.add_symbol('flag', dace.bool)

# Do something in the guard state.
t = st0.add_tasklet('write_0', {}, {'__out'}, '__out = 0')
t = st0.add_tasklet('write_0', {}, {'__out'}, '__out = -1')
A = st0.add_access('A')
st0.add_memlet_path(t, A, src_conn='__out', memlet=Memlet(expr='A[0]'))

Expand Down
62 changes: 0 additions & 62 deletions tests/transformations/raise_and_duplicate_test.py

This file was deleted.

93 changes: 93 additions & 0 deletions tests/transformations/state_replication_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved.

import os
from copy import deepcopy

import numpy as np

import dace
from dace import SDFG, Memlet, InterstateEdge
from dace.transformation.interstate import StateReplication


def make_branched_sdfg_with_replicable_branch_terminal():
"""
Construct a simple SDFG of the structure:
guard_state
/ \\
branch_1 branch_2
\\ /
terminal_state
(+ interior)
"""
g = SDFG('prog')
st0 = g.add_state("guard_state", is_start_block=True)
st1 = g.add_state("branch_1")
st2 = g.add_state("branch_2")
st3 = g.add_state("terminal_state")
g.add_array('A', (2,), dace.float32)
g.add_symbol('flag', dace.bool)

# Do something on the branches.
t = st1.add_tasklet('write_1', {}, {'__out'}, '__out = 1')
A = st1.add_access('A')
st1.add_memlet_path(t, A, src_conn='__out', memlet=Memlet(expr='A[0]'))
t = st2.add_tasklet('write_2', {}, {'__out'}, '__out = 2')
A = st2.add_access('A')
st2.add_memlet_path(t, A, src_conn='__out', memlet=Memlet(expr='A[0]'))

# Do something in the terminal state.
t = st3.add_tasklet('write_0', {}, {'__out'}, '__out = 3')
A = st3.add_access('A')
st3.add_memlet_path(t, A, src_conn='__out', memlet=Memlet(expr='A[1]'))

# Connect the states.
g.add_edge(st0, st1, InterstateEdge(condition='(flag)'))
g.add_edge(st0, st2, InterstateEdge(condition='(not flag)'))
g.add_edge(st1, st3, InterstateEdge())
g.add_edge(st2, st3, InterstateEdge())

g.fill_scope_connectors()

return g


def test_replicable_branch_terminal():
origA = np.zeros((2,), np.float32)

g = make_branched_sdfg_with_replicable_branch_terminal()
g.save(os.path.join('_dacegraphs', 'simple-0.sdfg'))
g.validate()
g.compile()

# Get the expected values.
wantA_1 = deepcopy(origA)
wantA_2 = deepcopy(origA)
g(A=wantA_1, flag=True)
g(A=wantA_2, flag=False)

# Before, the outer graph had four states.
assert len(g.nodes()) == 4

assert g.apply_transformations_repeated([StateReplication]) == 1

g.save(os.path.join('_dacegraphs', 'simple-1.sdfg'))
g.validate()
g.compile()

# But now, the graph have six states: the terminal state spawned two additional states on the branches.
assert len(g.nodes()) == 6

# Get the values from transformed program.
gotA_1 = deepcopy(origA)
gotA_2 = deepcopy(origA)
g(A=gotA_1, flag=True)
g(A=gotA_2, flag=False)

# Verify numerically.
assert all(np.equal(wantA_1, gotA_1))
assert all(np.equal(wantA_2, gotA_2))


if __name__ == '__main__':
test_replicable_branch_terminal()

0 comments on commit ad097eb

Please sign in to comment.