From eccc5ec106e7de53e1a3c82084c43aaf008a188c Mon Sep 17 00:00:00 2001 From: ooooo <3164076421@qq.com> Date: Tue, 12 Aug 2025 16:24:52 +0800 Subject: [PATCH] support Shard and CoShard compare Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .../auto_parallel/placement_types.h | 38 +++- test/auto_parallel/CMakeLists.txt | 2 + test/auto_parallel/co_shard.py | 20 +-- test/auto_parallel/test_placement_types.py | 162 ++++++++++++++++++ 4 files changed, 210 insertions(+), 12 deletions(-) create mode 100644 test/auto_parallel/test_placement_types.py diff --git a/paddle/phi/core/distributed/auto_parallel/placement_types.h b/paddle/phi/core/distributed/auto_parallel/placement_types.h index e0042dfd4a4458..b5e5586967e43f 100644 --- a/paddle/phi/core/distributed/auto_parallel/placement_types.h +++ b/paddle/phi/core/distributed/auto_parallel/placement_types.h @@ -83,7 +83,10 @@ class Shard : public Placement { bool operator==(const Placement& other) const override { const Shard* other_shard = dynamic_cast(&other); - return other_shard && this->dim_ == other_shard->dim_; + if (!other_shard) return false; + if (other_shard->get_co_shard_order() != 0) return false; + return this->dim_ == other_shard->dim_ && + this->split_factor_ == other_shard->split_factor_; } bool operator!=(const Placement& other) const override { @@ -152,13 +155,44 @@ class CoShard : public Shard { } std::shared_ptr copy() const override { - return std::make_shared(*this); + return std::make_shared(*this); } std::shared_ptr deepcopy() const override { return std::make_shared(*this); } + bool operator==(const Placement& other) const override { + if (const CoShard* other_coshard = dynamic_cast(&other)) { + return this->dim_ == other_coshard->dim_ && + this->split_factor_ == other_coshard->split_factor_ && + this->co_shard_order_ == other_coshard->co_shard_order_; + } + if (const Shard* other_shard = dynamic_cast(&other)) { + return this->co_shard_order_ == 0 && + this->dim_ == other_shard->get_dim() && + this->split_factor_ == other_shard->get_split_factor(); + } + return false; + } + + bool operator!=(const Placement& other) const override { + return !(*this == other); + } + + std::size_t hash() const override { + std::stringstream ss; + ss << "Shard(dim=" << std::to_string(dim_); + if (split_factor_ != 1) { + ss << ", split_factor=" << std::to_string(split_factor_); + } + if (co_shard_order_ != 0) { + ss << ", shard_order=" << std::to_string(co_shard_order_); + } + ss << ")"; + return std::hash{}(ss.str()); + } + private: int64_t co_shard_order_ = 0; }; diff --git a/test/auto_parallel/CMakeLists.txt b/test/auto_parallel/CMakeLists.txt index 19080bf6ed2a44..9dcededcfcfc92 100644 --- a/test/auto_parallel/CMakeLists.txt +++ b/test/auto_parallel/CMakeLists.txt @@ -9,6 +9,7 @@ add_subdirectory(pir) if(WITH_DISTRIBUTE AND WITH_GPU) # NOTE(zyl): unittests WITH multi cards and timeout + py_test_modules(test_co_shard MODULES test_co_shard) py_test_modules(test_converter MODULES test_converter) set_tests_properties(test_converter PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50) @@ -173,6 +174,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU) py_test_modules(test_api_dist_branch MODULES test_api_dist_branch) py_test_modules(test_shard_tensor_api MODULES test_shard_tensor_api ENVS FLAGS_enable_pir_api=1) + py_test_modules(test_placement_types MODULES test_placement_types) py_test_modules(test_strategy_api MODULES test_strategy_api) py_test_modules(test_parallel_api MODULES test_parallel_api) py_test_modules(test_dtensor_to_local_api MODULES test_dtensor_to_local_api) diff --git a/test/auto_parallel/co_shard.py b/test/auto_parallel/co_shard.py index 5c58cca74079c9..25836b44f6ab23 100644 --- a/test/auto_parallel/co_shard.py +++ b/test/auto_parallel/co_shard.py @@ -21,10 +21,10 @@ class TestCoShard: def basic_interface_case(self): shard = dist.Shard(0, shard_order=0) - np.testing.assert_equal(str(shard), "Shard(dim=0, shard_order=0)") + np.testing.assert_equal(shard, dist.Shard(dim=0, shard_order=0)) shard = dist.Shard(0, split_factor=2) - np.testing.assert_equal(str(shard), "Shard(dim=0, split_factor=2)") + np.testing.assert_equal(shard, dist.Shard(dim=0, split_factor=2)) def run_test_case_0(self): a = paddle.to_tensor([[1, 2], [3, 4], [5, 6], [7, 8]]) @@ -157,10 +157,10 @@ def run_test_case_3(self): a[dist.get_rank()].numpy().flatten(), ) np.testing.assert_equal( - str(out.placements[0]), "Shard(dim=0, shard_order=0)" + out.placements[0], dist.Shard(dim=0, shard_order=0) ) np.testing.assert_equal( - str(out.placements[1]), "Shard(dim=0, shard_order=1)" + out.placements[1], dist.Shard(dim=0, shard_order=1) ) def run_test_case_4(self): @@ -172,10 +172,10 @@ def run_test_case_4(self): out = paddle.reshape(input, [-1]) np.testing.assert_equal(out.shape, [8]) np.testing.assert_equal( - str(out.placements[0]), "Shard(dim=0, shard_order=0)" + out.placements[0], dist.Shard(dim=0, shard_order=0) ) np.testing.assert_equal( - str(out.placements[1]), "Shard(dim=0, shard_order=1)" + out.placements[1], dist.Shard(dim=0, shard_order=1) ) np.testing.assert_equal( out._local_value().numpy(), a[dist.get_rank()].numpy().flatten() @@ -183,16 +183,16 @@ def run_test_case_4(self): relu_out = paddle.nn.ReLU()(out) np.testing.assert_equal( - str(relu_out.placements[0]), "Shard(dim=0, shard_order=0)" + relu_out.placements[0], dist.Shard(dim=0, shard_order=0) ) np.testing.assert_equal( - str(relu_out.placements[1]), "Shard(dim=0, shard_order=1)" + relu_out.placements[1], dist.Shard(dim=0, shard_order=1) ) # test fallback to shard by one dim. add_out = paddle.add(relu_out, relu_out) - np.testing.assert_equal(str(add_out.placements[0]), "Shard(dim=0)") - np.testing.assert_equal(str(add_out.placements[1]), "Replicate()") + np.testing.assert_equal(add_out.placements[0], dist.Shard(dim=0)) + np.testing.assert_equal(add_out.placements[1], dist.Replicate()) def run_test_case_main(self): self.basic_interface_case() diff --git a/test/auto_parallel/test_placement_types.py b/test/auto_parallel/test_placement_types.py new file mode 100644 index 00000000000000..b82612116c0b85 --- /dev/null +++ b/test/auto_parallel/test_placement_types.py @@ -0,0 +1,162 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + +import paddle.distributed as dist + + +class TestPlacementTypes(unittest.TestCase): + def test_shard_eq_with_co_shard_order_zero(self): + """ + Tests that a Shard is equal to a CoShard with shard_order=0. + This confirms the "semantic equality" philosophy. + """ + s1 = dist.Shard(0) + s2 = dist.Shard(dim=0, shard_order=0) + + # 1. Test for symmetric equality + self.assertEqual( + s1, s2, "Shard(0) should be equal to Shard(dim=0, shard_order=0)" + ) + self.assertEqual(s2, s1, "Equality should be symmetric") + + # 2. Test hash consistency + self.assertEqual( + hash(s1), hash(s2), "Hashes must be equal for equal objects" + ) + + # 3. Test behavior in a set + placement_set = {s1, s2} + self.assertEqual( + len(placement_set), + 1, + "A set should only contain one of the two equal objects", + ) + + # 4. Test behavior in a dict + placement_dict = {s1: "value1"} + self.assertIn( + s2, placement_dict, "s2 should be found in a dict keyed by s1" + ) + self.assertEqual(placement_dict[s2], "value1") + + def test_shard_neq_with_co_shard_order_non_zero(self): + """ + Tests that a Shard is NOT equal to a CoShard with a non-zero shard_order. + """ + s1 = dist.Shard(0) + s2 = dist.Shard(dim=0, shard_order=1) + + # 1. Test for symmetric inequality + self.assertNotEqual( + s1, + s2, + "Shard(0) should NOT be equal to Shard(dim=0, shard_order=1)", + ) + self.assertNotEqual(s2, s1, "Inequality should be symmetric") + + # 2. Test hash difference + # Note: While not a strict requirement for non-equal objects to have different hashes, + # a good hash function should minimize collisions. We test for non-collision here. + self.assertNotEqual( + hash(s1), hash(s2), "Hashes should be different for unequal objects" + ) + + # 3. Test behavior in a set + placement_set = {s1, s2} + self.assertEqual( + len(placement_set), 2, "A set should contain two distinct objects" + ) + + def test_co_shard_eq(self): + """ + Tests equality for two CoShard objects. + """ + s1 = dist.Shard(dim=0, shard_order=1) + s2 = dist.Shard(dim=0, shard_order=1) + s3 = dist.Shard(dim=0, shard_order=2) + + self.assertEqual(s1, s2) + self.assertNotEqual(s1, s3) + + def test_replicate_placement(self): + """ + Tests equality and hash for Replicate placement. + """ + r1 = dist.Replicate() + r2 = dist.Replicate() + s1 = dist.Shard(0) + + # 1. Test equality + self.assertEqual(r1, r2, "Two Replicate objects should be equal") + self.assertNotEqual(r1, s1, "Replicate should not be equal to Shard") + + # 2. Test hash consistency + self.assertEqual( + hash(r1), + hash(r2), + "Hashes of two Replicate objects should be equal", + ) + + # 3. Test behavior in a set + placement_set: set[dist.Placement] = {r1, r2} + self.assertEqual( + len(placement_set), + 1, + "A set should only contain one Replicate object", + ) + placement_set.add(s1) + self.assertEqual( + len(placement_set), + 2, + "The set should now contain two distinct objects", + ) + + def test_partial_placement(self): + """ + Tests equality and hash for Partial placement. + """ + p_sum1 = dist.Partial(dist.ReduceType.kRedSum) + p_sum2 = dist.Partial(dist.ReduceType.kRedSum) + p_avg = dist.Partial(dist.ReduceType.kRedAvg) + r1 = dist.Replicate() + + # 1. Test equality + self.assertEqual( + p_sum1, p_sum2, "Two Partial(kRedSum) objects should be equal" + ) + self.assertNotEqual( + p_sum1, + p_avg, + "Partial(kRedSum) should not be equal to Partial(kRedAvg)", + ) + self.assertNotEqual( + p_sum1, r1, "Partial should not be equal to Replicate" + ) + + # 2. Test hash consistency + self.assertEqual(hash(p_sum1), hash(p_sum2)) + self.assertNotEqual(hash(p_sum1), hash(p_avg)) + + # 3. Test behavior in a set + placement_set = {p_sum1, p_sum2} + self.assertEqual(len(placement_set), 1) + placement_set.add(p_avg) + self.assertEqual(len(placement_set), 2) + + +if __name__ == '__main__': + unittest.main()