You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I've encountered a small error when implementing Distrax bijectors with Flax conditioner NNs, and I also have a question about best practices for using Distrax with Flax.
class Conditioner(nn.Module):
event_shape: Sequence[int]
num_bijector_params: int
hidden_dims: Sequence[int]
@nn.compact
def __call__(self, z: Array, h: Array) -> Array:
h = jnp.concatenate((z.flatten(), h.flatten()), axis=0)
for hidden_dim in self.hidden_dims:
h = nn.Dense(hidden_dim)(h)
h = nn.relu(h)
y = nn.Dense(np.prod(self.event_shape) * self.num_bijector_params)(h)
y = y.reshape(tuple(self.event_shape) + (self.num_bijector_params,))
return y
class MyModel(nn.Module):
hidden_dims: Sequence[int]
num_flows: int
num_bins: int
event_shape: Sequence[int]
conditioner: Optional[KwArgs] = None
@nn.compact
def __call__(self, x, y: Optional[Array] = None):
# base distribution
output_dim = np.prod(self.event_shape)
base = distrax.Independent(
distrax.Normal(loc=jnp.zeros(output_dim,), scale=jnp.ones(output_dim,)), len(self.event_shape)
)
# bijector
# Number of parameters for the rational-quadratic spline:
# - `num_bins` bin widths
# - `num_bins` bin heights
# - `num_bins + 1` knot slopes
# for a total of `3 * num_bins + 1` parameters.
num_bijector_params = 3 * self.num_bins + 1
layers = []
mask = jnp.arange(0, np.prod(self.event_shape)) % 2
mask = jnp.reshape(mask, self.event_shape)
mask = mask.astype(bool)
def bijector_fn(params: Array):
return distrax.RationalQuadraticSpline(
params, range_min=-3.0, range_max=3.0
)
h = x.flatten()
# shared feature extractor
for hidden_dim in self.hidden_dims:
h = nn.Dense(hidden_dim)(h)
h = nn.relu(h)
for i in range(self.num_flows):
conditioner = Conditioner(
event_shape=self.event_shape,
num_bijector_params=num_bijector_params,
**(self.conditioner or {}),
)
layer = distrax.MaskedCoupling(
mask=mask,
bijector=bijector_fn,
conditioner=functools.partial(conditioner, h=h),
)
layers.append(layer)
mask = ~mask
bijector = distrax.Inverse(distrax.Chain(layers))
transformed = distrax.Transformed(base, bijector)
if y is not None:
return transformed, transformed.log_prob(y)
else:
return transformed
model = MyModel(
hidden_dims = [64, 32],
num_flows = 3,
num_bins = 8,
event_shape = (6,),
conditioner = {'hidden_dims': [64, 32]}
)
variables = model.init(random.PRNGKey(0), jnp.empty((28, 28, 1)), y=jnp.empty((6,)))
dist = model.apply(variables, jnp.ones((28, 28, 1)))
dist.event_shape
Which raises the following error:
JaxTransformError: Jax transforms and Flax models cannot be mixed. (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.JaxTransformError)
Thankfully, evaluating log probs, i.e., dist.log_prob(jnp.zeros(6,)), runs without any error.
Any idea why this is happening? Am I doing something wrong when constructing the model?
On that note, I've also found that if I initialize the parameters like this:
The parameters for the conditioner are not instantiated. To fix this, I've used the workaround of evaluating the log prob of some dummy data when initializing the model:
But this feels a little hacky to me and suggests that perhaps I am doing something wrong in my model definition. Do you have a set of best practices for using Flax with Distrax (now that Haiku is deprecated)?
Thanks for the help!
The text was updated successfully, but these errors were encountered:
I've encountered a small error when implementing Distrax bijectors with Flax conditioner NNs, and I also have a question about best practices for using Distrax with Flax.
The error can be reproduced with the following setup (also in this Colab notebook https://colab.research.google.com/drive/1RLRZul_pHnglcT_-YZ7mcuKLU1qd3w5O?usp=sharing).
Which raises the following error:
Thankfully, evaluating log probs, i.e.,
dist.log_prob(jnp.zeros(6,))
, runs without any error.Any idea why this is happening? Am I doing something wrong when constructing the model?
On that note, I've also found that if I initialize the parameters like this:
The parameters for the conditioner are not instantiated. To fix this, I've used the workaround of evaluating the log prob of some dummy data when initializing the model:
But this feels a little hacky to me and suggests that perhaps I am doing something wrong in my model definition. Do you have a set of best practices for using Flax with Distrax (now that Haiku is deprecated)?
Thanks for the help!
The text was updated successfully, but these errors were encountered: