Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add complex constraint and real Fourier transform. #1762

Merged
merged 10 commits into from
Mar 16, 2024
11 changes: 11 additions & 0 deletions numpyro/distributions/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
__all__ = [
"boolean",
"circular",
"complex",
"corr_cholesky",
"corr_matrix",
"dependent",
Expand Down Expand Up @@ -629,6 +630,15 @@ def feasible_like(self, prototype):
)


class _Complex(_SingletonConstraint):
def __call__(self, x):
# XXX: consider to relax this condition to [-inf, inf] interval
return (x == x) & (x != float("inf")) & (x != float("-inf"))

def feasible_like(self, prototype):
return jax.numpy.zeros_like(prototype)


class _Real(_SingletonConstraint):
def __call__(self, x):
# XXX: consider to relax this condition to [-inf, inf] interval
Expand Down Expand Up @@ -692,6 +702,7 @@ def feasible_like(self, prototype):

boolean = _Boolean()
circular = _Circular()
complex = _Complex()
corr_cholesky = _CorrCholesky()
corr_matrix = _CorrMatrix()
dependent = _Dependent()
Expand Down
88 changes: 87 additions & 1 deletion numpyro/distributions/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
"LowerCholeskyAffine",
"PermuteTransform",
"PowerTransform",
"RealFastFourierTransform",
"ReshapeTransform",
"SigmoidTransform",
"SimplexToOrderedTransform",
Expand Down Expand Up @@ -1190,7 +1191,7 @@ def _inverse(self, y):
return jnp.reshape(y, self.inverse_shape(jnp.shape(y)))

def log_abs_det_jacobian(self, x, y, intermediates=None):
return 0.0
return jnp.zeros_like(x, shape=x.shape[: x.ndim - len(self._inverse_shape)])

def tree_flatten(self):
aux_data = {
Expand All @@ -1207,6 +1208,86 @@ def __eq__(self, other):
)


def _normalize_rfft_shape(input_shape, shape):
if shape is None:
return input_shape
return input_shape[: len(input_shape) - len(shape)] + shape


class RealFastFourierTransform(Transform):
"""
N-dimensional discrete fast Fourier transform for real input.

:param transform_shape: Length of each transformed axis to use from the input,
defaults to the input size.
:param transform_ndims: Number of trailing dimensions to transform.
"""

def __init__(
self,
transform_shape=None,
transform_ndims=1,
) -> None:
if isinstance(transform_shape, int):
transform_shape = (transform_shape,)
if transform_shape is not None and len(transform_shape) != transform_ndims:
raise ValueError(
f"Length of transform shape ({transform_shape}) does not match number "
f"of dimensions to transform ({transform_ndims})."
)
self.transform_shape = transform_shape
self.transform_ndims = transform_ndims

def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
axes = tuple(range(-self.transform_ndims, 0))
return jnp.fft.rfftn(x, self.transform_shape, axes)

def _inverse(self, y: jnp.ndarray) -> jnp.ndarray:
axes = tuple(range(-self.transform_ndims, 0))
return jnp.fft.irfftn(y, self.transform_shape, axes)

