Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support vector lengthscales for RBF and Matern kernels #1819

Merged
merged 6 commits into from
Jun 25, 2024

Conversation

samanklesaria
Copy link
Contributor

Resolves #1805

This allows vector lengthscales in the HSGP approximations to RBF and Matern kernels. Extends brendancooley/numpyro@ef4a24b

Copy link
Contributor

@juanitorduz juanitorduz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! At first glance, this looks good! But let's see that all tests pass :) @brendancooley do you want to take a look :) ?

Comment on lines 221 to 224
if isinstance(length, float | int):
exact = _exact_matern(length)
elif length.ndim == 1:
exact = _exact_matern(length)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do you need two conditions for exact = _exact_matern(length)?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(this is just a question, no need to use the OR statement in view of readability)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was copying the code from test_kernel_approx_squared_exponential. Yes, I believe this is in service to readability. If we use an 'or' statement here, we should probably use it there as well.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

either way fine by me


import jax

ARRAY_TYPE = Union[jax.Array, np.ndarray] # jax.Array covers tracers
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if this type can be used in other NumPyro modules.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using ArrayImpl is only an issue when a model gets compiled and the arrays turn into tracers. isinstance(X, jax.Array) will work for both jax arrays and tracers.

Copy link
Contributor

@brendancooley brendancooley Jun 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some details here https://jax.readthedocs.io/en/latest/jax_array_migration.html

I believe this is best practice for typing jax arrays (as of last year), but I am not sure

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am definitively not an expert in type hints, so following the recommendation from the docs seems the safest path :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should I mark this thread as resolved, as this seems to be in line with the recommendation?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

works for me!

Copy link
Contributor

@brendancooley brendancooley left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good and works on my offline example. We can see what CI says. Thanks @samanklesaria! Great that we have a group working on this and checking one another's work.

I think next step (maybe another PR) would be to support vector-valued alpha for the batch dims. I think it may "just work" with the code as is but it would be helpful to update type annotations, and maybe generalize @juanitorduz's code in contrib.hsgp.approximation to sample matrix/tensor-valued betas in linear_approximation when batch dimensions are detected.

@juanitorduz
Copy link
Contributor

@samanklesaria can you please rebase or sync with the master branch? Today we merged some fixes on the CI, see #1817

@juanitorduz
Copy link
Contributor

Yes! The alpha vectorization + docs we can do in another PR :)

@samanklesaria
Copy link
Contributor Author

@samanklesaria can you please rebase or sync with the master branch? Today we merged some fixes on the CI, see #1817

Done!

@samanklesaria
Copy link
Contributor Author

Should sampling matrix/tensor-valued betas in linear_approximation be done in a separate PR?

@juanitorduz
Copy link
Contributor

Should sampling matrix/tensor-valued betas in linear_approximation be done in a separate PR?

Personally, I think the scope of this PR is fine. I like working on small iterations so any additional feature can be done in a different PR (at least from my side)

@juanitorduz
Copy link
Contributor

@samanklesaria it seems there are other more places where you need to change the syntax (similarly as last commit) 😄

@samanklesaria
Copy link
Contributor Author

The current version might have fixed things, but I should probably install a copy of python3.9 locally to test it for sure.

@juanitorduz
Copy link
Contributor

There is one test failing because the last change

if isinstance(length, Union[float, int])

Union is a type hint so this won't work. I suggest you make the change as I suggested above 😄

Copy link
Contributor

@juanitorduz juanitorduz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks 🟢 ! Great! 🙌

@fehiepsi fehiepsi merged commit 2984b9b into pyro-ppl:master Jun 25, 2024
4 checks passed
@fehiepsi
Copy link
Member

Thanks for contributing @samanklesaria ! Looking hsgp is having a great momentum.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

contrib.hsgp: support vector-valued kernel hyperparameters
4 participants