Skip to content

Commit

Permalink
simplify _DynamicContext
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Sep 2, 2022
1 parent 0451a55 commit 9ff36a1
Showing 1 changed file with 4 additions and 22 deletions.
26 changes: 4 additions & 22 deletions flax/linen/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,33 +158,15 @@ def _tabulate_context():

# Track parent relationship across Modules.
# -----------------------------------------------------------------------------
class _DynamicContext:
class _DynamicContext(threading.local):
"""Dynamic context."""
# TODO(marcvanzee): switch to using contextvars once minimum python version is
# 3.7

def __init__(self):
self._thread_data = threading.local()

@property
def module_stack(self):
if not hasattr(self._thread_data, 'module_stack'):
self._thread_data.module_stack = [None,]
return self._thread_data.module_stack

@property
def capture_stack(self):
"""Keeps track of the active capture_intermediates filter functions."""
if not hasattr(self._thread_data, 'capture_stack'):
self._thread_data.capture_stack = []
return self._thread_data.capture_stack

@property
def call_info_stack(self) -> List[_CallInfoContext]:
"""Keeps track of the active call_info_context."""
if not hasattr(self._thread_data, 'call_info_stack'):
self._thread_data.call_info_stack = []
return self._thread_data.call_info_stack
self.module_stack = [None,]
self.capture_stack = []
self.call_info_stack = []

# The global context
_context = _DynamicContext()
Expand Down

0 comments on commit 9ff36a1

Please sign in to comment.