def forward_shape(self, shape: tuple) -> tuple:
# Dimensions remain unchanged except the last transformed dimension.
shape = _normalize_rfft_shape(shape, self.transform_shape)
return shape[:-1] + (shape[-1] // 2 + 1,)

def inverse_shape(self, shape: tuple) -> tuple:
if self.transform_shape:
return _normalize_rfft_shape(shape, self.transform_shape)
size = 2 * (shape[-1] - 1)
return shape[:-1] + (size,)

def log_abs_det_jacobian(
self, x: jnp.ndarray, y: jnp.ndarray, intermediates: None = None
) -> jnp.ndarray:
shape = jnp.broadcast_shapes(
x.shape[: -self.transform_ndims], y.shape[: -self.transform_ndims]
)
return jnp.zeros_like(x, shape=shape)

def tree_flatten(self):
aux_data = {
"transform_shape": self.transform_shape,
"transform_ndims": self.transform_ndims,
}
return (), ((), aux_data)

@property
def domain(self) -> constraints.Constraint:
return constraints.independent(constraints.real, self.transform_ndims)

@property
def codomain(self) -> constraints.Constraint:
return constraints.independent(constraints.complex, self.transform_ndims)

def __eq__(self, other):
return (
isinstance(other, RealFastFourierTransform)
and self.transform_ndims == other.transform_ndims
and self.transform_shape == other.transform_shape
)


##########################################################
# CONSTRAINT_REGISTRY
##########################################################
Expand Down Expand Up @@ -1334,6 +1415,11 @@ def _transform_to_positive_ordered_vector(constraint):
return ComposeTransform([OrderedTransform(), ExpTransform()])


@biject_to.register(constraints.complex)
def _transform_to_complex(constraint):
return IdentityTransform()


@biject_to.register(constraints.real)
def _transform_to_real(constraint):
return IdentityTransform()
Expand Down
1 change: 1 addition & 0 deletions test/test_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
SINGLETON_CONSTRAINTS = {
"boolean": constraints.boolean,
"circular": constraints.circular,
"complex": constraints.complex,
"corr_cholesky": constraints.corr_cholesky,
"corr_matrix": constraints.corr_matrix,
"l1_ball": constraints.l1_ball,
Expand Down
85 changes: 85 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
OrderedTransform,
PermuteTransform,
PowerTransform,
RealFastFourierTransform,
ReshapeTransform,
ScaledUnitLowerCholeskyTransform,
SigmoidTransform,
Expand All @@ -37,6 +38,7 @@
SoftplusTransform,
StickBreakingTransform,
UnpackTransform,
biject_to,
)


Expand Down Expand Up @@ -83,6 +85,11 @@ class T(namedtuple("TestCase", ["transform_cls", "params", "kwargs"])):
(_a(2.0),),
dict(),
),
"rfft": T(
RealFastFourierTransform,
(),
dict(transform_shape=(3, 4, 5), transform_ndims=3),
),
"simplex_to_ordered": T(
SimplexToOrderedTransform,
(_a(1.0),),
Expand Down Expand Up @@ -228,3 +235,81 @@ def test_reshape_transform_invalid():

with pytest.raises(TypeError, match="cannot reshape array"):
ReshapeTransform((2, 3), (6,))(jnp.arange(2))


@pytest.mark.parametrize(
"input_shape, shape, ndims",
[
((10,), None, 1),
((11,), 11, 1),
((10, 18), None, 2),
((10, 19), (7, 8), 2),
],
)
def test_real_fast_fourier_transform(input_shape, shape, ndims):
x1 = random.normal(random.key(17), input_shape)
transform = RealFastFourierTransform(shape, ndims)
y = transform(x1)
assert transform.codomain(y).all()
assert y.shape == transform.forward_shape(x1.shape)
x2 = transform.inv(y)
assert transform.domain(x2).all()
if x1.shape == x2.shape:
assert jnp.allclose(x2, x1, atol=1e-6)


@pytest.mark.parametrize(
"transform, shape",
[
(AffineTransform(3, 2.5), ()),
(CholeskyTransform(), (10,)),
(ComposeTransform([SoftplusTransform(), SigmoidTransform()]), ()),
(CorrCholeskyTransform(), (15,)),
(CorrMatrixCholeskyTransform(), (15,)),
(ExpTransform(), ()),
(IdentityTransform(), ()),
(IndependentTransform(ExpTransform(), 2), (3, 4)),
(L1BallTransform(), (9,)),
(LowerCholeskyAffine(jnp.ones(3), jnp.eye(3)), (3,)),
(LowerCholeskyTransform(), (10,)),
(OrderedTransform(), (5,)),
(PermuteTransform(jnp.roll(jnp.arange(7), 2)), (7,)),
(PowerTransform(2.5), ()),
(RealFastFourierTransform(7), (7,)),
(RealFastFourierTransform((8, 9), 2), (8, 9)),
(ReshapeTransform((5, 2), (10,)), (10,)),
(ReshapeTransform((15,), (3, 5)), (3, 5)),
(ScaledUnitLowerCholeskyTransform(), (6,)),
(SigmoidTransform(), ()),
(SimplexToOrderedTransform(), (5,)),
(SoftplusLowerCholeskyTransform(), (10,)),
(SoftplusTransform(), ()),
(StickBreakingTransform(), (11,)),
],
)
def test_bijective_transforms(transform, shape):
if isinstance(transform, type):
pytest.skip()
# Get a sample from the support of the distribution.
batch_shape = (13,)
unconstrained = random.normal(random.key(17), batch_shape + shape)
x1 = biject_to(transform.domain)(unconstrained)

# Transform forward and backward, checking shapes, values, and Jacobian shape.
y = transform(x1)
assert y.shape == transform.forward_shape(x1.shape)

x2 = transform.inv(y)
assert x2.shape == transform.inverse_shape(y.shape)
# Some transforms are a bit less stable; we give them larger tolerances.
atol = 1e-6
less_stable_transforms = (
CorrCholeskyTransform,
L1BallTransform,
StickBreakingTransform,
)
if isinstance(transform, less_stable_transforms):
atol = 1e-2
assert jnp.allclose(x1, x2, atol=atol)

assert transform.log_abs_det_jacobian(x1, y).shape == batch_shape
Loading