Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Possible to implement with bounded while loop? #24

Open
andrewlesak opened this issue Oct 1, 2024 · 6 comments
Open

Possible to implement with bounded while loop? #24

andrewlesak opened this issue Oct 1, 2024 · 6 comments

Comments

@andrewlesak
Copy link

Is it possible to add a function that integrates with a

  1. bounded while loop to allow for early stopping
  2. and is also vmap-able to allow for multiple progress bars?

Currently, I use a scan with an early stopping condition. If the condition is met, I skip the computationally expensive updates and simply iterate the remaining steps of the progress bar until completion. I've gotten my code to work with vmap and display multiple progress bars (yay!), but the issue is that vmap-of-cond executes both branches, so I'm forced to execute the expensive branch regardless of if I've met my convergence threshold (boo!). Since it seems like we won't be seeing a scan with early stopping any time soon (see this PR), I'll have to settle with a bounded while loop. Do you have any thoughts or insights into whether or not this is possible? Thanks.

@zombie-einstein
Copy link
Collaborator

Definitely +1 for vmap, I recently ran into a use-case for this and was going to look into a solution, would be interested to see how you got this working.

@zombie-einstein
Copy link
Collaborator

Are you suggesting we add a wrapper around an existing bounded while loop implementation?

@andrewlesak
Copy link
Author

andrewlesak commented Oct 1, 2024

Yes, a wrapper that works with a bounded while loop is what I am looking to implement!

Here's a simplified example that demonstrates my current use case in a scan. Admittedly I am new to JAX and not too confident in this implementation. When I use this scheme to fit simple HMMs, it appears to work (log likelihoods are monotonically increasing for each iter). However, when I shift to more expensive/complex models, the likelihoods jump around erratically and do not monotonically increase when trying to vmap over more than a single model fit. Not sure if this is a problem with my code or the package I am using (dynamax).

Interestingly, in my linked example the early stopping scan seems to work, yet doesn't in my actual implementation. (EDIT: I now realize this example works because the convergence criteria does not depend on a dynamic variable. In my actual use case, my criteria depends on a variable in the carry)

Open to any thoughts or critiques.

@zombie-einstein
Copy link
Collaborator

I see what you are doing with the vmap, would be interested in adding similar functionality. May need to think about how to implement it as part of the API (i.e. would it still be a wrapper around the scan inner function, or would it be better mapping the vmap somehow?).

For the bounded while loop, it's like the for loop, but it should initialise the progress bar with the upper bound and then stop if the while looper terminates?
Looking at the code this seems reasonable, I guess you'd initialise and update the progress in the bar but have to check for termination, I wonder if there would be some issue with step counting? The JAX while loop only carries a single value trough, it doesn't pass in an step number or mapped value (as we use in the scan).

@zombie-einstein
Copy link
Collaborator

zombie-einstein commented Oct 7, 2024

@andrewlesak I just released an update with functionality to hand multiple progress bars inside a vmapped function. Instructions are in the README.

Will try look at the bounded while loop at some point.

@andrewlesak
Copy link
Author

Thanks for the updates! I was able to adapt the new code to get a wrapper for the bounded while loop working so I can put together a draft PR soon for you to check out

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants