From cbddb33a7d42b8b3b58285a3d2fda85e6e211dbf Mon Sep 17 00:00:00 2001 From: lwkhahaha Date: Mon, 29 Jul 2024 14:01:15 +0800 Subject: [PATCH 1/6] [Typing][C-4] Add type annotations for python/paddle/distributed/auto_parallel/ interface.py --- .../distributed/auto_parallel/interface.py | 39 +++++++++++++------ 1 file changed, 27 insertions(+), 12 deletions(-) diff --git a/python/paddle/distributed/auto_parallel/interface.py b/python/paddle/distributed/auto_parallel/interface.py index e8aa51563ad77..92607cd066d24 100644 --- a/python/paddle/distributed/auto_parallel/interface.py +++ b/python/paddle/distributed/auto_parallel/interface.py @@ -12,8 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from functools import reduce -from typing import List, Tuple +from typing import TYPE_CHECKING, Any, Callable import numpy as np @@ -30,8 +32,15 @@ verify_shard_spec, ) +if TYPE_CHECKING: + from paddle import Tensor + -def shard_tensor(x, process_mesh=None, shard_spec=None): +def shard_tensor( + x: Tensor, + process_mesh: ProcessMesh = None, + shard_spec: list[str] = None, # None 或者包含str或None的列表 +) -> Tensor: """ Shard a tensor on a process mesh according to the shard specification. @@ -118,8 +127,12 @@ def shard_tensor(x, process_mesh=None, shard_spec=None): def shard_op( - op, process_mesh=None, in_shard_specs=None, out_shard_specs=None, **kwargs -): + op: Callable, # 假设Callable是合适类型注解 + process_mesh: ProcessMesh = None, + in_shard_specs: list[list[str]] = None, # 嵌套的Optional和List + out_shard_specs: list[list[str]] = None, + **kwargs, +) -> Any: """ Shard an operation on a process mesh according to its input and output shard specification. @@ -207,7 +220,7 @@ def shard_op( _g_recompute_idx = -1 -def recompute(op): +def recompute(op: Callable) -> None: global _g_recompute_idx _g_recompute_idx += 1 @@ -233,7 +246,7 @@ def __call__(self, *args, **kwargs): return RecomputeOperator(op) -def exclude_ops_in_recompute(run_function): +def exclude_ops_in_recompute(run_function: Callable) -> None: """ Exclude some operators in recompute segments. Args: @@ -272,7 +285,7 @@ class CollectionNames: LOGGING = "logging" -def get_collection(name): +def get_collection(name: str) -> dict: collection = _g_collections.get(name, None) if collection is None: collection = [] @@ -280,7 +293,9 @@ def get_collection(name): return _g_collections[name] -def add_to_collection(collection_name, value, name=None): +def add_to_collection( + collection_name: str, value: str, name: str = None +) -> dict: if collection_name not in _g_collections: _g_collections[collection_name] = [] if name is not None: @@ -295,7 +310,7 @@ def add_to_collection(collection_name, value, name=None): _g_collections[collection_name].append((None, value)) -def fetch(tensor, name=None, logging=False): +def fetch(tensor: Tensor, name: str = None, logging: bool = False) -> None: if isinstance(tensor, paddle.static.Variable): tensor = tensor.name elif isinstance(tensor, str): @@ -312,17 +327,17 @@ def fetch(tensor, name=None, logging=False): _g_mesh = None -def get_mesh(): +def get_mesh() -> ProcessMesh: global _g_mesh return _g_mesh -def set_mesh(mesh): +def set_mesh(mesh) -> ProcessMesh: global _g_mesh _g_mesh = mesh -def create_mesh(mesh_dims: List[Tuple[str, int]]): +def create_mesh(mesh_dims: list[tuple[str, int]]) -> ProcessMesh: """ Create a global process_mesh for auto parallel. From 7dec9d731fbd6195829834fe854a980364630af0 Mon Sep 17 00:00:00 2001 From: lwkhahaha Date: Mon, 29 Jul 2024 15:34:37 +0800 Subject: [PATCH 2/6] [Typing][C-4] Add type annotations for python/paddle/distributed/auto_parallel/ interface.py --- .../auto_parallel/placement_type.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/python/paddle/distributed/auto_parallel/placement_type.py b/python/paddle/distributed/auto_parallel/placement_type.py index 36f7c9b43d1fd..030a8ffd3ea05 100644 --- a/python/paddle/distributed/auto_parallel/placement_type.py +++ b/python/paddle/distributed/auto_parallel/placement_type.py @@ -11,13 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import cast +from typing import List, cast import paddle from paddle.base.core import Partial, Replicate, Shard +from paddle.distributed import Placement +from .process_mesh import ProcessMesh -def to_placements(dim_map, mesh, partial_idx=[]): + +def to_placements( + dim_map: list[int], mesh: ProcessMesh, partial_idx: list[int] = None +) -> list: """ convert dim_map to placements. @@ -55,7 +60,7 @@ def to_placements(dim_map, mesh, partial_idx=[]): return placements -def check_placements_equal(this, that): +def check_placements_equal(this: list[any], that: list[any]) -> bool: assert isinstance(this, list) and isinstance(that, list) small_placements = this if len(this) < len(that) else that large_placements = that if len(this) < len(that) else this @@ -69,7 +74,9 @@ def check_placements_equal(this, that): return True -def to_dim_map(placements, tensor_dims): +def to_dim_map( + placements: list[Placement], tensor_dims: int +) -> tuple[list[int], dict[int, str]]: """ convert placements to dim_map. @@ -97,7 +104,9 @@ def to_dim_map(placements, tensor_dims): return dim_map, partial_status -def get_shard_spec(mesh, placements, tensor_dims): +def get_shard_spec( + mesh: ProcessMesh, placements: list[Placement], tensor_dims: int +) -> List[str]: """to get shard_spec for construct DistAttr for static API.""" dim_map, _ = to_dim_map(placements, tensor_dims) mesh_dim_names = mesh.dim_names From 2231325ea57c27cfcbe978cfcd90b4bebca00f05 Mon Sep 17 00:00:00 2001 From: lwkhahaha Date: Mon, 29 Jul 2024 16:05:01 +0800 Subject: [PATCH 3/6] [Typing][C-4] Add type annotations for python/paddle/distributed/auto_parallel/ interface.py --- .../paddle/distributed/communication/group.py | 30 ++++++++++--------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/python/paddle/distributed/communication/group.py b/python/paddle/distributed/communication/group.py index d73e3ce90cbd2..0871f324a8f55 100644 --- a/python/paddle/distributed/communication/group.py +++ b/python/paddle/distributed/communication/group.py @@ -16,7 +16,7 @@ import paddle import paddle.distributed as dist -from paddle import framework +from paddle import Tensor, framework class Group: @@ -92,23 +92,23 @@ class _GroupManager: group_map_by_id = {} -def _get_global_group(): +def _get_global_group() -> dict[str, any]: if _GroupManager.global_group_id not in _GroupManager.group_map_by_id: raise RuntimeError("The global group is not initialized.") return _GroupManager.group_map_by_id[_GroupManager.global_group_id] -def _add_new_group(group): +def _add_new_group(group: Group) -> None: if group.id in _GroupManager.group_map_by_id: raise RuntimeError(f"The group with id {group.id} already exist.") _GroupManager.group_map_by_id[group.id] = group -def _is_global_group(group): +def _is_global_group(group: Group) -> bool: return group.id == _GroupManager.global_group_id -def _warn_cur_rank_not_in_group(group): +def _warn_cur_rank_not_in_group(group: Group) -> bool: global_rank = dist.get_rank() if group and not group.is_member(): warnings.warn( @@ -118,7 +118,7 @@ def _warn_cur_rank_not_in_group(group): return False -def _get_or_throw_group_rank(global_rank, group): +def _get_or_throw_group_rank(global_rank: int, group: Group) -> int: group_rank = group.get_group_rank(global_rank) assert ( group_rank >= 0 @@ -126,7 +126,7 @@ def _get_or_throw_group_rank(global_rank, group): return group_rank -def is_initialized(): +def is_initialized() -> bool: """ Check whether the distributed environment has been initialized @@ -154,7 +154,7 @@ def is_initialized(): return _GroupManager.global_group_id in _GroupManager.group_map_by_id -def destroy_process_group(group=None): +def destroy_process_group(group: Group = None) -> None: """ Destroy a given group for communication @@ -196,7 +196,7 @@ def destroy_process_group(group=None): del _GroupManager.group_map_by_id[group.id] -def get_group(id=0): +def get_group(id: int = 0) -> None: """ Get group instance by group id. @@ -226,7 +226,7 @@ def get_group(id=0): return None -def _sync_calc_stream(tensor): +def _sync_calc_stream(tensor: Tensor) -> None: if framework.in_dynamic_mode(): return paddle._legacy_C_ops.c_sync_calc_stream(tensor, tensor) else: @@ -239,7 +239,7 @@ def _sync_calc_stream(tensor): ) -def _sync_comm_stream(tensor, ring_id=0): +def _sync_comm_stream(tensor: Tensor, ring_id: int = 0) -> None: if framework.in_dynamic_mode(): return paddle._legacy_C_ops.c_sync_comm_stream( [tensor], [tensor], 'ring_id', ring_id @@ -255,7 +255,9 @@ def _sync_comm_stream(tensor, ring_id=0): ) -def wait(tensor, group=None, use_calc_stream=True): +def wait( + tensor: Tensor, group: bool = None, use_calc_stream: bool = True +) -> None: """ wait to sync stream for group. @@ -291,7 +293,7 @@ def wait(tensor, group=None, use_calc_stream=True): _sync_comm_stream(tensor, ring_id) -def barrier(group=None): +def barrier(group: Group = None) -> None: """ Barrier among all participators in the group. @@ -347,7 +349,7 @@ def barrier(group=None): ) -def get_backend(group=None): +def get_backend(group: Group = None) -> str: """ Get the backend of given group. From d8c0e7166fcbea6f23f9529b67ffc0d7655920d5 Mon Sep 17 00:00:00 2001 From: lwkhahaha Date: Mon, 29 Jul 2024 16:43:50 +0800 Subject: [PATCH 4/6] [Typing][C-4] Add type annotations for python/paddle/distributed/auto_parallel/ interface.py --- .../paddle/distributed/auto_parallel/strategy.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/python/paddle/distributed/auto_parallel/strategy.py b/python/paddle/distributed/auto_parallel/strategy.py index 7f026827d3446..1d7455b804d47 100644 --- a/python/paddle/distributed/auto_parallel/strategy.py +++ b/python/paddle/distributed/auto_parallel/strategy.py @@ -18,7 +18,7 @@ class BaseConfig: - def __init__(self, category, config_dict=None): + def __init__(self, category: any, config_dict: dict = None) -> None: self._category = category self._config_dict = None if config_dict is not None: @@ -37,7 +37,7 @@ def __init__(self, category, config_dict=None): if self._config_dict: self.from_dict(self._config_dict) - def from_dict(self, config_dict): + def from_dict(self, config_dict) -> None: config = constants.get_category_default_config(self._category) for field in config.keys(): value = config_dict.get(field, constants.NOT_FOUND) @@ -45,7 +45,7 @@ def from_dict(self, config_dict): if value != constants.NOT_FOUND: setattr(self, field, value) - def to_dict(self): + def to_dict(self) -> dict: result_dict = {} config = constants.get_category_default_config(self._category) for field in config.keys(): @@ -56,14 +56,14 @@ def to_dict(self): result_dict[field] = value.to_dict() return result_dict - def __repr__(self): + def __repr__(self) -> str: result_dict = self.to_dict() string = "{" for k, v in result_dict.items(): string += f"\"{k}\":\"{v}\"," return string + "}" - def __deepcopy__(self, memo): + def __deepcopy__(self, memo: dict) -> dict: cls = self.__class__ result = cls.__new__(cls) memo[id(self)] = result @@ -71,7 +71,7 @@ def __deepcopy__(self, memo): setattr(result, k, copy.deepcopy(v, memo)) return result - def get(self, k, d=None): + def get(self, k: int, d: int = None) -> dict: result_dict = self.to_dict() return result_dict.get(k, d) @@ -186,7 +186,7 @@ class Strategy(BaseConfig): """ - def __init__(self, config=None): + def __init__(self, config: dict = None) -> None: if config is not None: if isinstance(config, dict): self._config_dict = copy.deepcopy(config) From 9003d15bf6f9421dd41a3f9a4b8feb1bae7f55d9 Mon Sep 17 00:00:00 2001 From: lwkhahaha Date: Tue, 30 Jul 2024 15:42:43 +0800 Subject: [PATCH 5/6] [Typing][C-4][C-7][C-9][C-18] Add type annotations for python/paddle/distributed/auto_parallel/interface.py+strategy.py+placement_type.py and python/paddle/distributed/communication/group.py --- python/paddle/distributed/auto_parallel/placement_type.py | 2 +- python/paddle/distributed/communication/group.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/paddle/distributed/auto_parallel/placement_type.py b/python/paddle/distributed/auto_parallel/placement_type.py index 030a8ffd3ea05..083726d75ffd7 100644 --- a/python/paddle/distributed/auto_parallel/placement_type.py +++ b/python/paddle/distributed/auto_parallel/placement_type.py @@ -22,7 +22,7 @@ def to_placements( dim_map: list[int], mesh: ProcessMesh, partial_idx: list[int] = None -) -> list: +) -> List[Placement]: """ convert dim_map to placements. diff --git a/python/paddle/distributed/communication/group.py b/python/paddle/distributed/communication/group.py index 0871f324a8f55..61ad792a14c02 100644 --- a/python/paddle/distributed/communication/group.py +++ b/python/paddle/distributed/communication/group.py @@ -92,7 +92,7 @@ class _GroupManager: group_map_by_id = {} -def _get_global_group() -> dict[str, any]: +def _get_global_group() -> dict[str]: if _GroupManager.global_group_id not in _GroupManager.group_map_by_id: raise RuntimeError("The global group is not initialized.") return _GroupManager.group_map_by_id[_GroupManager.global_group_id] From 1ab8a5dee18ed8932226f53828515f30b0c126e4 Mon Sep 17 00:00:00 2001 From: lwkhahaha <124662571+lwkhahaha@users.noreply.github.com> Date: Tue, 30 Jul 2024 22:25:26 +0800 Subject: [PATCH 6/6] Update placement_type.py --- python/paddle/distributed/auto_parallel/placement_type.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/distributed/auto_parallel/placement_type.py b/python/paddle/distributed/auto_parallel/placement_type.py index 083726d75ffd7..5a140cf42d397 100644 --- a/python/paddle/distributed/auto_parallel/placement_type.py +++ b/python/paddle/distributed/auto_parallel/placement_type.py @@ -21,7 +21,7 @@ def to_placements( - dim_map: list[int], mesh: ProcessMesh, partial_idx: list[int] = None + dim_map: list[int], mesh: ProcessMesh, partial_idx: list[int] = [] ) -> List[Placement]: """ convert dim_map to placements.