diff --git a/flax/experimental/nnx/nnx/variables.py b/flax/experimental/nnx/nnx/variables.py index 53d541e31..c5577d386 100644 --- a/flax/experimental/nnx/nnx/variables.py +++ b/flax/experimental/nnx/nnx/variables.py @@ -385,7 +385,10 @@ def _variable_unflatten( *, cls: type[Variable[A]], ) -> Variable[A]: - return cls(children[0], **metadata) # type: ignore + variable = object.__new__(cls) + variable.value = children[0] + vars(variable).update(metadata) + return variable jtu.register_pytree_with_keys(