From b473704c11530c186817269f7f8537b7888d89e2 Mon Sep 17 00:00:00 2001 From: Mark Sandler Date: Mon, 22 Apr 2024 15:17:26 -0700 Subject: [PATCH] Fix LogicalRules type annotation. (Tuple[str] is a tuple with single element string, Tuple[str, ...] is the intended type here) PiperOrigin-RevId: 627171032 --- flax/linen/spmd.py | 8 ++++---- flax/typing.py | 6 ++++-- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/flax/linen/spmd.py b/flax/linen/spmd.py index fbfd3fd3b..1610d7f87 100644 --- a/flax/linen/spmd.py +++ b/flax/linen/spmd.py @@ -107,9 +107,9 @@ def _mesh_assignment_free(new_assignment, existing_assignments): def _logical_to_mesh_axes( - array_dim_names: Optional[Sequence[Optional[str]]], - rules: Optional[LogicalRules] = None, -) -> Optional[List[Union[_UnassignedAxis, None, str, Tuple[str]]]]: + array_dim_names: Optional[Sequence[Optional[str]]], + rules: Optional[LogicalRules] = None, +) -> Optional[List[Union[_UnassignedAxis, None, str, Tuple[str, ...]]]]: """Same as logical_to_mesh_axes, but doesn't fill in _unassigned_axis.""" if array_dim_names is None: return None @@ -126,7 +126,7 @@ def _logical_to_mesh_axes( if not isinstance(rules, (tuple, list)): raise ValueError('Unknown axis rule specification type.') # We assign mesh axes using a priority based ruleset over logical axis names. - result: List[Union[_UnassignedAxis, None, str, Tuple[str]]] + result: List[Union[_UnassignedAxis, None, str, Tuple[str, ...]]] result = [_unassigned_axis] * len(array_dim_names) for rule_model_name, rule_mesh_names in rules: if rule_model_name in array_dim_names: diff --git a/flax/typing.py b/flax/typing.py index d566a31c8..d6ecf02ca 100644 --- a/flax/typing.py +++ b/flax/typing.py @@ -119,10 +119,12 @@ class Out(Generic[T]): LogicalNames = Tuple[Union[str, None], ...] -LogicalRules = Sequence[Tuple[str, Union[str, Tuple[str], None]]] +# Maps each logical axis to physical mesh, can be either None (replicated), +# one physical axis or a tuple of physical axes. +LogicalRules = Sequence[Tuple[str, Union[str, Tuple[str, ...], None]]] ArrayPytree = Any # pylint: disable=invalid-name LogicalPartitionSpec = Any # pylint: disable=invalid-name LogicalPartitionSpecPytree = Any # pylint: disable=invalid-name PartitionSpecPytree = Any # pylint: disable=invalid-name -Sharding = Tuple[Optional[str], ...] \ No newline at end of file +Sharding = Tuple[Optional[str], ...]