Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Require transform_metadata when variables have sharding annotation #4187

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions flax/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,25 @@ def __reduce__(self):
return (FlaxError, (str(self),))


#################################################
# NNX errors #
#################################################


class TraceContextError(FlaxError):
pass


class AxisNameMissingError(FlaxError):
def __init__(self, x_sharding):
super().__init__(
'You are trying to modify param dimension via transforms like `nnx.vmap` '
f'or `nnx.scan`, while the param is partition-annotated as: {x_sharding} '
'You need to provide the axis name of the transform via extra '
'argument: transform_metadata={nnx.PARTITION_NAME: "your_axis_name"}'
)


#################################################
# lazy_init.py errors #
#################################################
Expand Down
16 changes: 9 additions & 7 deletions flax/nnx/spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
PartitionSpecPytree, # pylint: disable=invalid-name
Sharding,
)
from flax import errors

A = tp.TypeVar('A')
F = tp.TypeVar('F', bound=tp.Callable[..., tp.Any])
Expand All @@ -38,13 +39,15 @@ def add_axis(tree: A, index: int, params: tp.Mapping[tp.Any, tp.Any]) -> A:
def _add_axis(x: tp.Any):
if isinstance(x, variables.VariableState):
if hasattr(x, 'sharding') and x.sharding is not None:
if axis_name is None:
raise errors.AxisNameMissingError(x.sharding)
sharding: list[str | None] = list(x.sharding)
while len(sharding) < index:
sharding.append(None)
sharding.insert(index, axis_name)
x.sharding = tuple(sharding) # type: ignore

x.add_axis(axis_name, index)
x.add_axis(index, axis_name)
return x

return jax.tree.map(
Expand All @@ -58,10 +61,12 @@ def remove_axis(tree: A, index: int, params: tp.Mapping[tp.Any, tp.Any]) -> A:
def _remove_axis(x: tp.Any):
if isinstance(x, variables.VariableState):
if hasattr(x, 'sharding') and x.sharding is not None:
if axis_name is None:
raise errors.AxisNameMissingError(x.sharding)
sharding = list(x.sharding)
assert sharding.pop(index) == axis_name
x.sharding = tuple(sharding)
x.remove_axis(axis_name, index)
x.remove_axis(index, axis_name)
return x

return jax.tree.map(
Expand All @@ -71,12 +76,9 @@ def _remove_axis(x: tp.Any):
)


def _get_partition_name(params: tp.Mapping[tp.Any, tp.Any]) -> str:
def _get_partition_name(params: tp.Mapping[tp.Any, tp.Any]) -> str | None:
if PARTITION_NAME not in params:
raise ValueError(
'Trying to transform a Partitioned variable but "partition_name" '
f'is not specified in scan_metadata: {params}'
)
return None
return params[PARTITION_NAME]


Expand Down
47 changes: 21 additions & 26 deletions flax/nnx/transforms/iteration.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,10 +152,10 @@ def __post_init__(self):
functools.update_wrapper(self, self.f)

def __call__(self, *pure_args: tuple[tp.Any, ...]):
if spmd.PARTITION_NAME in self.transform_metadata:
pure_args = _update_variable_sharding_metadata(
pure_args, self.transform_metadata, spmd.remove_axis
)
print(self.transform_metadata)
pure_args = _update_variable_sharding_metadata(
pure_args, self.transform_metadata, spmd.remove_axis
)
args = extract.from_tree(pure_args, ctxtag='vmap')

out = self.f(*args)
Expand All @@ -167,10 +167,9 @@ def __call__(self, *pure_args: tuple[tp.Any, ...]):
split_fn=_vmap_split_fn,
ctxtag='vmap',
)
if spmd.PARTITION_NAME in self.transform_metadata:
pure_args_out, pure_out = _update_variable_sharding_metadata(
(pure_args_out, pure_out), self.transform_metadata, spmd.add_axis
)
pure_args_out, pure_out = _update_variable_sharding_metadata(
(pure_args_out, pure_out), self.transform_metadata, spmd.add_axis
)
return pure_args_out, pure_out


