Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve tabulate #2316

Merged
merged 1 commit into from
Aug 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 122 additions & 52 deletions flax/linen/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,20 @@
import enum
import functools
import inspect
import re
import threading
import typing
import weakref
from typing import (Any, Callable, Dict, Generic, Iterable, List, Optional,
Sequence, Set, Tuple, Type, TypeVar, Union, overload)
from typing import (Any, Callable, Dict, Iterable, List, NamedTuple, Mapping,
Optional, Set, Tuple, Type, TypeVar, Union, overload)

import jax
import numpy as np
import jax.numpy as jnp
from typing_extensions import \
dataclass_transform # pytype: disable=not-supported-yet

import flax
from flax import (config, core, errors, serialization, traceback_util,
traverse_util)
from flax.core import Scope
Expand All @@ -37,8 +40,6 @@
CollectionFilter, DenyList, FrozenVariableDict, Variable, VariableDict,
union_filters)
from flax.ids import uuid
from flax.linen import summary



traceback_util.register_exclusion(__file__)
Expand All @@ -61,6 +62,16 @@

# pylint: disable=protected-access,attribute-defined-outside-init

def _get_value_representation(x: Any) -> 'flax.linen.summary._ValueRepresentation':
from flax.linen import summary

if isinstance(x, (int, float, bool, type(None))) or (
isinstance(x, np.ndarray) and np.isscalar(x)):
return summary._ObjectRepresentation(x)
try:
return summary._ArrayRepresentation(jnp.shape(x), jnp.result_type(x))
except:
return summary._ObjectRepresentation(x)

def _indent(x: str, num_spaces: int):
indent_str = ' ' * num_spaces
Expand Down Expand Up @@ -104,6 +115,46 @@ def _module_repr(module: 'Module', num_spaces: int = 4):
else:
return f'{cls_name}()'

#
# -----------------------------------------------------------------------------

_find_non_lifted_module = re.compile(r'.*\((.*)\)')

def _fix_path_part(part: str):
"""Fixes a path part by removing transformation name and parenthesis sometimes
inserted by lifted transformations"""
match = _find_non_lifted_module.match(part)
if match:
return match.group(1)
return part

@dataclasses.dataclass
class _CallInfo:
index: int
path: Tuple[str, ...]
module_type: Type['Module']
method: str
args: Tuple[Any, ...]
kwargs: Dict[str, Any]
outputs: Any

@dataclasses.dataclass
class _CallInfoContext(threading.local):
index: int
calls: List[_CallInfo]

def get_call_index(self, module: 'Module') -> int:
index = self.index
self.index += 1
return index

@contextlib.contextmanager
def _tabulate_context():
_context.call_info_stack.append(_CallInfoContext(0, []))
try:
yield
finally:
_context.call_info_stack.pop()

# Track parent relationship across Modules.
# -----------------------------------------------------------------------------
Expand All @@ -128,6 +179,13 @@ def capture_stack(self):
self._thread_data.capture_stack = []
return self._thread_data.capture_stack

@property
def call_info_stack(self) -> List[_CallInfoContext]:
"""Keeps track of the active call_info_context."""
if not hasattr(self._thread_data, 'call_info_stack'):
self._thread_data.call_info_stack = []
return self._thread_data.call_info_stack

# The global context
_context = _DynamicContext()

Expand Down Expand Up @@ -638,6 +696,7 @@ def _call_wrapped_method(self, fun, args, kwargs):
is_compact_method = hasattr(fun, 'compact')
fun_name = getattr(fun, '__name__', 'unnamed_function')
is_setup_method = fun_name == 'setup'
add_call_info = not is_setup_method and len(_context.call_info_stack) > 0
# We lazily call setup() only when needed.
if is_setup_method:
is_recurrent = self._state.in_setup
Expand All @@ -652,15 +711,27 @@ def _call_wrapped_method(self, fun, args, kwargs):
self._state.in_compact_method = True
_context.module_stack.append(self)
try:
# get call info
if add_call_info:
call_index = _context.call_info_stack[-1].get_call_index(self)
scope_path = jax.tree_util.tree_map(_fix_path_part, self.scope.path)

# call method
if _use_named_call:
with jax.named_scope(_derive_profiling_name(self, fun)):
y = fun(self, *args, **kwargs)
else:
y = fun(self, *args, **kwargs)

