Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Typing][C-4][C-7][C-9][C-18]Add type annotations for python/paddle/distributed/auto_parallel/interface.py+strategy.py+placement_type.py and python/paddle/distributed/communication/group.py #66710

Closed
wants to merge 6 commits into from
39 changes: 27 additions & 12 deletions python/paddle/distributed/auto_parallel/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from functools import reduce
from typing import List, Tuple
from typing import TYPE_CHECKING, Any, Callable

import numpy as np

Expand All @@ -30,8 +32,15 @@
verify_shard_spec,
)

if TYPE_CHECKING:
from paddle import Tensor


def shard_tensor(x, process_mesh=None, shard_spec=None):
def shard_tensor(
x: Tensor,
process_mesh: ProcessMesh = None,
shard_spec: list[str] = None, # None 或者包含str或None的列表
) -> Tensor:
"""
Shard a tensor on a process mesh according to the shard specification.

Expand Down Expand Up @@ -118,8 +127,12 @@ def shard_tensor(x, process_mesh=None, shard_spec=None):


def shard_op(
op, process_mesh=None, in_shard_specs=None, out_shard_specs=None, **kwargs
):
op: Callable, # 假设Callable是合适类型注解
process_mesh: ProcessMesh = None,
in_shard_specs: list[list[str]] = None, # 嵌套的Optional和List
out_shard_specs: list[list[str]] = None,
**kwargs,
) -> Any:
"""
Shard an operation on a process mesh according to its input and output shard specification.

