Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added gamma to LBFGSState #320

Merged
merged 1 commit into from
Oct 7, 2022

Conversation

zaccharieramzi
Copy link
Contributor

@zaccharieramzi zaccharieramzi commented Oct 7, 2022

This PR adds a new attribute to the LBFGState named tuple, gamma.
This attribute is needed when one wants to use the approximation of the Hessian inverse after the algorithm has run using inv_hessian_product.

Indeed I realized that in this implementation, the initial value of the approximation of the Hessian is recomputed at each iteration, as suggested by Nocedal (something that is not present in scipy to the best of my knowledge and understanding of Fortran).

The application I have in mind for the approximation of the Hessian inverse is https://arxiv.org/abs/2106.00553.

(I didn't write an issue first because it's such a small PR, and I was going to write this part anyway for my own xps)

@zaccharieramzi
Copy link
Contributor Author

One aspect I am not too sure about though in the arguments of inv_hessian_product is start.
It indeed seems like it's used to perform the matrix multiplication in a certain order, but I don't understand why that's necessary.
Is it because with numerical errors the additions become non-commutative?

@mblondel
Copy link
Collaborator

mblondel commented Oct 7, 2022

One aspect I am not too sure about though in the arguments of inv_hessian_product is start

We use a circular buffer to maintain the history without lists.

Copy link
Collaborator

@mblondel mblondel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks a lot for the contribution

@zaccharieramzi
Copy link
Contributor Author

I think I don't have access to the logs of the copybara failure, or at least I am not sure how to access them in order to correct the failure.

@@ -142,6 +142,7 @@ class LbfgsState(NamedTuple):
s_history: Any
y_history: Any
rho_history: jnp.ndarray
gamma: Any = jnp.array(1.0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I overlooked this in my previous review but you shouldn't initialize the array here. Please do

Suggested change
gamma: Any = jnp.array(1.0)
gamma: Any = jnp.ndarray

and instead initialize gamma in init_state. CC @froystig

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I meant

  gamma: jnp.ndarray

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@copybara-service copybara-service bot merged commit 418bce3 into google:main Oct 7, 2022
junpenglao added a commit to blackjax-devs/blackjax that referenced this pull request Oct 23, 2022
rlouf pushed a commit to rlouf/blackjax that referenced this pull request Oct 26, 2022
junpenglao added a commit to blackjax-devs/blackjax that referenced this pull request Mar 12, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants