Skip to content

Commit

Permalink
✨ Enable random feature order in autoregressive flows
Browse files Browse the repository at this point in the history
  • Loading branch information
francois-rozet committed Aug 11, 2022
1 parent ebc60ba commit 4858fd7
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 11 deletions.
3 changes: 2 additions & 1 deletion lampe/distributions/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,8 @@ class MonotonicRQSTransform(Transform):
r"""Creates a monotonic rational-quadratic spline (RQS) transformation.
References:
Neural Spline Flows (Durkan et al., 2019)
Neural Spline Flows
(Durkan et al., 2019)
https://arxiv.org/abs/1906.04032
Arguments:
Expand Down
37 changes: 27 additions & 10 deletions lampe/nn/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,9 @@ class MaskedAutoregressiveTransform(TransformModule):
Arguments:
features: The number of features.
context: The number of context features.
passes: The number of passes for the inverse transformation. If :py:`None`,
use the number of features instead.
order: The feature ordering. If :py:`None`, use :py:`range(features)` instead.
passes: The number of passes for the inverse transformation.
univariate: A univariate transformation constructor.
shapes: The shapes of the univariate transformation parameters.
kwargs: Keyword arguments passed to :class:`lampe.nn.MaskedMLP`.
Expand All @@ -174,8 +175,8 @@ def __init__(
self,
features: int,
context: int = 0,
passes: int = None,
order: LongTensor = None,
passes: int = -1,
univariate: Callable[..., Transform] = MonotonicAffineTransform,
shapes: List[Size] = [(), ()],
**kwargs,
Expand All @@ -186,10 +187,13 @@ def __init__(
self.shapes = list(map(Size, shapes))
self.sizes = [s.numel() for s in self.shapes]

if passes is None:
passes = features

if order is None:
order = torch.arange(features)

self.passes = passes if passes > 0 else features
self.passes = min(max(passes, 1), features)
self.order = torch.div(order, ceil(features / self.passes), rounding_mode='floor')

in_order = torch.cat((self.order, torch.full((context,), -1)))
Expand Down Expand Up @@ -234,6 +238,9 @@ class MAF(FlowModule):
features: The number of features.
context: The number of context features.
transforms: The number of autoregressive transforms.
randperm: Whether features are randomly permuted between transforms or not.
If :py:`False`, features are in ascending (descending) order for even
(odd) transforms.
kwargs: Keyword arguments passed to :class:`MaskedAutoregressiveTransform`.
"""

Expand All @@ -242,16 +249,19 @@ def __init__(
features: int,
context: int = 0,
transforms: int = 3,
randperm: bool = False,
**kwargs,
):
increasing = torch.arange(features)
decreasing = torch.flipud(increasing)
orders = [
torch.arange(features),
torch.flipud(torch.arange(features)),
]

transforms = [
MaskedAutoregressiveTransform(
features=features,
context=context,
order=decreasing if i % 2 else increasing,
order=torch.randperm(features) if randperm else orders[i % 2],
**kwargs,
)
for i in range(transforms)
Expand All @@ -266,7 +276,8 @@ class NSF(MAF):
r"""Creates a neural spline flow (NSF).
References:
Neural Spline Flows (Durkan et al., 2019)
Neural Spline Flows
(Durkan et al., 2019)
https://arxiv.org/abs/1906.04032
Arguments:
Expand Down Expand Up @@ -345,6 +356,9 @@ class NAF(FlowModule):
features: The number of features.
context: The number of context features.
transforms: The number of autoregressive transforms.
randperm: Whether features are randomly permuted between transforms or not.
If :py:`False`, features are in ascending (descending) order for even
(odd) transforms.
kwargs: Keyword arguments passed to :class:`NeuralAutoregressiveTransform`.
"""

Expand All @@ -353,16 +367,19 @@ def __init__(
features: int,
context: int = 0,
transforms: int = 3,
randperm: bool = False,
**kwargs,
):
increasing = torch.arange(features)
decreasing = torch.flipud(increasing)
orders = [
torch.arange(features),
torch.flipud(torch.arange(features)),
]

transforms = [
NeuralAutoregressiveTransform(
features=features,
context=context,
order=decreasing if i % 2 else increasing,
order=torch.randperm(features) if randperm else orders[i % 2],
**kwargs,
)
for i in range(transforms)
Expand Down

0 comments on commit 4858fd7

Please sign in to comment.