Expand Down Expand Up @@ -207,7 +220,7 @@ def shard_op(
_g_recompute_idx = -1


def recompute(op):
def recompute(op: Callable) -> None:
global _g_recompute_idx
_g_recompute_idx += 1

Expand All @@ -233,7 +246,7 @@ def __call__(self, *args, **kwargs):
return RecomputeOperator(op)


def exclude_ops_in_recompute(run_function):
def exclude_ops_in_recompute(run_function: Callable) -> None:
"""
Exclude some operators in recompute segments.
Args:
Expand Down Expand Up @@ -272,15 +285,17 @@ class CollectionNames:
LOGGING = "logging"


def get_collection(name):
def get_collection(name: str) -> dict:
collection = _g_collections.get(name, None)
if collection is None:
collection = []
_g_collections[name] = collection
return _g_collections[name]


def add_to_collection(collection_name, value, name=None):
def add_to_collection(
collection_name: str, value: str, name: str = None
) -> dict:
if collection_name not in _g_collections:
_g_collections[collection_name] = []
if name is not None:
Expand All @@ -295,7 +310,7 @@ def add_to_collection(collection_name, value, name=None):
_g_collections[collection_name].append((None, value))


def fetch(tensor, name=None, logging=False):
def fetch(tensor: Tensor, name: str = None, logging: bool = False) -> None:
if isinstance(tensor, paddle.static.Variable):
tensor = tensor.name
elif isinstance(tensor, str):
Expand All @@ -312,17 +327,17 @@ def fetch(tensor, name=None, logging=False):
_g_mesh = None


def get_mesh():
def get_mesh() -> ProcessMesh:
global _g_mesh
return _g_mesh


def set_mesh(mesh):
def set_mesh(mesh) -> ProcessMesh:
global _g_mesh
_g_mesh = mesh


def create_mesh(mesh_dims: List[Tuple[str, int]]):
def create_mesh(mesh_dims: list[tuple[str, int]]) -> ProcessMesh:
"""
Create a global process_mesh for auto parallel.

Expand Down
19 changes: 14 additions & 5 deletions python/paddle/distributed/auto_parallel/placement_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import cast
from typing import List, cast

import paddle
from paddle.base.core import Partial, Replicate, Shard
from paddle.distributed import Placement

from .process_mesh import ProcessMesh

def to_placements(dim_map, mesh, partial_idx=[]):

def to_placements(
dim_map: list[int], mesh: ProcessMesh, partial_idx: list[int] = []
) -> List[Placement]:
"""
convert dim_map to placements.

Expand Down Expand Up @@ -55,7 +60,7 @@ def to_placements(dim_map, mesh, partial_idx=[]):
return placements


def check_placements_equal(this, that):
def check_placements_equal(this: list[any], that: list[any]) -> bool:
assert isinstance(this, list) and isinstance(that, list)
small_placements = this if len(this) < len(that) else that
large_placements = that if len(this) < len(that) else this
Expand All @@ -69,7 +74,9 @@ def check_placements_equal(this, that):
return True


def to_dim_map(placements, tensor_dims):
def to_dim_map(
placements: list[Placement], tensor_dims: int
) -> tuple[list[int], dict[int, str]]:
"""
convert placements to dim_map.

Expand Down Expand Up @@ -97,7 +104,9 @@ def to_dim_map(placements, tensor_dims):
return dim_map, partial_status


def get_shard_spec(mesh, placements, tensor_dims):
def get_shard_spec(
mesh: ProcessMesh, placements: list[Placement], tensor_dims: int
) -> List[str]:
"""to get shard_spec for construct DistAttr for static API."""
dim_map, _ = to_dim_map(placements, tensor_dims)
mesh_dim_names = mesh.dim_names
Expand Down
14 changes: 7 additions & 7 deletions python/paddle/distributed/auto_parallel/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@


class BaseConfig:
def __init__(self, category, config_dict=None):
def __init__(self, category: any, config_dict: dict = None) -> None:
self._category = category
self._config_dict = None
if config_dict is not None:
Expand All @@ -37,15 +37,15 @@ def __init__(self, category, config_dict=None):
if self._config_dict:
self.from_dict(self._config_dict)

def from_dict(self, config_dict):
def from_dict(self, config_dict) -> None:
config = constants.get_category_default_config(self._category)
for field in config.keys():
value = config_dict.get(field, constants.NOT_FOUND)
# Use the default value if we cannot found the value
if value != constants.NOT_FOUND:
setattr(self, field, value)

def to_dict(self):
def to_dict(self) -> dict:
result_dict = {}
config = constants.get_category_default_config(self._category)
for field in config.keys():
Expand All @@ -56,22 +56,22 @@ def to_dict(self):
result_dict[field] = value.to_dict()
return result_dict

def __repr__(self):
def __repr__(self) -> str:
result_dict = self.to_dict()
string = "{"
for k, v in result_dict.items():
string += f"\"{k}\":\"{v}\","
return string + "}"

def __deepcopy__(self, memo):
def __deepcopy__(self, memo: dict) -> dict:
cls = self.__class__
result = cls.__new__(cls)
memo[id(self)] = result
for k, v in self.__dict__.items():
setattr(result, k, copy.deepcopy(v, memo))
return result

def get(self, k, d=None):
def get(self, k: int, d: int = None) -> dict:
result_dict = self.to_dict()
return result_dict.get(k, d)

Expand Down Expand Up @@ -186,7 +186,7 @@ class Strategy(BaseConfig):

"""

def __init__(self, config=None):
def __init__(self, config: dict = None) -> None:
if config is not None:
if isinstance(config, dict):
self._config_dict = copy.deepcopy(config)
Expand Down
30 changes: 16 additions & 14 deletions python/paddle/distributed/communication/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import paddle
import paddle.distributed as dist
from paddle import framework
from paddle import Tensor, framework


class Group:
Expand Down Expand Up @@ -92,23 +92,23 @@ class _GroupManager:
group_map_by_id = {}


def _get_global_group():
def _get_global_group() -> dict[str]:
if _GroupManager.global_group_id not in _GroupManager.group_map_by_id:
raise RuntimeError("The global group is not initialized.")
return _GroupManager.group_map_by_id[_GroupManager.global_group_id]


def _add_new_group(group):
def _add_new_group(group: Group) -> None:
if group.id in _GroupManager.group_map_by_id:
raise RuntimeError(f"The group with id {group.id} already exist.")
_GroupManager.group_map_by_id[group.id] = group


def _is_global_group(group):
def _is_global_group(group: Group) -> bool:
return group.id == _GroupManager.global_group_id


def _warn_cur_rank_not_in_group(group):
def _warn_cur_rank_not_in_group(group: Group) -> bool:
global_rank = dist.get_rank()
if group and not group.is_member():
warnings.warn(
Expand All @@ -118,15 +118,15 @@ def _warn_cur_rank_not_in_group(group):
return False


def _get_or_throw_group_rank(global_rank, group):
def _get_or_throw_group_rank(global_rank: int, group: Group) -> int:
group_rank = group.get_group_rank(global_rank)
assert (
group_rank >= 0
), f"The input rank {global_rank} can not be found inside the group {group.name}"
return group_rank


def is_initialized():
def is_initialized() -> bool:
"""

Check whether the distributed environment has been initialized
Expand Down Expand Up @@ -154,7 +154,7 @@ def is_initialized():
return _GroupManager.global_group_id in _GroupManager.group_map_by_id


def destroy_process_group(group=None):
def destroy_process_group(group: Group = None) -> None:
"""
Destroy a given group for communication

Expand Down Expand Up @@ -196,7 +196,7 @@ def destroy_process_group(group=None):
del _GroupManager.group_map_by_id[group.id]


def get_group(id=0):
def get_group(id: int = 0) -> None:
"""

Get group instance by group id.
Expand Down Expand Up @@ -226,7 +226,7 @@ def get_group(id=0):
return None


def _sync_calc_stream(tensor):
def _sync_calc_stream(tensor: Tensor) -> None:
if framework.in_dynamic_mode():
return paddle._legacy_C_ops.c_sync_calc_stream(tensor, tensor)
else:
Expand All @@ -239,7 +239,7 @@ def _sync_calc_stream(tensor):
)


def _sync_comm_stream(tensor, ring_id=0):
def _sync_comm_stream(tensor: Tensor, ring_id: int = 0) -> None:
if framework.in_dynamic_mode():
return paddle._legacy_C_ops.c_sync_comm_stream(
[tensor], [tensor], 'ring_id', ring_id
Expand All @@ -255,7 +255,9 @@ def _sync_comm_stream(tensor, ring_id=0):
)


def wait(tensor, group=None, use_calc_stream=True):
def wait(
tensor: Tensor, group: bool = None, use_calc_stream: bool = True
) -> None:
"""

wait to sync stream for group.
Expand Down Expand Up @@ -291,7 +293,7 @@ def wait(tensor, group=None, use_calc_stream=True):
_sync_comm_stream(tensor, ring_id)


def barrier(group=None):
def barrier(group: Group = None) -> None:
"""

Barrier among all participators in the group.
Expand Down Expand Up @@ -347,7 +349,7 @@ def barrier(group=None):
)


def get_backend(group=None):
def get_backend(group: Group = None) -> str:
"""
Get the backend of given group.

Expand Down