Skip to content

Commit

Permalink
Add __copy__ and __deepcopy__ methods to the _Sentinel object used to…
Browse files Browse the repository at this point in the history
… define the default value for `parent`, to simply return the sentinel object.

PiperOrigin-RevId: 485122748
  • Loading branch information
edloper authored and Flax Authors committed Oct 31, 2022
1 parent 05988d1 commit 7d21975
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 1 deletion.
10 changes: 9 additions & 1 deletion flax/linen/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,15 @@ def __init__(self):


class _Sentinel:
pass

def __copy__(self):
return self # Do not copy singleton sentinel.

def __deepcopy__(self, memo):
del memo
return self # Do not copy singleton sentinel.


_unspecified_parent = _Sentinel()


Expand Down
13 changes: 13 additions & 0 deletions tests/linen/linen_module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@

"""Tests for flax.linen."""

import copy
import dataclasses
import functools
import gc
import inspect
import operator
from typing import (Any, Callable, Generic, Mapping, NamedTuple, Sequence,
Tuple, TypeVar)
Expand Down Expand Up @@ -1790,6 +1792,17 @@ def __call__(self, input):
with self.assertRaises(errors.IncorrectPostInitOverrideError):
r.init(jax.random.PRNGKey(2), jnp.ones(3))

def test_deepcopy_unspecified_parent(self):
parent_parameter = inspect.signature(DummyModule).parameters['parent']
unspecified_parent = parent_parameter.default

self.assertIs(unspecified_parent,
copy.copy(unspecified_parent))

self.assertIs(unspecified_parent,
copy.deepcopy(unspecified_parent))


class LeakTests(absltest.TestCase):

def test_tracer_leaks(self):
Expand Down

0 comments on commit 7d21975

Please sign in to comment.