Skip to content

Commit 6e31ae0

Browse files
support Shard and CoShard compare (#74565)
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 8321bbb commit 6e31ae0

File tree

4 files changed

+210
-12
lines changed

4 files changed

+210
-12
lines changed

paddle/phi/core/distributed/auto_parallel/placement_types.h

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,10 @@ class Shard : public Placement {
8383

8484
bool operator==(const Placement& other) const override {
8585
const Shard* other_shard = dynamic_cast<const Shard*>(&other);
86-
return other_shard && this->dim_ == other_shard->dim_;
86+
if (!other_shard) return false;
87+
if (other_shard->get_co_shard_order() != 0) return false;
88+
return this->dim_ == other_shard->dim_ &&
89+
this->split_factor_ == other_shard->split_factor_;
8790
}
8891

8992
bool operator!=(const Placement& other) const override {
@@ -152,13 +155,44 @@ class CoShard : public Shard {
152155
}
153156

154157
std::shared_ptr<Shard> copy() const override {
155-
return std::make_shared<Shard>(*this);
158+
return std::make_shared<CoShard>(*this);
156159
}
157160

158161
std::shared_ptr<Shard> deepcopy() const override {
159162
return std::make_shared<CoShard>(*this);
160163
}
161164

165+
bool operator==(const Placement& other) const override {
166+
if (const CoShard* other_coshard = dynamic_cast<const CoShard*>(&other)) {
167+
return this->dim_ == other_coshard->dim_ &&
168+
this->split_factor_ == other_coshard->split_factor_ &&
169+
this->co_shard_order_ == other_coshard->co_shard_order_;
170+
}
171+
if (const Shard* other_shard = dynamic_cast<const Shard*>(&other)) {
172+
return this->co_shard_order_ == 0 &&
173+
this->dim_ == other_shard->get_dim() &&
174+
this->split_factor_ == other_shard->get_split_factor();
175+
}
176+
return false;
177+
}
178+
179+
bool operator!=(const Placement& other) const override {
180+
return !(*this == other);
181+
}
182+
183+
std::size_t hash() const override {
184+
std::stringstream ss;
185+
ss << "Shard(dim=" << std::to_string(dim_);
186+
if (split_factor_ != 1) {
187+
ss << ", split_factor=" << std::to_string(split_factor_);
188+
}
189+
if (co_shard_order_ != 0) {
190+
ss << ", shard_order=" << std::to_string(co_shard_order_);
191+
}
192+
ss << ")";
193+
return std::hash<std::string>{}(ss.str());
194+
}
195+
162196
private:
163197
int64_t co_shard_order_ = 0;
164198
};

test/auto_parallel/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ add_subdirectory(pir)
99
if(WITH_DISTRIBUTE AND WITH_GPU)
1010

1111
# NOTE(zyl): unittests WITH multi cards and timeout
12+
py_test_modules(test_co_shard MODULES test_co_shard)
1213
py_test_modules(test_converter MODULES test_converter)
1314
set_tests_properties(test_converter PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE"
1415
TIMEOUT 50)
@@ -173,6 +174,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
173174
py_test_modules(test_api_dist_branch MODULES test_api_dist_branch)
174175
py_test_modules(test_shard_tensor_api MODULES test_shard_tensor_api ENVS
175176
FLAGS_enable_pir_api=1)
177+
py_test_modules(test_placement_types MODULES test_placement_types)
176178
py_test_modules(test_strategy_api MODULES test_strategy_api)
177179
py_test_modules(test_parallel_api MODULES test_parallel_api)
178180
py_test_modules(test_dtensor_to_local_api MODULES test_dtensor_to_local_api)

test/auto_parallel/co_shard.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,10 @@
2121
class TestCoShard:
2222
def basic_interface_case(self):
2323
shard = dist.Shard(0, shard_order=0)
24-
np.testing.assert_equal(str(shard), "Shard(dim=0, shard_order=0)")
24+
np.testing.assert_equal(shard, dist.Shard(dim=0, shard_order=0))
2525

2626
shard = dist.Shard(0, split_factor=2)
27-
np.testing.assert_equal(str(shard), "Shard(dim=0, split_factor=2)")
27+
np.testing.assert_equal(shard, dist.Shard(dim=0, split_factor=2))
2828

2929
def run_test_case_0(self):
3030
a = paddle.to_tensor([[1, 2], [3, 4], [5, 6], [7, 8]])
@@ -157,10 +157,10 @@ def run_test_case_3(self):
157157
a[dist.get_rank()].numpy().flatten(),
158158
)
159159
np.testing.assert_equal(
160-
str(out.placements[0]), "Shard(dim=0, shard_order=0)"
160+
out.placements[0], dist.Shard(dim=0, shard_order=0)
161161
)
162162
np.testing.assert_equal(
163-
str(out.placements[1]), "Shard(dim=0, shard_order=1)"
163+
out.placements[1], dist.Shard(dim=0, shard_order=1)
164164
)
165165

166166
def run_test_case_4(self):
@@ -172,27 +172,27 @@ def run_test_case_4(self):
172172
out = paddle.reshape(input, [-1])
173173
np.testing.assert_equal(out.shape, [8])
174174
np.testing.assert_equal(
175-
str(out.placements[0]), "Shard(dim=0, shard_order=0)"
175+
out.placements[0], dist.Shard(dim=0, shard_order=0)
176176
)
177177
np.testing.assert_equal(
178-
str(out.placements[1]), "Shard(dim=0, shard_order=1)"
178+
out.placements[1], dist.Shard(dim=0, shard_order=1)
179179
)
180180
np.testing.assert_equal(
181181
out._local_value().numpy(), a[dist.get_rank()].numpy().flatten()
182182
)
183183

184184
relu_out = paddle.nn.ReLU()(out)
185185
np.testing.assert_equal(
186-
str(relu_out.placements[0]), "Shard(dim=0, shard_order=0)"
186+
relu_out.placements[0], dist.Shard(dim=0, shard_order=0)
187187
)
188188
np.testing.assert_equal(
189-
str(relu_out.placements[1]), "Shard(dim=0, shard_order=1)"
189+
relu_out.placements[1], dist.Shard(dim=0, shard_order=1)
190190
)
191191

192192
# test fallback to shard by one dim.
193193
add_out = paddle.add(relu_out, relu_out)
194-
np.testing.assert_equal(str(add_out.placements[0]), "Shard(dim=0)")
195-
np.testing.assert_equal(str(add_out.placements[1]), "Replicate()")
194+
np.testing.assert_equal(add_out.placements[0], dist.Shard(dim=0))
195+
np.testing.assert_equal(add_out.placements[1], dist.Replicate())
196196

197197
def run_test_case_main(self):
198198
self.basic_interface_case()
Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
import unittest
17+
18+
import paddle.distributed as dist
19+
20+
21+
class TestPlacementTypes(unittest.TestCase):
22+
def test_shard_eq_with_co_shard_order_zero(self):
23+
"""
24+
Tests that a Shard is equal to a CoShard with shard_order=0.
25+
This confirms the "semantic equality" philosophy.
26+
"""
27+
s1 = dist.Shard(0)
28+
s2 = dist.Shard(dim=0, shard_order=0)
29+
30+
# 1. Test for symmetric equality
31+
self.assertEqual(
32+
s1, s2, "Shard(0) should be equal to Shard(dim=0, shard_order=0)"
33+
)
34+
self.assertEqual(s2, s1, "Equality should be symmetric")
35+
36+
# 2. Test hash consistency
37+
self.assertEqual(
38+
hash(s1), hash(s2), "Hashes must be equal for equal objects"
39+
)
40+
41+
# 3. Test behavior in a set
42+
placement_set = {s1, s2}
43+
self.assertEqual(
44+
len(placement_set),
45+
1,
46+
"A set should only contain one of the two equal objects",
47+
)
48+
49+
# 4. Test behavior in a dict
50+
placement_dict = {s1: "value1"}
51+
self.assertIn(
52+
s2, placement_dict, "s2 should be found in a dict keyed by s1"
53+
)
54+
self.assertEqual(placement_dict[s2], "value1")
55+
56+
def test_shard_neq_with_co_shard_order_non_zero(self):
57+
"""
58+
Tests that a Shard is NOT equal to a CoShard with a non-zero shard_order.
59+
"""
60+
s1 = dist.Shard(0)
61+
s2 = dist.Shard(dim=0, shard_order=1)
62+
63+
# 1. Test for symmetric inequality
64+
self.assertNotEqual(
65+
s1,
66+
s2,
67+
"Shard(0) should NOT be equal to Shard(dim=0, shard_order=1)",
68+
)
69+
self.assertNotEqual(s2, s1, "Inequality should be symmetric")
70+
71+
# 2. Test hash difference
72+
# Note: While not a strict requirement for non-equal objects to have different hashes,
73+
# a good hash function should minimize collisions. We test for non-collision here.
74+
self.assertNotEqual(
75+
hash(s1), hash(s2), "Hashes should be different for unequal objects"
76+
)
77+
78+
# 3. Test behavior in a set
79+
placement_set = {s1, s2}
80+
self.assertEqual(
81+
len(placement_set), 2, "A set should contain two distinct objects"
82+
)
83+
84+
def test_co_shard_eq(self):
85+
"""
86+
Tests equality for two CoShard objects.
87+
"""
88+
s1 = dist.Shard(dim=0, shard_order=1)
89+
s2 = dist.Shard(dim=0, shard_order=1)
90+
s3 = dist.Shard(dim=0, shard_order=2)
91+
92+
self.assertEqual(s1, s2)
93+
self.assertNotEqual(s1, s3)
94+
95+
def test_replicate_placement(self):
96+
"""
97+
Tests equality and hash for Replicate placement.
98+
"""
99+
r1 = dist.Replicate()
100+
r2 = dist.Replicate()
101+
s1 = dist.Shard(0)
102+
103+
# 1. Test equality
104+
self.assertEqual(r1, r2, "Two Replicate objects should be equal")
105+
self.assertNotEqual(r1, s1, "Replicate should not be equal to Shard")
106+
107+
# 2. Test hash consistency
108+
self.assertEqual(
109+
hash(r1),
110+
hash(r2),
111+
"Hashes of two Replicate objects should be equal",
112+
)
113+
114+
# 3. Test behavior in a set
115+
placement_set: set[dist.Placement] = {r1, r2}
116+
self.assertEqual(
117+
len(placement_set),
118+
1,
119+
"A set should only contain one Replicate object",
120+
)
121+
placement_set.add(s1)
122+
self.assertEqual(
123+
len(placement_set),
124+
2,
125+
"The set should now contain two distinct objects",
126+
)
127+
128+
def test_partial_placement(self):
129+
"""
130+
Tests equality and hash for Partial placement.
131+
"""
132+
p_sum1 = dist.Partial(dist.ReduceType.kRedSum)
133+
p_sum2 = dist.Partial(dist.ReduceType.kRedSum)
134+
p_avg = dist.Partial(dist.ReduceType.kRedAvg)
135+
r1 = dist.Replicate()
136+
137+
# 1. Test equality
138+
self.assertEqual(
139+
p_sum1, p_sum2, "Two Partial(kRedSum) objects should be equal"
140+
)
141+
self.assertNotEqual(
142+
p_sum1,
143+
p_avg,
144+
"Partial(kRedSum) should not be equal to Partial(kRedAvg)",
145+
)
146+
self.assertNotEqual(
147+
p_sum1, r1, "Partial should not be equal to Replicate"
148+
)
149+
150+
# 2. Test hash consistency
151+
self.assertEqual(hash(p_sum1), hash(p_sum2))
152+
self.assertNotEqual(hash(p_sum1), hash(p_avg))
153+
154+
# 3. Test behavior in a set
155+
placement_set = {p_sum1, p_sum2}
156+
self.assertEqual(len(placement_set), 1)
157+
placement_set.add(p_avg)
158+
self.assertEqual(len(placement_set), 2)
159+
160+
161+
if __name__ == '__main__':
162+
unittest.main()

0 commit comments

Comments
 (0)