-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Possible to implement with bounded while loop? #24
Comments
Definitely +1 for vmap, I recently ran into a use-case for this and was going to look into a solution, would be interested to see how you got this working. |
Are you suggesting we add a wrapper around an existing bounded while loop implementation? |
Yes, a wrapper that works with a bounded while loop is what I am looking to implement! Here's a simplified example that demonstrates my current use case in a scan. Admittedly I am new to JAX and not too confident in this implementation. When I use this scheme to fit simple HMMs, it appears to work (log likelihoods are monotonically increasing for each iter). However, when I shift to more expensive/complex models, the likelihoods jump around erratically and do not monotonically increase when trying to vmap over more than a single model fit. Not sure if this is a problem with my code or the package I am using (dynamax). Interestingly, in my linked example the early stopping scan seems to work, yet doesn't in my actual implementation. (EDIT: I now realize this example works because the convergence criteria does not depend on a dynamic variable. In my actual use case, my criteria depends on a variable in the carry) Open to any thoughts or critiques. |
I see what you are doing with the vmap, would be interested in adding similar functionality. May need to think about how to implement it as part of the API (i.e. would it still be a wrapper around the scan inner function, or would it be better mapping the vmap somehow?). For the bounded while loop, it's like the for loop, but it should initialise the progress bar with the upper bound and then stop if the while looper terminates? |
@andrewlesak I just released an update with functionality to hand multiple progress bars inside a vmapped function. Instructions are in the README. Will try look at the bounded while loop at some point. |
Thanks for the updates! I was able to adapt the new code to get a wrapper for the bounded while loop working so I can put together a draft PR soon for you to check out |
Is it possible to add a function that integrates with a
Currently, I use a scan with an early stopping condition. If the condition is met, I skip the computationally expensive updates and simply iterate the remaining steps of the progress bar until completion. I've gotten my code to work with vmap and display multiple progress bars (yay!), but the issue is that vmap-of-cond executes both branches, so I'm forced to execute the expensive branch regardless of if I've met my convergence threshold (boo!). Since it seems like we won't be seeing a scan with early stopping any time soon (see this PR), I'll have to settle with a bounded while loop. Do you have any thoughts or insights into whether or not this is possible? Thanks.
The text was updated successfully, but these errors were encountered: