Skip to content

Commit

Permalink
Enclose non-sequences in the loop body function
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Oct 17, 2022
1 parent bf97e57 commit a1b7b5c
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 3 deletions.
5 changes: 2 additions & 3 deletions aesara/link/jax/dispatch/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@ def scan(*outer_inputs):
raise NotImplementedError("sit-sot not supported")
if len(outer_in_shared):
raise NotImplementedError("shared variables not supported")
if len(outer_in_non_seqs):
raise NotImplementedError("non sequence are not supported")

# If `output_infos` is empty we need to create an empty initial carry
# value with the output's shape and dtype
Expand All @@ -39,6 +37,7 @@ def scan(*outer_inputs):

init_carry = outer_in_sit_sot
sequences = outer_in_seqs
non_sequences = outer_in_non_seqs

def scan_inner_in_args(carry, x):
"""Create an inner-input expression.
Expand All @@ -59,7 +58,7 @@ def scan_inner_in_args(carry, x):
else:
inner_in_sit_sot = carry

return sum([inner_in_seqs, inner_in_sit_sot], [])
return sum([inner_in_seqs, inner_in_sit_sot, non_sequences], [])

def scan_new_carry(inner_outputs):
"""Create a new carry expression
Expand Down
11 changes: 11 additions & 0 deletions tests/link/jax/test_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,17 @@
None,
lambda op: op.info.n_nit_sot > 0,
),
# nit-sot, non_seq
(
lambda c: at.as_tensor(2.0) * c,
[],
[{}],
[at.dscalar("c")],
3,
[1.0],
None,
lambda op: op.info.n_nit_sot > 0 and op.info.n_non_seqs > 0,
),
],
)
def test_xit_xot_types(
Expand Down

0 comments on commit a1b7b5c

Please sign in to comment.