Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Distrax + Flax bijector error and best practices #263

Open
JamesAllingham opened this issue Nov 17, 2023 · 0 comments
Open

Distrax + Flax bijector error and best practices #263

JamesAllingham opened this issue Nov 17, 2023 · 0 comments

Comments

@JamesAllingham
Copy link

I've encountered a small error when implementing Distrax bijectors with Flax conditioner NNs, and I also have a question about best practices for using Distrax with Flax.

The error can be reproduced with the following setup (also in this Colab notebook https://colab.research.google.com/drive/1RLRZul_pHnglcT_-YZ7mcuKLU1qd3w5O?usp=sharing).

class Conditioner(nn.Module):
    event_shape: Sequence[int]
    num_bijector_params: int
    hidden_dims: Sequence[int]

    @nn.compact
    def __call__(self, z: Array, h: Array) -> Array:
        h = jnp.concatenate((z.flatten(), h.flatten()), axis=0)

        for hidden_dim in self.hidden_dims:
            h = nn.Dense(hidden_dim)(h)
            h = nn.relu(h)

        y = nn.Dense(np.prod(self.event_shape) * self.num_bijector_params)(h)
        y = y.reshape(tuple(self.event_shape) + (self.num_bijector_params,))

        return y

class MyModel(nn.Module):
    hidden_dims: Sequence[int]
    num_flows: int
    num_bins: int
    event_shape: Sequence[int]
    conditioner: Optional[KwArgs] = None

    @nn.compact
    def __call__(self, x, y: Optional[Array] = None):
        # base distribution
        output_dim = np.prod(self.event_shape)
        base = distrax.Independent(
            distrax.Normal(loc=jnp.zeros(output_dim,), scale=jnp.ones(output_dim,)), len(self.event_shape)
        )

        # bijector
        # Number of parameters for the rational-quadratic spline:
        # - `num_bins` bin widths
        # - `num_bins` bin heights
        # - `num_bins + 1` knot slopes
        # for a total of `3 * num_bins + 1` parameters.
        num_bijector_params = 3 * self.num_bins + 1

        layers = []
        mask = jnp.arange(0, np.prod(self.event_shape)) % 2
        mask = jnp.reshape(mask, self.event_shape)
        mask = mask.astype(bool)

        def bijector_fn(params: Array):
            return distrax.RationalQuadraticSpline(
                params, range_min=-3.0, range_max=3.0
            )

        h = x.flatten()

        # shared feature extractor
        for hidden_dim in self.hidden_dims:
            h = nn.Dense(hidden_dim)(h)
            h = nn.relu(h)

        for i in range(self.num_flows):
            conditioner = Conditioner(
                event_shape=self.event_shape,
                num_bijector_params=num_bijector_params,
                **(self.conditioner or {}),
            )

            layer = distrax.MaskedCoupling(
                mask=mask,
                bijector=bijector_fn,
                conditioner=functools.partial(conditioner, h=h),
            )

            layers.append(layer)
            mask = ~mask

        bijector = distrax.Inverse(distrax.Chain(layers))
        transformed = distrax.Transformed(base, bijector)

        if y is not None:
            return transformed, transformed.log_prob(y)
        else:
            return transformed
            
            
model = MyModel(
    hidden_dims = [64, 32],
    num_flows = 3,
    num_bins = 8,
    event_shape = (6,),
    conditioner = {'hidden_dims': [64, 32]}
)

variables = model.init(random.PRNGKey(0), jnp.empty((28, 28, 1)), y=jnp.empty((6,)))

dist = model.apply(variables, jnp.ones((28, 28, 1)))

dist.event_shape

Which raises the following error:

JaxTransformError: Jax transforms and Flax models cannot be mixed. (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.JaxTransformError)

Thankfully, evaluating log probs, i.e., dist.log_prob(jnp.zeros(6,)), runs without any error.

Any idea why this is happening? Am I doing something wrong when constructing the model?

On that note, I've also found that if I initialize the parameters like this:

variables = model.init(random.PRNGKey(0), jnp.empty((28, 28, 1)))

The parameters for the conditioner are not instantiated. To fix this, I've used the workaround of evaluating the log prob of some dummy data when initializing the model:

variables = model.init(random.PRNGKey(0), jnp.empty((28, 28, 1)), y=jnp.empty((6,)))

But this feels a little hacky to me and suggests that perhaps I am doing something wrong in my model definition. Do you have a set of best practices for using Flax with Distrax (now that Haiku is deprecated)?

Thanks for the help!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant