Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Module bug #481

Closed
ASEM000 opened this issue Sep 8, 2023 · 4 comments
Closed

Module bug #481

ASEM000 opened this issue Sep 8, 2023 · 4 comments
Labels
bug Something isn't working

Comments

@ASEM000
Copy link
Contributor

ASEM000 commented Sep 8, 2023

Hello patrick,

I think the following is a bug in Module

import equinox as eqx

class ParentTree(eqx.Module):
    def __init_subclass__(klass: type, foo: bool = False):
        if foo:
            klass.foo = True

class Tree(ParentTree, foo=True):
    def __init__(self, a: int):
        self.a = a

Tree(a=1).foo
# _ModuleMeta.__new__() got an unexpected keyword argument 'foo'

if i omit foo a different error comes up:

class Tree(ParentTree):
    def __init__(self, a: int):
        self.a = a

Tree(a=1)
# AttributeError: Cannot set attribute a

In comparison to pytreeclass or simple_pytree, this works as expected:

import pytreeclass as pytc
class ParentTree(pytc.TreeClass):
    def __init_subclass__(klass: type, foo: bool = False):
        if foo:
            klass.foo = True
class Tree(ParentTree, foo=True):
    def __init__(self, a: int):
        self.a = a
Tree(a=1).foo
# True
from simple_pytree import Pytree
class ParentTree(Pytree):
    def __init_subclass__(klass: type, foo: bool = False):
        if foo:
            klass.foo = True
class Tree(ParentTree, foo=True):
    def __init__(self, a: int):
        self.a = a
Tree(a=1).foo
# True

TheModule logic became more complex since the last time i checked, so Im not sure the reason of the two errors. But one pointer is that maybe propagate kwargs in module meta new.

@patrick-kidger
Copy link
Owner

Regarding foo=True: thanks! This'll be fixed by #482.

Regarding omitting foo: Equinox's syntax involves explicitly declaring all attributes. You want

class Tree(ParentTree):
    a: int  # this bit is new

    def __init__(self, a: int):
        self.a = a

although also note that the default __init__ method matches the one explicity provided, so that can be skipped. Both the attribute declaration and the default __init__ are the same behaviour as with any dataclass.

@ASEM000
Copy link
Contributor Author

ASEM000 commented Sep 8, 2023

This was quick :D

@ASEM000
Copy link
Contributor Author

ASEM000 commented Sep 8, 2023

Regarding omitting foo: Equinox's syntax involves explicitly declaring all attributes.

My bad.

@patrick-kidger
Copy link
Owner

Closing! Fixed in #482, and this will appear in the upcoming v0.11.0 release.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants