Skip to content

Commit

Permalink
ran notebooks, added unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
Marton Havasi committed Dec 17, 2024
1 parent be2a367 commit 1498339
Show file tree
Hide file tree
Showing 5 changed files with 113 additions and 65 deletions.
48 changes: 28 additions & 20 deletions examples/2d_discrete_flow_matching.ipynb

Large diffs are not rendered by default.

23 changes: 13 additions & 10 deletions examples/2d_flow_matching.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@
{
"data": {
"text/plain": [
"<torch._C.Generator at 0x7f87769e0ef0>"
"<torch._C.Generator at 0x7f987dfe2ef0>"
]
},
"execution_count": 3,
Expand Down Expand Up @@ -193,16 +193,16 @@
"name": "stdout",
"output_type": "stream",
"text": [
"| iter 2000 | 3.59 ms/step | loss 3.772 \n",
"| iter 4000 | 3.45 ms/step | loss 3.684 \n",
"| iter 2000 | 3.61 ms/step | loss 3.772 \n",
"| iter 4000 | 3.46 ms/step | loss 3.684 \n",
"| iter 6000 | 3.45 ms/step | loss 3.780 \n",
"| iter 8000 | 3.45 ms/step | loss 3.729 \n",
"| iter 8000 | 3.46 ms/step | loss 3.729 \n",
"| iter 10000 | 3.45 ms/step | loss 3.705 \n",
"| iter 12000 | 3.45 ms/step | loss 3.661 \n",
"| iter 14000 | 3.45 ms/step | loss 3.625 \n",
"| iter 14000 | 3.46 ms/step | loss 3.625 \n",
"| iter 16000 | 3.45 ms/step | loss 3.837 \n",
"| iter 18000 | 3.45 ms/step | loss 3.796 \n",
"| iter 20000 | 3.45 ms/step | loss 3.872 \n"
"| iter 18000 | 3.46 ms/step | loss 3.796 \n",
"| iter 20000 | 3.46 ms/step | loss 3.872 \n"
]
}
],
Expand Down Expand Up @@ -291,7 +291,8 @@
"\n",
"x_init = torch.randn((batch_size, 2), dtype=torch.float32, device=device)\n",
"solver = ODESolver(velocity_model=wrapped_vf) # create an ODESolver class\n",
"sol = solver.sample(time_grid=T, x_init=x_init, method='midpoint', step_size=step_size, return_intermediates=True) # sample from the model"
"with torch.no_grad():\n",
" sol = solver.sample(time_grid=T, x_init=x_init, method='midpoint', step_size=step_size, return_intermediates=True) # sample from the model"
]
},
{
Expand Down Expand Up @@ -387,13 +388,15 @@
"log_p_acc = 0\n",
"\n",
"for i in range(num_acc):\n",
" _, log_p = solver.compute_likelihood(x_1=x_1, method='midpoint', step_size=step_size, exact_divergence=False, log_p0=gaussian_log_density)\n",
" with torch.no_grad():\n",
" _, log_p = solver.compute_likelihood(x_1=x_1, method='midpoint', step_size=step_size, exact_divergence=False, log_p0=gaussian_log_density)\n",
" log_p_acc += log_p\n",
"\n",
"log_p_acc /= num_acc\n",
"\n",
"# compute with exact divergence\n",
"_, exact_log_p = solver.compute_likelihood(x_1=x_1, method='midpoint', step_size=step_size, exact_divergence=True, log_p0=gaussian_log_density)"
"with torch.no_grad():\n",
" _, exact_log_p = solver.compute_likelihood(x_1=x_1, method='midpoint', step_size=step_size, exact_divergence=True, log_p0=gaussian_log_density)"
]
},
{
Expand Down
42 changes: 25 additions & 17 deletions examples/2d_riemannian_flow_matching_flat_torus.ipynb

Large diffs are not rendered by default.

42 changes: 25 additions & 17 deletions examples/2d_riemannian_flow_matching_sphere.ipynb

Large diffs are not rendered by default.

23 changes: 22 additions & 1 deletion tests/solver/test_ode_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@ def forward(self, x: torch.Tensor, t: torch.Tensor, **extras) -> torch.Tensor:
class ConstantVelocityModel(ModelWrapper):
def __init__(self):
super().__init__(None)
self.a = torch.nn.Parameter(torch.tensor(1.0))

def forward(self, x: torch.Tensor, t: torch.Tensor, **extras) -> torch.Tensor:
return x * 0.0 + 1.0
return x * 0.0 + self.a


class TestODESolver(unittest.TestCase):
Expand Down Expand Up @@ -75,6 +76,26 @@ def test_sample_with_different_methods(self):
"The solution to the velocity field 3t^3 from 0 to 1 is incorrect.",
)

def test_gradients(self):
x_init = torch.tensor([1.0, 0.0])
step_size = 0.001
time_grid = torch.tensor([0.0, 1.0])

for method in ["euler", "dopri5", "midpoint", "heun3"]:
with self.subTest(method=method):
self.constant_velocity_model.zero_grad()
result = self.constant_velocity_solver.sample(
x_init=x_init,
step_size=step_size if method != "dopri5" else None,
time_grid=time_grid,
method=method,
)
loss = result.sum()
loss.backward()
self.assertAlmostEqual(
self.constant_velocity_model.a.grad, 2.0, delta=1e-4
)

def test_compute_likelihood(self):
x_1 = torch.tensor([[0.0, 0.0]])
step_size = 0.1
Expand Down

0 comments on commit 1498339

Please sign in to comment.