-
Notifications
You must be signed in to change notification settings - Fork 68
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Removed no_grad from solver #19
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,9 +23,10 @@ def forward(self, x: torch.Tensor, t: torch.Tensor, **extras) -> torch.Tensor: | |
class ConstantVelocityModel(ModelWrapper): | ||
def __init__(self): | ||
super().__init__(None) | ||
self.a = torch.nn.Parameter(torch.tensor(1.0)) | ||
|
||
def forward(self, x: torch.Tensor, t: torch.Tensor, **extras) -> torch.Tensor: | ||
return x * 0.0 + 1.0 | ||
return x * 0.0 + self.a | ||
|
||
|
||
class TestODESolver(unittest.TestCase): | ||
|
@@ -75,6 +76,27 @@ def test_sample_with_different_methods(self): | |
"The solution to the velocity field 3t^3 from 0 to 1 is incorrect.", | ||
) | ||
|
||
def test_gradients(self): | ||
x_init = torch.tensor([1.0, 0.0]) | ||
step_size = 0.001 | ||
time_grid = torch.tensor([0.0, 1.0]) | ||
|
||
for method in ["euler", "dopri5", "midpoint", "heun3"]: | ||
with self.subTest(method=method): | ||
self.constant_velocity_model.zero_grad() | ||
result = self.constant_velocity_solver.sample( | ||
x_init=x_init, | ||
step_size=step_size if method != "dopri5" else None, | ||
time_grid=time_grid, | ||
method=method, | ||
enable_grad=True, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Check grads are not computed without this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added test |
||
) | ||
loss = result.sum() | ||
loss.backward() | ||
self.assertAlmostEqual( | ||
self.constant_velocity_model.a.grad, 2.0, delta=1e-4 | ||
) | ||
|
||
def test_compute_likelihood(self): | ||
x_1 = torch.tensor([[0.0, 0.0]]) | ||
step_size = 0.1 | ||
|
@@ -93,7 +115,7 @@ def dummy_log_p(x: Tensor) -> Tensor: | |
self.assertEqual(x_1.shape[0], log_likelihood.shape[0]) | ||
|
||
def test_compute_likelihood_exact_divergence(self): | ||
x_1 = torch.tensor([[0.0, 0.0]]) | ||
x_1 = torch.tensor([[0.0, 0.0]], requires_grad=True) | ||
step_size = 0.1 | ||
|
||
# Define a dummy log probability function | ||
|
@@ -105,6 +127,7 @@ def dummy_log_p(x: Tensor) -> Tensor: | |
log_p0=dummy_log_p, | ||
step_size=step_size, | ||
exact_divergence=True, | ||
enable_grad=True, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Check grads not computed without this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added test |
||
) | ||
self.assertIsInstance(log_likelihood, Tensor) | ||
self.assertEqual(x_1.shape[0], log_likelihood.shape[0]) | ||
|
@@ -114,6 +137,10 @@ def dummy_log_p(x: Tensor) -> Tensor: | |
self.assertTrue( | ||
torch.allclose(x_1 - 1.0, x_0, atol=1e-2), | ||
) | ||
log_likelihood.backward() | ||
self.assertTrue( | ||
torch.allclose(x_1.grad, torch.tensor([1.0, 1.0]), atol=1e-2), | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Was this no_grad unnecessary previously?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was unnecessary yes.