Expand Down Expand Up @@ -356,10 +355,9 @@ def __post_init__(self):
functools.update_wrapper(self, self.f)

def __call__(self, *pure_args: tuple[tp.Any, ...]):
if spmd.PARTITION_NAME in self.transform_metadata:
pure_args = _update_variable_sharding_metadata(
pure_args, self.transform_metadata, spmd.remove_axis
)
pure_args = _update_variable_sharding_metadata(
pure_args, self.transform_metadata, spmd.remove_axis
)
args = extract.from_tree(pure_args, ctxtag='pmap')

out = self.f(*args)
Expand All @@ -371,10 +369,9 @@ def __call__(self, *pure_args: tuple[tp.Any, ...]):
split_fn=_vmap_split_fn,
ctxtag='pmap',
)
if spmd.PARTITION_NAME in self.transform_metadata:
pure_args_out, pure_out = _update_variable_sharding_metadata(
(pure_args_out, pure_out), self.transform_metadata, spmd.add_axis
)
pure_args_out, pure_out = _update_variable_sharding_metadata(
(pure_args_out, pure_out), self.transform_metadata, spmd.add_axis
)
return pure_args_out, pure_out


Expand Down Expand Up @@ -994,10 +991,9 @@ def __call__(
assert self.input_carry_argnum is None
assert pure_carry_arg is None

if spmd.PARTITION_NAME in self.transform_metadata:
pure_args = _update_variable_sharding_metadata(
pure_args, self.transform_metadata, spmd.remove_axis
)
pure_args = _update_variable_sharding_metadata(
pure_args, self.transform_metadata, spmd.remove_axis
)

args: tuple = extract.from_tree(
pure_args,
Expand Down Expand Up @@ -1065,12 +1061,11 @@ def __call__(
map_non_graph_nodes=True,
ctxtag='scan',
)
if spmd.PARTITION_NAME in self.transform_metadata:
pure_args_out, pure_out = _update_variable_sharding_metadata(
(pure_args_out, pure_out),
self.transform_metadata,
spmd.add_axis,
)
pure_args_out, pure_out = _update_variable_sharding_metadata(
(pure_args_out, pure_out),
self.transform_metadata,
spmd.add_axis,
)

# extract the pure carry from the pure args
if self.input_carry_argnum == 'all':
Expand Down
20 changes: 9 additions & 11 deletions flax/nnx/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -870,17 +870,15 @@ def get_metadata(self) -> dict[str, tp.Any]:
del metadata['value']
return metadata

def add_axis(self, axis_name: AxisName, axis_index: AxisIndex):
if not hasattr(self, 'add_axis_hooks'):
raise ValueError(f'No add_axis_hooks found for VariableState: {self}')
for hook in self.add_axis_hooks:
hook(self, axis_name, axis_index)

def remove_axis(self, axis_name: AxisName, axis_index: AxisIndex):
if not hasattr(self, 'remove_axis_hooks'):
raise ValueError(f'No remove_axis_hooks found for VariableState: {self}')
for hook in self.remove_axis_hooks:
hook(self, axis_name, axis_index)
def add_axis(self, axis_index: AxisIndex, axis_name: AxisName | None = None):
if hasattr(self, 'add_axis_hooks'):
for hook in self.add_axis_hooks:
hook(self, axis_name, axis_index)

def remove_axis(self, axis_index: AxisIndex, axis_name: AxisName | None = None):
if hasattr(self, 'remove_axis_hooks'):
for hook in self.remove_axis_hooks:
hook(self, axis_name, axis_index)


def _variable_state_flatten(x: VariableState[tp.Any], *, with_keys: bool):
Expand Down
Loading
You are viewing a condensed version of this merge commit. You can view the full changes here.