-
Notifications
You must be signed in to change notification settings - Fork 246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support vector lengthscales for RBF and Matern kernels #1819
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! At first glance, this looks good! But let's see that all tests pass :) @brendancooley do you want to take a look :) ?
if isinstance(length, float | int): | ||
exact = _exact_matern(length) | ||
elif length.ndim == 1: | ||
exact = _exact_matern(length) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do you need two conditions for exact = _exact_matern(length)
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(this is just a question, no need to use the OR statement in view of readability)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was copying the code from test_kernel_approx_squared_exponential
. Yes, I believe this is in service to readability. If we use an 'or' statement here, we should probably use it there as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
either way fine by me
|
||
import jax | ||
|
||
ARRAY_TYPE = Union[jax.Array, np.ndarray] # jax.Array covers tracers |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if this type can be used in other NumPyro modules.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using ArrayImpl
is only an issue when a model gets compiled and the arrays turn into tracers. isinstance(X, jax.Array)
will work for both jax arrays and tracers.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some details here https://jax.readthedocs.io/en/latest/jax_array_migration.html
I believe this is best practice for typing jax arrays (as of last year), but I am not sure
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am definitively not an expert in type hints, so following the recommendation from the docs seems the safest path :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should I mark this thread as resolved, as this seems to be in line with the recommendation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
works for me!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good and works on my offline example. We can see what CI says. Thanks @samanklesaria! Great that we have a group working on this and checking one another's work.
I think next step (maybe another PR) would be to support vector-valued alpha
for the batch dims. I think it may "just work" with the code as is but it would be helpful to update type annotations, and maybe generalize @juanitorduz's code in contrib.hsgp.approximation
to sample matrix/tensor-valued beta
s in linear_approximation
when batch dimensions are detected.
@samanklesaria can you please rebase or sync with the master branch? Today we merged some fixes on the CI, see #1817 |
Yes! The alpha vectorization + docs we can do in another PR :) |
Done! |
Should sampling matrix/tensor-valued betas in |
Personally, I think the scope of this PR is fine. I like working on small iterations so any additional feature can be done in a different PR (at least from my side) |
@samanklesaria it seems there are other more places where you need to change the syntax (similarly as last commit) 😄 |
The current version might have fixed things, but I should probably install a copy of python3.9 locally to test it for sure. |
There is one test failing because the last change if isinstance(length, Union[float, int]) Union is a type hint so this won't work. I suggest you make the change as I suggested above 😄 |
Co-authored-by: Juan Orduz <juanitorduz@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks 🟢 ! Great! 🙌
Thanks for contributing @samanklesaria ! Looking hsgp is having a great momentum. |
Resolves #1805
This allows vector lengthscales in the HSGP approximations to RBF and Matern kernels. Extends brendancooley/numpyro@ef4a24b