Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Split Bijector #103

Open
wants to merge 31 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix tests and others
  • Loading branch information
vmoens committed Apr 22, 2022
commit 60fa3679b50532976adbe9a9269f6117da6db78e
4 changes: 2 additions & 2 deletions flowtorch/bijectors/coupling.py
Original file line number Diff line number Diff line change
@@ -65,15 +65,15 @@ def __init__(

def _forward(
self, x: torch.Tensor, params: Optional[Sequence[torch.Tensor]]
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
assert self._params_fn is not None

y, ldj = super()._forward(x, params)
return y, ldj

def _inverse(
self, y: torch.Tensor, params: Optional[Sequence[torch.Tensor]]
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
assert self._params_fn is not None

x, ldj = super()._inverse(y, params)
8 changes: 4 additions & 4 deletions flowtorch/bijectors/ops/affine.py
Original file line number Diff line number Diff line change
@@ -9,7 +9,7 @@
from torch.distributions.utils import _sum_rightmost

_DEFAULT_POSITIVE_BIASES = {
"softplus": torch.expm1(torch.ones(1)).log().item(),
"softplus": 0.5413248538970947,
"exp": 0.0,
}

@@ -58,7 +58,7 @@ def positive_map(self, x: torch.Tensor) -> torch.Tensor:

def _forward(
self, x: torch.Tensor, params: Optional[Sequence[torch.Tensor]]
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
assert params is not None

mean, unbounded_scale = params
@@ -73,7 +73,7 @@ def _forward(

def _inverse(
self, y: torch.Tensor, params: Optional[Sequence[torch.Tensor]]
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
assert (
params is not None
), f"{self.__class__.__name__}._inverse got no parameters"
@@ -86,7 +86,7 @@ def _inverse(

if not self._exp_map:
inverse_scale = self.positive_map(unbounded_scale).reciprocal()
log_scale = inverse_scale.log()
log_scale = -inverse_scale.log()
else:
inverse_scale = torch.exp(-unbounded_scale)
log_scale = unbounded_scale
2 changes: 1 addition & 1 deletion flowtorch/bijectors/ops/spline.py
Original file line number Diff line number Diff line change
@@ -56,7 +56,7 @@ def _inverse(

# TODO: Should I invert the sign of log_detJ?
# TODO: A unit test that compares log_detJ from _forward and _inverse
return x_new, _sum_rightmost(log_detJ, self.domain.event_dim)
return x_new, _sum_rightmost(-log_detJ, self.domain.event_dim)

def _log_abs_det_jacobian(
self,
12 changes: 4 additions & 8 deletions flowtorch/parameters/coupling.py
Original file line number Diff line number Diff line change
@@ -4,7 +4,6 @@

import torch
import torch.nn as nn

from flowtorch.nn.made import MaskedLinear
from flowtorch.parameters.base import Parameters

@@ -178,12 +177,11 @@ def bias(self) -> torch.Tensor:

def _forward(
self,
*input: torch.Tensor,
input: torch.Tensor,
inverse: bool,
context: Optional[torch.Tensor] = None,
) -> Optional[Sequence[torch.Tensor]]:

input = input[0]
input_masked = input.masked_fill(self.mask_output, 0.0) # type: ignore
if context is not None:
input_aug = torch.cat(
@@ -203,7 +201,7 @@ def _forward(

result = h.unbind(-2)
result = tuple(
r.masked_fill(~self.mask_output.expand_as(r), 0.0)
r.masked_fill(~self.mask_output.expand_as(r), 0.0) # type: ignore
for r in result # type: ignore
)
return result
@@ -318,12 +316,11 @@ def _init_weights(self) -> None:

def _forward(
self,
*input: torch.Tensor,
input: torch.Tensor,
inverse: bool,
context: Optional[torch.Tensor] = None,
) -> Optional[Sequence[torch.Tensor]]:

input = input[0]
unsqueeze = False
if input.ndimension() == 3:
# mostly for initialization
@@ -352,8 +349,7 @@ def _forward(
result = h.chunk(2, -3)

result = tuple(
r.masked_fill(~self.mask.expand_as(r), 0.0)
for r in result # type: ignore
r.masked_fill(~self.mask.expand_as(r), 0.0) for r in result # type: ignore
)

return result
1 change: 0 additions & 1 deletion tests/test_bijectivetensor.py
Original file line number Diff line number Diff line change
@@ -15,7 +15,6 @@ def get_net() -> AffineAutoregressive:
[
AffineAutoregressive(params.DenseAutoregressive()),
AffineAutoregressive(params.DenseAutoregressive()),
AffineAutoregressive(params.DenseAutoregressive()),
]
)
ar = ar(
26 changes: 19 additions & 7 deletions tests/test_bijector.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# Copyright (c) Meta Platforms, Inc
import math

import flowtorch.bijectors as bijectors
import numpy as np
import pytest
@@ -17,11 +19,13 @@ def test_bijector_constructor():

@pytest.fixture(params=[bij_name for _, bij_name in bijectors.standard_bijectors])
def flow(request):
torch.set_default_dtype(torch.double)
bij = request.param
event_dim = max(bij.domain.event_dim, 1)
event_shape = event_dim * [3]
base_dist = dist.Independent(
dist.Normal(torch.zeros(event_shape), torch.ones(event_shape)), event_dim
dist.Normal(torch.zeros(event_shape), torch.ones(event_shape)),
event_dim,
)

flow = Flow(base_dist, bij)
@@ -37,10 +41,12 @@ def test_jacobian(flow, epsilon=1e-2):
x = torch.randn(*flow.event_shape)
x = torch.distributions.transform_to(bij.domain)(x)
y = bij.forward(x)
if bij.domain.event_dim == 1:
analytic_ldt = bij.log_abs_det_jacobian(x, y).data
if bij.domain.event_dim == 0:
analytic_ldt = bij.log_abs_det_jacobian(x, y).data.sum(-1)
else:
analytic_ldt = bij.log_abs_det_jacobian(x, y).sum(-1).data
analytic_ldt = bij.log_abs_det_jacobian(x, y).data
for _ in range(bij.domain.event_dim - 1):
analytic_ldt = analytic_ldt.sum(-1)

# Calculate numerical Jacobian
# TODO: Better way to get all indices of array/tensor?
@@ -82,7 +88,8 @@ def test_jacobian(flow, epsilon=1e-2):
if hasattr(params, "permutation"):
numeric_ldt = torch.sum(torch.log(torch.diag(jacobian)))
else:
numeric_ldt = torch.log(torch.abs(jacobian.det()))
jacobian = jacobian.view(int(math.sqrt(jacobian.numel())), -1)
numeric_ldt = torch.log(torch.abs(jacobian.det())).sum()

ldt_discrepancy = (analytic_ldt - numeric_ldt).abs()
assert ldt_discrepancy < epsilon
@@ -105,15 +112,20 @@ def test_inverse(flow, epsilon=1e-5):

# Test g^{-1}(g(x)) = x
x_true = base_dist.sample(torch.Size([10]))
assert x_true.dtype is torch.double
x_true = torch.distributions.transform_to(bij.domain)(x_true)

y = bij.forward(x_true)
J_1 = y.log_detJ
y = y.detach_from_flow()

x_calculated = bij.inverse(y)
J_2 = x_calculated.log_detJ
x_calculated = x_calculated.detach_from_flow()

assert (x_true - x_calculated).abs().max().item() < epsilon

# Test that Jacobian after inverse op is same as after forward
J_1 = bij.log_abs_det_jacobian(x_true, y)
J_2 = bij.log_abs_det_jacobian(x_calculated, y)
assert (J_1 - J_2).abs().max().item() < epsilon


7 changes: 4 additions & 3 deletions tests/test_distribution.py
Original file line number Diff line number Diff line change
@@ -15,7 +15,8 @@ def test_tdist_standalone():
def make_tdist():
# train a flow here
base_dist = torch.distributions.Independent(
torch.distributions.Normal(torch.zeros(input_dim), torch.ones(input_dim)), 1
torch.distributions.Normal(torch.zeros(input_dim), torch.ones(input_dim)),
1,
)
bijector = bijs.AffineAutoregressive()
tdist = dist.Flow(base_dist, bijector)
@@ -37,9 +38,9 @@ def test_neals_funnel_vi():
flow = dist.Flow(base_dist, bijector)
bijector = flow.bijector

opt = torch.optim.Adam(flow.parameters(), lr=2e-3)
opt = torch.optim.Adam(flow.parameters(), lr=1e-2)
num_elbo_mc_samples = 200
for _ in range(100):
for _ in range(500):
z0 = flow.base_dist.rsample(sample_shape=(num_elbo_mc_samples,))
zk = bijector.forward(z0)
ldj = zk._log_detJ