if _context.capture_stack:
filter_fn = _context.capture_stack[-1]
if filter_fn and filter_fn(self, fun_name):
self.sow('intermediates', fun_name, y)
if add_call_info:
_args, _kwargs, _y = jax.tree_util.tree_map(
_get_value_representation, (args, kwargs, y), is_leaf=lambda x: x is None)
_context.call_info_stack[-1].calls.append(
_CallInfo(call_index, scope_path, type(self), fun.__name__, _args, _kwargs, _y))
return y
finally:
_context.module_stack.pop()
Expand Down Expand Up @@ -1410,17 +1481,17 @@ def tabulate(
self,
rngs: Union[PRNGKey, RNGSequences],
*args,
method: Optional[Callable[..., Any]] = None,
mutable: CollectionFilter = True,
depth: Optional[int] = None,
exclude_methods: Sequence[str] = (),
show_repeated: bool = False,
mutable: CollectionFilter = True,
console_kwargs: Optional[Mapping[str, Any]] = None,
**kwargs) -> str:
"""Creates a summary of the Module represented as a table.

This method has the same signature as `init`, but instead of returning
the variables, it returns the string summarizing the Module in a table.
`tabulate` uses `jax.eval_shape` to run the forward computation without
consuming any FLOPs or allocating memory.
This method has the same signature and internally calls `Module.init`,
but instead of returning the variables, it returns the string summarizing
the Module in a table. `tabulate` uses `jax.eval_shape` to run the forward
computation without consuming any FLOPs or allocating memory.

Example::

Expand All @@ -1441,61 +1512,60 @@ def __call__(self, x):

This gives the following output::

Foo Summary
┏━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┓
┃ path ┃ outputs ┃ params ┃
┡━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━┩
│ Inputs │ float32[16,9] │ │
├─────────┼───────────────┼──────────────────────┤
│ Dense_0 │ float32[16,4] │ bias: float32[4] │
│ │ │ kernel: float32[9,4] │
│ │ │ │
│ │ │ 40 (160 B) │
├─────────┼───────────────┼──────────────────────┤
│ Dense_1 │ float32[16,2] │ bias: float32[2] │
│ │ │ kernel: float32[4,2] │
│ │ │ │
│ │ │ 10 (40 B) │
├─────────┼───────────────┼──────────────────────┤
│ Foo │ float32[16,2] │ │
├─────────┼───────────────┼──────────────────────┤
│ │ Total │ 50 (200 B) │
└─────────┴───────────────┴──────────────────────┘

Total Parameters: 50 (200 B)
Foo Summary
┏━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┓
┃ path ┃ module ┃ inputs ┃ outputs ┃ params ┃
┡━━━━━━━━━╇━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━┩
│ │ Foo │ float32[16,9] │ float32[16,2] │ │
├─────────┼────────┼───────────────┼───────────────┼──────────────────────┤
│ Dense_0 │ Dense │ float32[16,9] │ float32[16,4] │ bias: float32[4] │
│ │ │ │ │ kernel: float32[9,4] │
│ │ │ │ │ │
│ │ │ │ │ 40 (160 B) │
├─────────┼────────┼───────────────┼───────────────┼──────────────────────┤
│ Dense_1 │ Dense │ float32[16,4] │ float32[16,2] │ bias: float32[2] │
│ │ │ │ │ kernel: float32[4,2] │
│ │ │ │ │ │
│ │ │ │ │ 10 (40 B) │
├─────────┼────────┼───────────────┼───────────────┼──────────────────────┤
│ │ │ │ Total │ 50 (200 B) │
└─────────┴────────┴───────────────┴───────────────┴──────────────────────┘

Total Parameters: 50 (200 B)

**Note**: rows order in the table does not represent execution order,
instead it aligns with the order of keys in `variables` which are sorted
alphabetically.

Args:
rngs: The rngs for the variable collections.
rngs: The rngs for the variable collections as passed to `Module.init`.
*args: The arguments to the forward computation.
method: An optional method. If provided, applies this method. If not
provided, applies the ``__call__`` method.
mutable: Can be bool, str, or list. Specifies which collections should be
treated as mutable: ``bool``: all/no collections are mutable.
``str``: The name of a single mutable collection. ``list``: A
list of names of mutable collections. By default all collections
except 'intermediates' are mutable.
depth: controls how many submodule deep the summary can go. By default its
`None` which means no limit. If a submodule is not shown because of the
depth limit, its parameter count and bytes will be added to the row of
its first shown ancestor such that the sum of all rows always adds up to
the total number of parameters of the Module.
exclude_methods: A sequence of strings that specifies which methods should
be ignored. In case a module calls a helper method from its main method,
use this argument to exclude the helper method from the summary to avoid
ambiguity.
depth limit, its parameter count and bytes will be added to the row of its
first shown ancestor such that the sum of all rows always adds up to the
total number of parameters of the Module.
show_repeated: If `True`, repeated calls to the same module will be shown
in the table, otherwise only the first call will be shown. Default is
`False`.
mutable: Can be bool, str, or list. Specifies which collections should be
treated as mutable: ``bool``: all/no collections are mutable. ``str``: The
name of a single mutable collection. ``list``: A list of names of mutable
collections. By default all collections except 'intermediates' are
mutable.
console_kwargs: An optional dictionary with additional keyword arguments that
are passed to `rich.console.Console` when rendering the table. Default arguments
are `{'force_terminal': True, 'force_jupyter': False}`.
**kwargs: keyword arguments to pass to the forward computation.

Returns:
A string summarizing the Module.
"""

tabulate_fn = summary.tabulate(self, rngs, method=method,
mutable=mutable, depth=depth,
exclude_methods=exclude_methods)
from flax.linen import summary

tabulate_fn = summary.tabulate(self, rngs, depth=depth,
show_repeated=show_repeated, mutable=mutable,
console_kwargs=console_kwargs)
return tabulate_fn(*args, **kwargs)


Expand Down
Loading