Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fail to cloudpickle nn.Module when property is set: "maximum recursion depth exceeded" #3283

Closed
fdzzy opened this issue Aug 17, 2023 · 0 comments · Fixed by #3286
Closed

Comments

@fdzzy
Copy link

fdzzy commented Aug 17, 2023

Colab link

I'm trying to use cloudpickle to dump and reload a flax.linen.Module object, when the object has "@Property" set, it would fail with RecursionError: maximum recursion depth exceeded.

This is the minimal code to repro:

import pickle
from cloudpickle import cloudpickle_fast
from flax import linen as nn

class NNModuleWithProperty(nn.Module):
  @property
  def my_property(self):
    print('Calling my_property')
    return "my_property"

def dump_and_reload(pickle_cls, obj, filename='/tmp/test_file.pkl'):
  with open(filename, 'wb') as f:
    pickle_cls.dump(obj, f)

  with open(filename, 'rb') as f:
    obj_loaded = pickle_cls.load(f)
  return obj_loaded

dump_and_reload(cloudpickle_fast, NNModuleWithProperty())

It would fail with the following stack trace:

---------------------------------------------------------------------------
RecursionError                            Traceback (most recent call last)
[<ipython-input-6-2df3a34fadbf>](https://localhost:8080/#) in <cell line: 1>()
----> 1 dump_and_reload(cloudpickle_fast, NNModuleWithProperty())

2 frames
[<ipython-input-3-c689b3a91da9>](https://localhost:8080/#) in dump_and_reload(pickle_cls, obj, filename)
     19 
     20   with open(filename, 'rb') as f:
---> 21     obj_loaded = pickle_cls.load(f)
     22   return obj_loaded

[/usr/local/lib/python3.10/dist-packages/flax/linen/module.py](https://localhost:8080/#) in __getattr__(self, name)
    711 
    712     def __getattr__(self, name):
--> 713       return getattr(self.wrapped, name)
    714 
    715   return _DescriptorWrapper(descriptor)

... last 1 frames repeated, from the frame below ...

[/usr/local/lib/python3.10/dist-packages/flax/linen/module.py](https://localhost:8080/#) in __getattr__(self, name)
    711 
    712     def __getattr__(self, name):
--> 713       return getattr(self.wrapped, name)
    714 
    715   return _DescriptorWrapper(descriptor)

RecursionError: maximum recursion depth exceeded

However, it works well with pickle:

dump_and_reload(pickle, NNModuleWithProperty())

System information

  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Linux Ubuntu 22.04.2 LTS
  • Flax, jax, jaxlib versions (obtain with pip show flax jax jaxlib:
    • flax: 0.7.2
    • jax: 0.4.14
    • jaxlib: 0.4.14+cuda11.cudnn86
  • Python version: 3.10.12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant