Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AttributeError isn't raised correctly from properties of linen modules #2537

Closed
lucaslingle opened this issue Oct 15, 2022 · 2 comments · Fixed by #2541
Closed

AttributeError isn't raised correctly from properties of linen modules #2537

lucaslingle opened this issue Oct 15, 2022 · 2 comments · Fixed by #2541
Assignees
Labels
Priority: P2 - no schedule Best effort response and resolution. We have no plan to work on this at the moment.

Comments

@lucaslingle
Copy link

System information

  • OS Platform and Distribution: Any
  • Flax, jax, jaxlib versions flax==0.6.1, jax==0.3.23, jaxlib==0.3.22
  • Python version: 3.10
  • GPU/TPU model and memory: N/A
  • CUDA version (if applicable): N/A

Problem you have encountered:

Linen modules do not seem to report AttributeErrors correctly for non-existent attributes accessed by properties. Instead, the raised error claims that the property itself does not exist.

This issue doesn't seem to occur in linen module methods, and to the best of my knowledge, it doesn't occur with other types of errors in linen properties, such as with mismatched einsum indices.

What you expected to happen:

Properties in linen modules would ideally perform error reporting the same way as properties of ordinary classes. E.g.,

class Foo:
    def __init__(self):
        self.bar = 0
    @property
    def prop(self):
        return self.baz
    def __call__(self):
        return self.prop

foo = Foo()
foo()

which gives

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "<stdin>", line 8, in __call__
  File "<stdin>", line 6, in prop
AttributeError: 'Foo' object has no attribute 'baz'

Steps to reproduce:

However, Linen modules do not perform error reporting correctly for properties:

import flax.linen as nn
import jax
import jax.numpy as jnp

class Foo(nn.Module):
    def setup(self):
        self.bar = self.param("bar", jax.nn.initializers.normal(0.01), [10, 10], jnp.float32)
    @property
    def prop(self):
        return self.baz
    def __call__(self, inputs):
        return self.prop

foo = Foo()
params = foo.init(
    {"params": jax.random.PRNGKey(0)}, 
    inputs=jnp.ones(shape=[64, 128], dtype=jnp.float32)
)["params"]

Logs, error messages, etc:

Running the code immediately above gives

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/lucaslingle/opt/miniconda3/envs/some_project/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/Users/lucaslingle/opt/miniconda3/envs/some_project/lib/python3.10/site-packages/flax/linen/module.py", line 1333, in init
    _, v_out = self.init_with_output(
  File "/Users/lucaslingle/opt/miniconda3/envs/some_project/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/Users/lucaslingle/opt/miniconda3/envs/some_project/lib/python3.10/site-packages/flax/linen/module.py", line 1289, in init_with_output
    return init_with_output(
  File "/Users/lucaslingle/opt/miniconda3/envs/some_project/lib/python3.10/site-packages/flax/core/scope.py", line 897, in wrapper
    return apply(fn, mutable=mutable, flags=init_flags)({}, *args, rngs=rngs,
  File "/Users/lucaslingle/opt/miniconda3/envs/some_project/lib/python3.10/site-packages/flax/core/scope.py", line 865, in wrapper
    y = fn(root, *args, **kwargs)
  File "/Users/lucaslingle/opt/miniconda3/envs/some_project/lib/python3.10/site-packages/flax/linen/module.py", line 1750, in scope_fn
    return fn(module.clone(parent=scope), *args, **kwargs)
  File "/Users/lucaslingle/opt/miniconda3/envs/some_project/lib/python3.10/site-packages/flax/linen/module.py", line 402, in wrapped_module_method
    return self._call_wrapped_method(fun, args, kwargs)
  File "/Users/lucaslingle/opt/miniconda3/envs/some_project/lib/python3.10/site-packages/flax/linen/module.py", line 705, in _call_wrapped_method
    y = fun(self, *args, **kwargs)
  File "<stdin>", line 8, in __call__
  File "/Users/lucaslingle/opt/miniconda3/envs/some_project/lib/python3.10/site-packages/flax/linen/module.py", line 783, in __getattr__
    raise AttributeError(msg)
jax._src.traceback_util.UnfilteredStackTrace: AttributeError: "Foo" object has no attribute "prop".
@cgarciae cgarciae added the Priority: P2 - no schedule Best effort response and resolution. We have no plan to work on this at the moment. label Oct 17, 2022
@cgarciae cgarciae self-assigned this Oct 17, 2022
@cgarciae
Copy link
Collaborator

@lucaslingle thanks for reporting this. Its not that we are not reporting an AttributeError for the incorrect attribute, it seems that we currently don't support properties at all. I'll look into this.

@cgarciae
Copy link
Collaborator

Update, after some digging, it seems this error will happen with any code that has the following properties:

  • Class implements __getattr__
  • The property raises a AttributeError

We will look a way to wrap properties to catch this attribute error an launch a custom error.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Priority: P2 - no schedule Best effort response and resolution. We have no plan to work on this at the moment.
Projects
None yet
2 participants