Need a little guidance on using pytensor.scan() #518
Unanswered
mike-lawrence
asked this question in
Q&A
Replies: 1 comment
-
Here is one option import pytensor
import pytensor.tensor as pt
from pytensor.ifelse import ifelse
def step(a, last_a_copied, copy_count):
last_a_copied, copy_count = ifelse(
copy_count >= last_a_copied,
(a, pt.ones_like(copy_count)),
(last_a_copied, copy_count + 1)
)
return last_a_copied, copy_count
a = pt.vector("a", dtype="int64")
copy_count = a.max() + 1 # Or any value that's larger than any in a
last_a_copied = a[0]
[b, _], _ = pytensor.scan(
step,
sequences=[a],
outputs_info=[last_a_copied, copy_count],
)
b.eval({a: [2,1,3,2,4,4,3,3,5,5,2]}) # array([2, 2, 3, 3, 3, 4, 4, 4, 4, 5, 5]) |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
How would the following function be implemented using scan?
Beta Was this translation helpful? Give feedback.
All reactions