Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

specifiying pdf in nk.jax.jacobian has no effect by default #1979

Open
MarcMachaczek opened this issue Jan 16, 2025 · 8 comments
Open

specifiying pdf in nk.jax.jacobian has no effect by default #1979

MarcMachaczek opened this issue Jan 16, 2025 · 8 comments

Comments

@MarcMachaczek
Copy link

Hi,
While playing around with nk.jax.jacobian to compute the mean gradients from a fullsum state, along the lines of:

jacobian = nk.jax.jacobian(vqs._apply_fun, vqs.parameters, samples, mode=S.mode, dense=True, pdf=vqs.probability_distribution())
mean_grad = jacobian.sum(axis=0)

I noticed that the means were unreasonably large and it made no difference whether I specified pdf or not.

If I understand the implementation of nk.jax.jacobian correctly, the pdf argument is essentially ignored if both centered==False and _sqrt_rescale==False.

Happy to do a PR (if this is indeed a bug), but it might make sense to think about this internal flag before.

I'm using NetKet version 3.14.3.

Best,
Marc

@PhilipVinc
Copy link
Member

I noticed that the means were unreasonably large

This depends on your state, I think, so it should not be surprising?

and it made no difference whether I specified pdf or not.
If I understand the implementation of nk.jax.jacobian correctly, the pdf argument is essentially ignored if both centered==False and _sqrt_rescale==False.

Yes, indeed. We do not multiply by $p(x)$ by default, so if you do not specify one among centered or _sqrt_rescale the pdf argument is effectively ignored.

By the way, the idea at the time was that not specifying pdf is somewhat equivalent to passing a flat probability distribution pdf=jnp.ones((n_samples))/N_samples.
Now, multiplying by pdf by default would mean that for consistency we should consider multiplying always by 1/Ns.
Unless we wanted not to be consistent?

I'm unsure what the best solution is.
If we are not consistent, and so multiply by $p(x)$ when pdf is specified, and $\sqrt{p(x)}$ when _sqrt_rescale==True and the pdf is specified, while multiply by $\sqrt{1/Ns}$ if _sqrt_rescale==True and no pdf is specified and do nothing if _sqrt_rescale==False, do we make It harder to implement algorithms around? (people like @lgravina1997 or @alleSini99 used those a lot recently and can comment?)

Instead, if we strive for consistency this would mean breaking a lot of code around (again, people please comment).
A possible alternative is to print a warning if you specify pdf but we will ignore it?

(As an historical note, _sqrt_rescale is something unfinished, and that's why it has a leading underscore...
I never knew how to properly expose it, but it was often used internally so I left it there.
However I've never throughly thought this part of the interface).

@lgravina1997
Copy link
Contributor

Regarding the mean of the gradient, I would point out that in the example above you are only averaging the Jacobian and therefore what you see is not the gradient, unless you have a flat local energy equal to one.

I do agree with the principle of considering pdf and 1/Ns at the same level. And I like the consistency of having only the Jacobian unless otherwise specified. I tend to the see the multiplication by the pdf (or the finite-sample equivalent) as a step that should not be automatic or assumed.

The way I see it is that _sqrt_rescale==True is needed whenever you want to compute expectation values.
I agree that exposing it could be a good idea.

@PhilipVinc
Copy link
Member

The way I see it is that _sqrt_rescale==True is needed whenever you want to compute expectation values.
I agree that exposing it could be a good idea.

Expose how exactly? like it is now, or differently?

@lgravina1997
Copy link
Contributor

I do not see a problem with how it is now, but if it is unclear maybe something could be added to the doc?

@PhilipVinc
Copy link
Member

I had put an underscore because It felt a bit like an hack (but was internally useful).
It seems people use it so maybe it should become standard?

The main point was that I was not sure if we want to implement a multiply by pdf or multiply by sqrt of pdf.
In netket we usually need the latter, but I think sometimes you might want the former...

@lgravina1997
Copy link
Contributor

I have personally used both versions before. Maybe having two flags implementing both variants and returning an error if both are true could be a solution that would then make everyone happy.

@MarcMachaczek
Copy link
Author

Regarding the mean of the gradient, I would point out that in the example above you are only averaging the Jacobian and therefore what you see is not the gradient, unless you have a flat local energy equal to one.

Yes correct, I'm referring to the mean gradients of the network wrt to its parameters, ie.
$\sum_s |\psi(s)|^2 \partial_\theta \log(\psi(s))$.
Being able to specify a pdf here without any additional factor or sqrt would be ideal for a full sum state. Tbh manually multiplying is easy and only adds one additional line of code so I don't see why this particular problem would justify a "breaking" change. I just thought this behavior was not intended.

@PhilipVinc
Copy link
Member

So summing up. No problem.
Maybe we can warn If pdf is passed but unused.

To do what Luca proposes we could add a kwarg multiply_by_pdf=None/False/"sqrt","pdf" or something like that, without having multiple kwargs that is more messy and harder to document...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants