diff --git a/python/paddle/distributed/auto_parallel/process_mesh.py b/python/paddle/distributed/auto_parallel/process_mesh.py index 897ee0be36a1e7..c4ccd43b12619c 100644 --- a/python/paddle/distributed/auto_parallel/process_mesh.py +++ b/python/paddle/distributed/auto_parallel/process_mesh.py @@ -21,6 +21,8 @@ import numpy as np import paddle +from paddle.distributed import fleet +from paddle.distributed.collective import _get_group_map from paddle.distributed.communication.group import is_initialized from paddle.framework import core @@ -442,8 +444,29 @@ def get_group( f"{dim_name} not in the dimension names {self._dim_names}" ) else: - pg = paddle.distributed.new_group(self._process_ids) - return pg + if hasattr(fleet.fleet, "_hcg"): + hcg = fleet.get_hybrid_communicate_group() + if hcg is not None: + + parallel_group_map = { + "pp": hcg.get_pipe_parallel_group, + "dp": hcg.get_data_parallel_group, + "mp": hcg.get_model_parallel_group, + "sep": hcg.get_sep_parallel_group, + "sharding": hcg.get_sharding_parallel_group, + } + + if dim_name not in parallel_group_map: + raise ValueError( + f"{dim_name} is not a valid dim name." + ) + + return parallel_group_map[dim_name]() + group_map = _get_group_map() + for group in group_map.values(): + if set(group.ranks) == set(self._process_ids): + return group + return paddle.distributed.new_group(self._process_ids) else: if dim_name not in self._dim_names: raise ValueError( diff --git a/test/auto_parallel/hybrid_strategy/CMakeLists.txt b/test/auto_parallel/hybrid_strategy/CMakeLists.txt index 006f95249ecd33..ce31d06d0ab42f 100644 --- a/test/auto_parallel/hybrid_strategy/CMakeLists.txt +++ b/test/auto_parallel/hybrid_strategy/CMakeLists.txt @@ -173,6 +173,14 @@ if((WITH_GPU) AND (LINUX)) py_test_modules( test_process_mesh MODULES test_process_mesh ENVS "http_proxy=;https_proxy=;PYTHONPATH=../..:${PADDLE_BINARY_DIR}/python") - set_tests_properties(test_process_mesh PROPERTIES TIMEOUT "60" LABELS + set_tests_properties(test_process_mesh PROPERTIES TIMEOUT "150" LABELS "RUN_TYPE=HYBRID") endif() +if((WITH_GPU) AND (LINUX)) + py_test_modules( + test_get_group_in_different_hybrid_configs MODULES + test_get_group_in_different_hybrid_configs ENVS + "http_proxy=;https_proxy=;PYTHONPATH=../..:${PADDLE_BINARY_DIR}/python") + set_tests_properties(test_get_group_in_different_hybrid_configs + PROPERTIES TIMEOUT "150" LABELS "RUN_TYPE=HYBRID") +endif() diff --git a/test/auto_parallel/hybrid_strategy/process_mesh_demo_unittest.py b/test/auto_parallel/hybrid_strategy/process_mesh_demo_unittest.py index 8d2bf4609d3654..fb48af746e8899 100644 --- a/test/auto_parallel/hybrid_strategy/process_mesh_demo_unittest.py +++ b/test/auto_parallel/hybrid_strategy/process_mesh_demo_unittest.py @@ -99,7 +99,7 @@ def test_get_group(self): assert isinstance( group_1d_with_name, dist.communication.group.Group ) - + assert group_1d_with_name.id == group_1d.id # Test case 3: Single dimension mesh with wrong dim_name try: mesh_1d.get_group(dim_name="wrong_name") diff --git a/test/auto_parallel/hybrid_strategy/test_get_group_in_different_hybrid_configs.py b/test/auto_parallel/hybrid_strategy/test_get_group_in_different_hybrid_configs.py new file mode 100644 index 00000000000000..a7834d3e13d470 --- /dev/null +++ b/test/auto_parallel/hybrid_strategy/test_get_group_in_different_hybrid_configs.py @@ -0,0 +1,162 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import collective.test_communication_api_base as test_base + + +class TestProcessMeshDPGroupConsistency(test_base.CommunicationTestDistBase): + def setUp(self): + super().setUp(num_of_devices=2, timeout=200, nnode=1) + + def test_dp_parallel(self): + """Test data parallel group creation and consistency""" + _default_envs = { + "dp": "2", + "mp": "1", + "pp": "1", + "parallel_type": "dp", + "FLAGS_embedding_deterministic": "1", + "FLAGS_cudnn_deterministic": "1", + } + _changeable_envs = { + "backend": ["gpu"], + } + envs_list = test_base.gen_product_envs_list( + _default_envs, _changeable_envs + ) + for envs in envs_list: + self.run_test_case( + "test_process_mesh_group_consistency.py", + user_defined_envs=envs, + ) + + +class TestProcessMeshMPGroupConsistency(test_base.CommunicationTestDistBase): + def setUp(self): + super().setUp(num_of_devices=2, timeout=200, nnode=1) + + def test_mp_parallel(self): + """Test model parallel group creation and consistency""" + _default_envs = { + "dp": "1", + "mp": "2", + "pp": "1", + "parallel_type": "mp", + "FLAGS_embedding_deterministic": "1", + "FLAGS_cudnn_deterministic": "1", + } + _changeable_envs = { + "backend": ["gpu"], + } + envs_list = test_base.gen_product_envs_list( + _default_envs, _changeable_envs + ) + for envs in envs_list: + self.run_test_case( + "test_process_mesh_group_consistency.py", + user_defined_envs=envs, + ) + + +class TestProcessMeshPPGroupConsistency(test_base.CommunicationTestDistBase): + def setUp(self): + super().setUp(num_of_devices=2, timeout=200, nnode=1) + + def test_pp_parallel(self): + """Test pipeline parallel group creation and consistency""" + _default_envs = { + "dp": "1", + "mp": "1", + "pp": "2", + "parallel_type": "pp", + "FLAGS_embedding_deterministic": "1", + "FLAGS_cudnn_deterministic": "1", + } + _changeable_envs = { + "backend": ["gpu"], + } + envs_list = test_base.gen_product_envs_list( + _default_envs, _changeable_envs + ) + for envs in envs_list: + self.run_test_case( + "test_process_mesh_group_consistency.py", + user_defined_envs=envs, + ) + + +class TestProcessMeshSEPGroupConsistency(test_base.CommunicationTestDistBase): + def setUp(self): + super().setUp(num_of_devices=2, timeout=200, nnode=1) + + def test_sep_parallel(self): + """Test sequence parallel group creation and consistency""" + _default_envs = { + "dp": "1", + "mp": "1", + "pp": "1", + "sep": "2", + "sharding": "1", + "parallel_type": "sep", + "FLAGS_embedding_deterministic": "1", + "FLAGS_cudnn_deterministic": "1", + } + _changeable_envs = { + "backend": ["gpu"], + } + envs_list = test_base.gen_product_envs_list( + _default_envs, _changeable_envs + ) + for envs in envs_list: + self.run_test_case( + "test_process_mesh_group_consistency.py", + user_defined_envs=envs, + ) + + +class TestProcessMeshShardingGroupConsistency( + test_base.CommunicationTestDistBase +): + def setUp(self): + super().setUp(num_of_devices=2, timeout=200, nnode=1) + + def test_sharding_parallel(self): + """Test sharding parallel group creation and consistency""" + _default_envs = { + "dp": "1", + "mp": "1", + "pp": "1", + "sep": "1", + "sharding": "2", + "parallel_type": "sharding", + "FLAGS_embedding_deterministic": "1", + "FLAGS_cudnn_deterministic": "1", + } + _changeable_envs = { + "backend": ["gpu"], + } + envs_list = test_base.gen_product_envs_list( + _default_envs, _changeable_envs + ) + for envs in envs_list: + self.run_test_case( + "test_process_mesh_group_consistency.py", + user_defined_envs=envs, + ) + + +if __name__ == "__main__": + unittest.main() # python run diff --git a/test/auto_parallel/hybrid_strategy/test_process_mesh_group_consistency.py b/test/auto_parallel/hybrid_strategy/test_process_mesh_group_consistency.py new file mode 100644 index 00000000000000..7dd15d44405637 --- /dev/null +++ b/test/auto_parallel/hybrid_strategy/test_process_mesh_group_consistency.py @@ -0,0 +1,107 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import paddle.distributed as dist +from paddle.distributed import fleet + + +class TestProcessMeshGroupConsistency: + def __init__(self): + # Get configuration from environment variables + self.dp = int(os.getenv("dp", "1")) + self.mp = int(os.getenv("mp", "1")) + self.pp = int(os.getenv("pp", "1")) + self.sep = int(os.getenv("sep", "1")) + self.sharding = int(os.getenv("sharding", "1")) + + # Determine which parallel type to test + self.parallel_type = os.getenv("parallel_type", "dp") + + def init_dist_env(self): + """Initialize distributed environment""" + # Configure distributed strategy + dist_strategy = fleet.DistributedStrategy() + dist_strategy.hybrid_configs = { + "dp_degree": self.dp, + "mp_degree": self.mp, + "pp_degree": self.pp, + "sep_degree": self.sep, + "sharding_degree": self.sharding, + } + + # Add corresponding configuration based on parallel type + if self.sep > 1: + dist_strategy.hybrid_configs["sep_degree"] = self.sep + if self.sharding > 1: + dist_strategy.hybrid_configs["sharding_degree"] = self.sharding + + fleet.init(is_collective=True, strategy=dist_strategy) + + def test_process_mesh_group_consistency(self): + """Test consistency between ProcessMesh created groups and HCG created groups""" + + # Create corresponding ProcessMesh and get corresponding HCG group based on parallel type + if self.parallel_type == "dp": + mesh = dist.ProcessMesh([0, 1], dim_names=["dp"]) + hcg = fleet.get_hybrid_communicate_group() + group = mesh.get_group(dim_name="dp") + hcg_group = hcg.get_data_parallel_group() + + elif self.parallel_type == "mp": + mesh = dist.ProcessMesh([0, 1], dim_names=["mp"]) + hcg = fleet.get_hybrid_communicate_group() + group = mesh.get_group(dim_name="mp") + hcg_group = hcg.get_model_parallel_group() + + elif self.parallel_type == "pp": + mesh = dist.ProcessMesh([0, 1], dim_names=["pp"]) + hcg = fleet.get_hybrid_communicate_group() + group = mesh.get_group(dim_name="pp") + hcg_group = hcg.get_pipe_parallel_group() + + elif self.parallel_type == "sep": + mesh = dist.ProcessMesh([0, 1], dim_names=["sep"]) + hcg = fleet.get_hybrid_communicate_group() + group = mesh.get_group(dim_name="sep") + hcg_group = hcg.get_sep_parallel_group() + + elif self.parallel_type == "sharding": + mesh = dist.ProcessMesh([0, 1], dim_names=["sharding"]) + hcg = fleet.get_hybrid_communicate_group() + group = mesh.get_group(dim_name="sharding") + hcg_group = hcg.get_sharding_parallel_group() + + else: + raise ValueError(f"Unsupported parallel type: {self.parallel_type}") + + # Verify that group ranks are consistent + group_ranks = group.ranks + hcg_group_ranks = hcg_group.ranks + assert set(group_ranks) == set(hcg_group_ranks) + + # Verify that group IDs are consistent + group_id = group.id + hcg_group_id = hcg_group.id + assert group_id == hcg_group_id + + def run_test_cases(self): + """Run test cases""" + self.init_dist_env() + self.test_process_mesh_group_consistency() + + +if __name__ == "__main__": + TestProcessMeshGroupConsistency().run_test_cases()