Skip to content

Commit

Permalink
remove _call_wrapped_get_descriptor
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Nov 2, 2022
1 parent 4eb15b2 commit d209640
Showing 1 changed file with 20 additions and 30 deletions.
50 changes: 20 additions & 30 deletions flax/linen/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,10 +447,7 @@ def wrap_descriptor_once(descriptor) -> "DescriptorWrapper":
if isinstance(descriptor, DescriptorWrapper):
return descriptor

def get_fn(descriptor, self):
return self._call_wrapped_get_descriptor(descriptor)

return create_descriptor_wrapper(descriptor, get_fn=get_fn)
return create_descriptor_wrapper(descriptor)


def _wrap_hash(hash_fn: Callable[..., Any]) -> Callable[..., Any]:
Expand Down Expand Up @@ -602,20 +599,24 @@ def __set_name__(self, owner, name) -> None: ...
class DescriptorWrapper:
pass

def create_descriptor_wrapper(descriptor: Descriptor, get_fn):
def create_descriptor_wrapper(descriptor: Descriptor):
"""Creates a descriptor wrapper that calls a get_fn on the descriptor."""

class _DescriptorWrapper(DescriptorWrapper):
"""A descriptor that can wrap any descriptor"""

def __init__(self):
self.get_fn = get_fn
self.wrapped = descriptor

# conditionally define descriptor methods
if hasattr(descriptor, '__get__'):
def __get__(self, obj, objtype=None):
return self.get_fn(self.wrapped, obj)
# here we will catch internal AttributeError and re-raise it as a
# more informative and correct error message.
try:
return self.wrapped.__get__(self)
except AttributeError as e:
raise errors.DescriptorAttributeError() from e
if hasattr(descriptor, '__set__'):
def __set__(self, obj, value):
return self.wrapped.__set__(obj, value)
Expand Down Expand Up @@ -753,37 +754,26 @@ def _wrap_module_attributes(cls):
"""Wraps user-defined non-inherited methods and descriptors with state
management functions.
"""
exclusions = ([f.name for f in dataclasses.fields(cls)] +
# wrap methods
method_exclusions = ([f.name for f in dataclasses.fields(cls)] +
['__eq__', '__repr__', '__init__', '__hash__',
'__post_init__', '__dict__'])
for key in _get_local_method_names(cls, exclude=exclusions):
'__post_init__'])
for key in _get_local_method_names(cls, exclude=method_exclusions):
method = getattr(cls, key)
if hasattr(method, 'nowrap'):
continue
setattr(cls, key, wrap_method_once(method))
for key in _get_local_descriptor_names(cls, exclusions):
prop = getattr(cls, key)
if hasattr(prop, 'nowrap'):

# wrap descriptors
descriptor_exclusions = ([f.name for f in dataclasses.fields(cls)] +
['parent', '__dict__'])
for key in _get_local_descriptor_names(cls, descriptor_exclusions):
descriptor = getattr(cls, key)
if hasattr(descriptor, 'nowrap'):
continue
setattr(cls, key, wrap_descriptor_once(prop))
setattr(cls, key, wrap_descriptor_once(descriptor))
return cls

def _call_wrapped_get_descriptor(self, descriptor: Descriptor):
"""Calls a wrapped property.
This function wraps all properties, its purpose is to ensure that
AttributeErrors are raised as a custom PropertyAttributeError so that
users get the correct error message instead of the default AttributeError
telling them that the itself property is not defined.
Args:
fun: the wrapped property function.
"""
try:
return descriptor.__get__(self)
except AttributeError as e:
raise errors.DescriptorAttributeError() from e

def _call_wrapped_method(self, fun, args, kwargs):
""""Calls a wrapped method.
Expand Down

0 comments on commit d209640

Please sign in to comment.