Coupling transform module #22
Replies: 3 comments 5 replies
-
Hello @simonschnake, thanks for the issue. There is interest for purely coupling transformations. Some hyper networks are hard/impossible to write as masked networks. It is notably the case of convolutional networks. I implemented a |
Beta Was this translation helpful? Give feedback.
-
Okay, very nice. Here are my One difference in the implementation, compared to other implementations, is that both parts of the features are transformed sequentially. In most implementations, only one side undergoes the transformations. There is not a significant difference between them. My rationale was to ensure that all features are transformed an equal number of times. I can also provide the vanilla implementation of a hyper network. What do you think of it? I haven't implemented any tests yet, but I can do that and provide a pull request. class CouplingTransformModule(TransformModule):
r"""Creates a coupling transformation module.
References:
| NICE: Non-linear Independent Components Estimation (Dinh et al., 2014)
| https://arxiv.org/abs/1410.8516v6
Arguments:
features: The number of features.
context: The number of context features.
order: The feature ordering. If :py:`None`, use :py:`range(features)` instead.
univariate: The univariate transformation constructor.
shapes: The shapes of the univariate transformation parameters.
kwargs: Keyword arguments passed to conditioner networks.
"""
def __init__(
self,
features: int,
context: int = 0,
order: LongTensor = None,
univariate: Callable[..., Transform] = MonotonicAffineTransform,
shapes: Sequence[Size] = ((), ()),
hyper_network: nn.Module = None, # TODO <- hyper_network needs to be sketched out
**kwargs,
):
super().__init__()
# Univariate transformation
self.univariate = univariate
self.shapes = list(map(Size, shapes))
self.sizes = [s.numel() for s in self.shapes]
if order is None:
order = torch.arange(features)
else:
order = torch.as_tensor(order)
self.register_buffer('first_features', order[:len(order) // 2])
self.register_buffer('second_features', order[len(order) // 2:])
self.hyper_first = hyper_network(
first_features = self.first_features,
second_features = self.second_features,
num_params = sum(self.sizes),
context_features = context,
**kwargs
)
self.hyper_second = hyper_network(
first_features = self.second_features,
second_features = self.first_features,
num_params = sum(self.sizes),
context_features = context,
**kwargs
)
class CouplingTransform(Transform):
r"""Transform via a coupling scheme.
.. math:: y_i = f(m x, \bar{m} x)
Arguments:
meta_first: A meta function which returns the transformation :math:`f_1` of the first half.
meta_second: A meta function which returns the transformation :math:`f_2` of the second half.
"""
domain = constraints.real_vector
codomain = constraints.real_vector
bijective = True
def __init__(
self,
meta_first: Callable[[Tensor], Transform],
meta_second: Callable[[Tensor], Transform],
first_features: Tensor,
second_features: Tensor,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.meta_first = meta_first
self.meta_second = meta_second
self.first_features = first_features
self.second_features = second_features
def _call(self, x: Tensor) -> Tensor:
x_first = x[..., self.first_features]
x_second = x[..., self.second_features]
y_first = self.meta_first(x_second)(x_first)
y_second = self.meta_second(y_first)(x_second)
y = torch.empty_like(x)
y[..., self.first_features] = y_first
y[..., self.second_features] = y_second
return y
def _inverse(self, y: Tensor) -> Tensor:
y_first = y[..., self.first_features]
y_second = y[..., self.second_features]
x_second = self.meta_second(y_first)(y_second)
x_first = self.meta_first(x_second)(y_first)
x = torch.empty_like(x)
x[..., self.first_features] = x_first
x[..., self.second_features] = x_second
return x
def log_abs_det_jacobian(self, x: Tensor, y: Tensor) -> Tensor:
x_first = x[..., self.first_features]
x_second = x[..., self.second_features]
y_first = y[..., self.first_features]
y_second = y[..., self.second_features]
ladj_first = self.meta_first(x_second).log_abs_det_jacobian(x_first, y_first).sum(dim=-1)
ladj_second = self.meta_second(y_first).log_abs_det_jacobian(x_second, y_second).sum(dim=-1)
return ladj_first + ladj_second
def call_and_ladj(self, x: Tensor) -> Tuple[Tensor, Tensor]:
x_first = x[..., self.first_features]
x_second = x[..., self.second_features]
y_first, ladj_first = self.meta_first(x_second).call_and_ladj(x_first)
y_second, ladj_second = self.meta_second(y_first).call_and_ladj(x_second)
y = torch.empty_like(x)
y[..., self.first_features] = y_first
y[..., self.second_features] = y_second
return y_first, ladj_first + ladj_second
``` |
Beta Was this translation helpful? Give feedback.
-
PR #23 adds coupling transformations to Zuko 🔥 |
Beta Was this translation helpful? Give feedback.
-
Hey,
I am currently sketching out how a dedicated
CouplingTransformModule
would look like.I know that you can get something similar by using the
MaskedAffineTransformModule
withpasses=2
.My usage is something different, I want to use different conditioner networks or in the language of
zuko
hyper network.I wanted to ask if there is interest to add this to
zuko
.Cheers
Simon
Beta Was this translation helpful? Give feedback.
All reactions