Skip to content

Commit

Permalink
Use unique input names in numba_funcify_Scan
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf authored and brandonwillard committed Sep 21, 2022
1 parent cd3a3ce commit 0d69809
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 1 deletion.
2 changes: 1 addition & 1 deletion aesara/link/numba/dispatch/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def numba_funcify_Scan(op, node, **kwargs):
p_outer_in_nit_sot = p_outer_in_shared + n_shared_outs
p_outer_in_non_seqs = p_outer_in_nit_sot + n_nit_sot

input_names = [n.auto_name for n in node.inputs[1:]]
input_names = [f"{n.auto_name}_{i}" for i, n in enumerate(node.inputs[1:])]
outer_in_seqs_names = input_names[:n_seqs]
outer_in_mit_mot_names = input_names[p_in_mit_mot : p_in_mit_mot + n_mit_mot]
outer_in_mit_sot_names = input_names[p_in_mit_sot : p_in_mit_sot + n_mit_sot]
Expand Down
19 changes: 19 additions & 0 deletions tests/link/numba/test_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,3 +179,22 @@ def power_of_2(previous_power, max_value):
np.array(45).astype(config.floatX),
]
compare_numba_and_py(out_fg, test_input_vals)


def test_scan_multiple_none_output():
A = at.dvector("A")

def power_step(prior_result, x):
return prior_result * x, prior_result * x * x, prior_result * x * x * x

result, _ = scan(
power_step,
non_sequences=[A],
outputs_info=[at.ones_like(A), None, None],
n_steps=3,
)

out_fg = FunctionGraph([A], result)
test_input_vals = (np.array([1.0, 2.0]),)

compare_numba_and_py(out_fg, test_input_vals)

0 comments on commit 0d69809

Please sign in to comment.