Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Scalar To Symbol Promotion and Implicit Type Conversion Tasklet #1665

Open
philip-paul-mueller opened this issue Sep 23, 2024 · 0 comments
Open

Comments

@philip-paul-mueller
Copy link
Collaborator

I am actually not sure if this is really a bug or the SDFG is actually invalid.
The problem is about the following code:

def foo(a: dace.float64, b: dace.float64[10, 10]) -> dace.float64[10, 10]:
    if a < 0.5:
        return b + 2
    else:
        return b + 1

and the reproducer attached below.
If simplify is run on the reproducer below then the false branch is eliminated.
I traced down the error to the tasklet that casts the bool to an integer, which has the form __out = __in.
If the tasklet is changed to something like __out = (1 if __in else 0) then it works as expected.
I agree that the tasklet only (really) make sense if the tasklet is understood as C++ and not in Python and if the language of the tasklet is changed to CPP it actually works.
But I think that this behaviour should be classified as bug.

sdfg = dace.SDFG("Test")
sdfg.add_array("in_arg_1", shape=(10, 10), dtype=dace.float64, transient=False)
sdfg.add_scalar("in_arg_2", dtype=dace.float64, transient=False)
sdfg.add_array("ret_arg", shape=(10, 10), dtype=dace.float64, transient=False)

make_it_work = False

# Copy everything into transients
#  Without this it does not work
state0 = sdfg.add_state(is_start_block=True)
sdfg.add_array("in_arg_1_transient", shape=(10, 10), dtype=dace.float64, transient=True)
sdfg.add_scalar("in_arg_2_transient", dtype=dace.float64, transient=True)

state0.add_nedge(
    state0.add_access("in_arg_1"),
    state0.add_access("in_arg_1_transient"),
    dace.Memlet("in_arg_1[0:4, 0:4]"),
)
state0.add_nedge(
    state0.add_access("in_arg_2"),
    state0.add_access("in_arg_2_transient"),
    dace.Memlet("in_arg_2[0]"),
)

state1 = sdfg.add_state_after(state0)
sdfg.add_scalar("cond_res", dtype=dace.bool_, transient=True)
cond_tasklet = state1.add_tasklet(
    "condition",
    inputs={"__in"},
    code="__out = __in < 0.5",
    outputs={"__out"},
)
state1.add_edge(
    state1.add_read("in_arg_2_transient"),
    None,
    cond_tasklet,
    "__in",
    dace.Memlet("in_arg_2_transient[0]"),
)
state1.add_edge(
    cond_tasklet,
    "__out",
    state1.add_write("cond_res"),
    None,
    dace.Memlet("cond_res[0]"),
)
    
state2 = sdfg.add_state_after(state1)
sdfg.add_scalar("offset", dtype=dace.int32, transient=True)
offset_tasklet = state2.add_tasklet(
    "offset_computation",
    inputs={"__in"},
    code= "__out = dace.int32(__in)" if make_it_work else "__out = __in" ,
    #code= "__out = 1 if __in else 0" if make_it_work else "__out = __in" ,  # This is an alternative that also works
    outputs={"__out"},
)
state2.add_edge(
    state2.add_read("cond_res"),
    None,
    offset_tasklet,
    "__in",
    dace.Memlet("cond_res[0]"),
)
state2.add_edge(
    offset_tasklet,
    "__out",
    state2.add_write("offset"),
    None,
    dace.Memlet("offset[0]"),
)

branch_state = sdfg.add_state_after(state2, assignments={"_offset": "offset[0]"})

state3_1 = sdfg.add_state()
state3_1.add_mapped_tasklet(
    "offsetter",
    map_ranges=[("__i0", "0:10"), ("__i1", "0:10")],
    inputs={"__in0": dace.Memlet("in_arg_1_transient[__i0, __i1]")},
    code="__out = __in0 + 1",
    outputs={"__out": dace.Memlet("ret_arg[__i0, __i1]")},
    external_edges=True,
)
sdfg.add_edge(branch_state, state3_1, dace.sdfg.InterstateEdge(condition="_offset == 0"))

state3_2 = sdfg.add_state()
sdfg.add_edge(branch_state, state3_2, dace.sdfg.InterstateEdge(condition="_offset != 0"))
state3_2.add_mapped_tasklet(
    "offsetter",
    map_ranges=[("__i0", "0:10"), ("__i1", "0:10")],
    inputs={"__in0": dace.Memlet("in_arg_1_transient[__i0, __i1]")},
    code="__out = __in0 + 2",
    outputs={"__out": dace.Memlet("ret_arg[__i0, __i1]")},
    external_edges=True,
)

assert sdfg.number_of_nodes() >= 3

sdfg.validate()
sdfg.simplify()

# Because the if is runtime data, the if can not be removed and we know that we need at least
#  3 states (2 for the branches and some branch before).
assert sdfg.number_of_nodes() >= 3
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant