Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support forward mode differentiation for SVI #1731

Merged
merged 16 commits into from
Feb 8, 2024

Conversation

juanitorduz
Copy link
Contributor

Closes #1726

Trying here a "good first issue".

@juanitorduz juanitorduz marked this pull request as draft February 7, 2024 20:31
@fehiepsi
Copy link
Member

fehiepsi commented Feb 7, 2024

Looks great to me. Could you add a simple test rwith while loop in the model?

@juanitorduz
Copy link
Contributor Author

I added a simple test test/infer/test_svi.py::test_forward_mode_differentiation in 532fb5b but it is failing with

    @functools.wraps(update)
    def tree_update(i, grad_tree, opt_state):
      states_flat, tree, subtrees = opt_state
      grad_flat, tree2 = tree_flatten(grad_tree)
      if tree2 != tree:
        msg = ("optimizer update function was passed a gradient tree that did "
               "not match the parameter tree structure with which it was "
               "initialized: parameter tree {} and grad tree {}.")
>       raise TypeError(msg.format(tree, tree2))
E       TypeError: optimizer update function was passed a gradient tree that did not match the parameter tree structure with which it was initialized: parameter tree PyTreeDef({'loc': *, 'scale': *}) and grad tree PyTreeDef(({'loc': *, 'scale': *}, None))

And I am not sure if its because of the implementation of the test is wrong 😑 . Any tips? Thanks

numpyro/optim.py Outdated
@@ -34,6 +34,11 @@
_OptState = TypeVar("_OptState")
_IterOptState = tuple[int, _OptState]

def _value_and_grad(f, x, forward_mode_differentiation=False):
if forward_mode_differentiation:
return f(x), jacfwd(f, has_aux=True)(x)
Copy link
Member

@fehiepsi fehiepsi Feb 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can set has_aux to False.

A faster way, which just does 1 forward pass, is to redefine f:

def wrapper(x):
  out, aux = f(x)
  return out, (out, aux)

grad, (out, aux) = jacfwd(wrapper, has_aux=True)(x)
return (out, aux), grad

We can also apply this trick in hmc forward mode implementation.

Copy link
Contributor Author

@juanitorduz juanitorduz Feb 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok! So I added your suggestion in ef3a3f7, and now I am getting an error downstream 🤔 . Do I need to unpack somewhere further?

        if progress_bar:
            losses = []
            with tqdm.trange(1, num_steps + 1) as t:
                batch = max(num_steps // 20, 1)
                for i in t:
                    svi_state, loss = jit(body_fn)(svi_state, None)
                    losses.append(jax.device_get(loss))
                    if i % batch == 0:
                        if stable_update:
                            valid_losses = [x for x in losses[i - batch :] if x == x]
                            num_valid = len(valid_losses)
                            if num_valid == 0:
                                avg_loss = float("nan")
                            else:
                                avg_loss = sum(valid_losses) / num_valid
                        else:
>                           avg_loss = sum(losses[i - batch :]) / batch
E                           TypeError: unsupported operand type(s) for +: 'int' and 'dict'

numpyro/infer/svi.py:409: TypeError

Copy link
Member

@fehiepsi fehiepsi Feb 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry, it should be jacfwd of wrapper, not f

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It worked in a64c092 !

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same trick in 1259db7 for hmc

@juanitorduz juanitorduz requested a review from fehiepsi February 8, 2024 12:31
@juanitorduz juanitorduz marked this pull request as ready for review February 8, 2024 15:18
Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Woohoo, thanks for supporting this feature!

@fehiepsi fehiepsi merged commit aec6bd5 into pyro-ppl:master Feb 8, 2024
4 checks passed
@juanitorduz
Copy link
Contributor Author

Thank you for your guidance @fehiepsi 🙏🙂

@juanitorduz juanitorduz deleted the issue_726 branch April 30, 2024 19:03
OlaRonning pushed a commit to aleatory-science/numpyro that referenced this pull request May 6, 2024
* allow forward pass

* fix params

* add missing param in docstring

* add flag to svi

* typo docs

* reorder arguments

* order args

* kw argument internal function

* add arg to minimize

* decouple aux function

* rm kw argument unused

* nicer doctrings

* simple test

* add wrapper

* fix wrapper order

* add wrapper trick to hmc
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Support forward mode differentiation for SVI
2 participants