Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Linear regression example in rust. #3

Open
adrodgers opened this issue Mar 17, 2023 · 7 comments
Open

Linear regression example in rust. #3

adrodgers opened this issue Mar 17, 2023 · 7 comments

Comments

@adrodgers
Copy link

Hello,

I am having trouble understanding how to set up the logp function for my use case, specifically the gradient term. I have managed to implement my model using emcee, as this doesn't require a gradient term but I would really like to use NUTS directly from rust!

A slightly more involved example, along the lines of this would be very helpful to understand how to fit a model (including priors) to data, from which I can extrapolate into my own use case.

Thank you for your work on this excellent crate.

@twiecki
Copy link
Member

twiecki commented Mar 17, 2023

Any reason you can't use PyMC?

@aseyboldt
Copy link
Member

Just using pymc is the easiest way to do this for sure.
But if you really want to do it in rust (depending on the problem you might indeed get better performance that way), you can, but will have to derive the gradient yourself, because I don't think there is a well established autodiff library in rust at the moment.
If you already worked out the logp function for your model, you only have to compute that in the logp function of the CpuLogpFunction trait, and write its gradient into grad.
If you are stuck somewhere, can you be more specific as to where exactly?
If you end up writing an example, I'd love a PR for it! :-)

@adrodgers
Copy link
Author

@twiecki Having recently started to code in rust, I have really enjoyed the experience and want to see how far I can take my work using it, but perhaps this is a bridge too far at the moment.
I have used PyMC(3) previously but I ran into some issues with control flow when defining my model that I wasn't able to solve. For context, I am calculating the gravity signal due to spheres, cuboids, polyhedra etc. which tend to require looping over faces and edges. This led me to switch to numpyro, which allowed me to use control flow via jax, which I found to be a bit more intuitive. I will invest some time into updating my code to the latest version of PyMC, especially as it now has a plethora of backends available. Are there any tutorials that you know of for dealing with control flow in the definition of a PyMC model? If so, that would be a big help.

@aseyboldt The main sticking point was conceptualizing what exactly the gradient is that I should assign in the CpuLogpFunction. Is it correct that it is the partial differential of the log posterior with respect to each model parameter?
I have done some searching and I think you are correct that there isn't a stable/settled upon auto-diff library in rust just yet, this seems like it could be a possibility.

Thank you both for your quick replies. Much appreciated.

@twiecki
Copy link
Member

twiecki commented Mar 20, 2023

You can also just use JAX in your PyMC model definition if that helps. Otherwise it's the scan route which is pretty confusing indeed and tends to be slow, although current work tries to fix that.

@aseyboldt
Copy link
Member

aseyboldt commented Mar 20, 2023

Is it correct that it is the partial differential of the log posterior with respect to each model parameter?

Yes, that's it :-)

The library you mentioned won't work too well though. It only supports forward mode autodiff, which is a pretty bad fit when you want to compute gradients. We'd really like to have backward mode autodiff.

@kbvernon
Copy link

kbvernon commented Dec 1, 2024

Piggybacking on this issue to request an example with parameter transforms like the logit.

Also, wrt autodiff, there's an ongoing project to integrate Enzyme AD into rustc.

And thanks for the awesome crate!

@aseyboldt
Copy link
Member

Even right now we could use dfdx, candle burn or tch-rs.
I have to admit that using this from rust isn't really my main motivation, but an example with any of those would be great.
I'd also be quite curios how easy it is to beat jax on the gpu. Implementing the Math trait should allow us to avoid copies between the GPU and main memory for instance.
I don't think I'll have time to do this before I release the Fisher HMC addition in nutpie though, but if someone beats me to it, I'd happily merge an example. :-)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants