-
-
Notifications
You must be signed in to change notification settings - Fork 984
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added type assertion for better code clarity #3036
Conversation
Can you assert isinstance(X, torch.Tensor) |
I thought that it was recommended to use |
Sure I will do that and resubmit |
I have modified all the files (added assertion checks) having this problem in the pyro/contrib/gp/models folder Below is an example that illustrates the assertion error. import pyro
import pyro.contrib.gp as gp
import numpy as np
import torch
X = np.random.randn(200, 2)
Y = np.logical_xor(X[:, 0] > 0, X[:, 1] > 0).astype(np.float64)
kernel = gp.kernels.RBF(input_dim=2)
likelihood = gp.likelihoods.Binary()
model = gp.models.VariationalGP(X, Y, kernel, likelihood=likelihood, whiten=True)
---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
Input In [6], in <cell line: 11>()
9 kernel = gp.kernels.RBF(input_dim=2)
10 likelihood = gp.likelihoods.Binary()
---> 11 model = gp.models.VariationalGP(X, Y, kernel, likelihood=likelihood, whiten=True)
File C:\work\Pyro\pyro\pyro\contrib\gp\models\vgp.py:74, in VariationalGP.__init__(self, X, y, kernel, likelihood, mean_function, latent_shape, whiten, jitter) 63 def __init__(
64 self,
65 X,
(...)
72 jitter=1e-6,
73 ):
---> 74 assert isinstance(
75 X, torch.Tensor
76 ), "X needs to be a torch Tensor instead of a {}".format(type(X))
77 assert isinstance(
78 y, torch.Tensor
79 ), "y needs to be a torch Tensor instead of a {}".format(type(y))
80 super().__init__(X, y, kernel, mean_function, jitter)
AssertionError: X needs to be a torch Tensor instead of a <class 'numpy.ndarray'> Please let me know if some tests need to be added pertaining to the updated code. |
Thanks for adding the checks! I think no need to add tests. Pyro is based on PyTorch, not numpy, so we would expect all inputs are torch tensors. Could you fix the lint issue? You can run |
Run |
Thanks for the help! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, @GautamV234 !
I am working with Prof. @nipunbatra and I am currently working on this PR to solve #3026.