Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 36 additions & 2 deletions paddle/phi/core/distributed/auto_parallel/placement_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,10 @@ class Shard : public Placement {

bool operator==(const Placement& other) const override {
const Shard* other_shard = dynamic_cast<const Shard*>(&other);
return other_shard && this->dim_ == other_shard->dim_;
if (!other_shard) return false;
if (other_shard->get_co_shard_order() != 0) return false;
return this->dim_ == other_shard->dim_ &&
this->split_factor_ == other_shard->split_factor_;
}

bool operator!=(const Placement& other) const override {
Expand Down Expand Up @@ -152,13 +155,44 @@ class CoShard : public Shard {
}

std::shared_ptr<Shard> copy() const override {
return std::make_shared<Shard>(*this);
return std::make_shared<CoShard>(*this);
}

std::shared_ptr<Shard> deepcopy() const override {
return std::make_shared<CoShard>(*this);
}

bool operator==(const Placement& other) const override {
if (const CoShard* other_coshard = dynamic_cast<const CoShard*>(&other)) {
return this->dim_ == other_coshard->dim_ &&
this->split_factor_ == other_coshard->split_factor_ &&
this->co_shard_order_ == other_coshard->co_shard_order_;
}
if (const Shard* other_shard = dynamic_cast<const Shard*>(&other)) {
return this->co_shard_order_ == 0 &&
this->dim_ == other_shard->get_dim() &&
this->split_factor_ == other_shard->get_split_factor();
}
return false;
}

bool operator!=(const Placement& other) const override {
return !(*this == other);
}

std::size_t hash() const override {
std::stringstream ss;
ss << "Shard(dim=" << std::to_string(dim_);
if (split_factor_ != 1) {
ss << ", split_factor=" << std::to_string(split_factor_);
}
if (co_shard_order_ != 0) {
ss << ", shard_order=" << std::to_string(co_shard_order_);
}
ss << ")";
return std::hash<std::string>{}(ss.str());
}

private:
int64_t co_shard_order_ = 0;
};
Expand Down
2 changes: 2 additions & 0 deletions test/auto_parallel/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ add_subdirectory(pir)
if(WITH_DISTRIBUTE AND WITH_GPU)

# NOTE(zyl): unittests WITH multi cards and timeout
py_test_modules(test_co_shard MODULES test_co_shard)
py_test_modules(test_converter MODULES test_converter)
set_tests_properties(test_converter PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE"
TIMEOUT 50)
Expand Down Expand Up @@ -173,6 +174,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_api_dist_branch MODULES test_api_dist_branch)
py_test_modules(test_shard_tensor_api MODULES test_shard_tensor_api ENVS
FLAGS_enable_pir_api=1)
py_test_modules(test_placement_types MODULES test_placement_types)
py_test_modules(test_strategy_api MODULES test_strategy_api)
py_test_modules(test_parallel_api MODULES test_parallel_api)
py_test_modules(test_dtensor_to_local_api MODULES test_dtensor_to_local_api)
Expand Down
20 changes: 10 additions & 10 deletions test/auto_parallel/co_shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@
class TestCoShard:
def basic_interface_case(self):
shard = dist.Shard(0, shard_order=0)
np.testing.assert_equal(str(shard), "Shard(dim=0, shard_order=0)")
np.testing.assert_equal(shard, dist.Shard(dim=0, shard_order=0))

shard = dist.Shard(0, split_factor=2)
np.testing.assert_equal(str(shard), "Shard(dim=0, split_factor=2)")
np.testing.assert_equal(shard, dist.Shard(dim=0, split_factor=2))

def run_test_case_0(self):
a = paddle.to_tensor([[1, 2], [3, 4], [5, 6], [7, 8]])
Expand Down Expand Up @@ -157,10 +157,10 @@ def run_test_case_3(self):
a[dist.get_rank()].numpy().flatten(),
)
np.testing.assert_equal(
str(out.placements[0]), "Shard(dim=0, shard_order=0)"
out.placements[0], dist.Shard(dim=0, shard_order=0)
)
np.testing.assert_equal(
str(out.placements[1]), "Shard(dim=0, shard_order=1)"
out.placements[1], dist.Shard(dim=0, shard_order=1)
)

def run_test_case_4(self):
Expand All @@ -172,27 +172,27 @@ def run_test_case_4(self):
out = paddle.reshape(input, [-1])
np.testing.assert_equal(out.shape, [8])
np.testing.assert_equal(
str(out.placements[0]), "Shard(dim=0, shard_order=0)"
out.placements[0], dist.Shard(dim=0, shard_order=0)
)
np.testing.assert_equal(
str(out.placements[1]), "Shard(dim=0, shard_order=1)"
out.placements[1], dist.Shard(dim=0, shard_order=1)
)
np.testing.assert_equal(
out._local_value().numpy(), a[dist.get_rank()].numpy().flatten()
)

relu_out = paddle.nn.ReLU()(out)
np.testing.assert_equal(
str(relu_out.placements[0]), "Shard(dim=0, shard_order=0)"
relu_out.placements[0], dist.Shard(dim=0, shard_order=0)
)
np.testing.assert_equal(
str(relu_out.placements[1]), "Shard(dim=0, shard_order=1)"
relu_out.placements[1], dist.Shard(dim=0, shard_order=1)
)

# test fallback to shard by one dim.
add_out = paddle.add(relu_out, relu_out)
np.testing.assert_equal(str(add_out.placements[0]), "Shard(dim=0)")
np.testing.assert_equal(str(add_out.placements[1]), "Replicate()")
np.testing.assert_equal(add_out.placements[0], dist.Shard(dim=0))
np.testing.assert_equal(add_out.placements[1], dist.Replicate())

def run_test_case_main(self):
self.basic_interface_case()
Expand Down
162 changes: 162 additions & 0 deletions test/auto_parallel/test_placement_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import unittest

import paddle.distributed as dist


class TestPlacementTypes(unittest.TestCase):
def test_shard_eq_with_co_shard_order_zero(self):
"""
Tests that a Shard is equal to a CoShard with shard_order=0.
This confirms the "semantic equality" philosophy.
"""
s1 = dist.Shard(0)
s2 = dist.Shard(dim=0, shard_order=0)

# 1. Test for symmetric equality
self.assertEqual(
s1, s2, "Shard(0) should be equal to Shard(dim=0, shard_order=0)"
)
self.assertEqual(s2, s1, "Equality should be symmetric")

# 2. Test hash consistency
self.assertEqual(
hash(s1), hash(s2), "Hashes must be equal for equal objects"
)

# 3. Test behavior in a set
placement_set = {s1, s2}
self.assertEqual(
len(placement_set),
1,
"A set should only contain one of the two equal objects",
)

# 4. Test behavior in a dict
placement_dict = {s1: "value1"}
self.assertIn(
s2, placement_dict, "s2 should be found in a dict keyed by s1"
)
self.assertEqual(placement_dict[s2], "value1")

def test_shard_neq_with_co_shard_order_non_zero(self):
"""
Tests that a Shard is NOT equal to a CoShard with a non-zero shard_order.
"""
s1 = dist.Shard(0)
s2 = dist.Shard(dim=0, shard_order=1)

# 1. Test for symmetric inequality
self.assertNotEqual(
s1,
s2,
"Shard(0) should NOT be equal to Shard(dim=0, shard_order=1)",
)
self.assertNotEqual(s2, s1, "Inequality should be symmetric")

# 2. Test hash difference
# Note: While not a strict requirement for non-equal objects to have different hashes,
# a good hash function should minimize collisions. We test for non-collision here.
self.assertNotEqual(
hash(s1), hash(s2), "Hashes should be different for unequal objects"
)

# 3. Test behavior in a set
placement_set = {s1, s2}
self.assertEqual(
len(placement_set), 2, "A set should contain two distinct objects"
)

def test_co_shard_eq(self):
"""
Tests equality for two CoShard objects.
"""
s1 = dist.Shard(dim=0, shard_order=1)
s2 = dist.Shard(dim=0, shard_order=1)
s3 = dist.Shard(dim=0, shard_order=2)

self.assertEqual(s1, s2)
self.assertNotEqual(s1, s3)

def test_replicate_placement(self):
"""
Tests equality and hash for Replicate placement.
"""
r1 = dist.Replicate()
r2 = dist.Replicate()
s1 = dist.Shard(0)

# 1. Test equality
self.assertEqual(r1, r2, "Two Replicate objects should be equal")
self.assertNotEqual(r1, s1, "Replicate should not be equal to Shard")

# 2. Test hash consistency
self.assertEqual(
hash(r1),
hash(r2),
"Hashes of two Replicate objects should be equal",
)

# 3. Test behavior in a set
placement_set: set[dist.Placement] = {r1, r2}
self.assertEqual(
len(placement_set),
1,
"A set should only contain one Replicate object",
)
placement_set.add(s1)
self.assertEqual(
len(placement_set),
2,
"The set should now contain two distinct objects",
)

def test_partial_placement(self):
"""
Tests equality and hash for Partial placement.
"""
p_sum1 = dist.Partial(dist.ReduceType.kRedSum)
p_sum2 = dist.Partial(dist.ReduceType.kRedSum)
p_avg = dist.Partial(dist.ReduceType.kRedAvg)
r1 = dist.Replicate()

# 1. Test equality
self.assertEqual(
p_sum1, p_sum2, "Two Partial(kRedSum) objects should be equal"
)
self.assertNotEqual(
p_sum1,
p_avg,
"Partial(kRedSum) should not be equal to Partial(kRedAvg)",
)
self.assertNotEqual(
p_sum1, r1, "Partial should not be equal to Replicate"
)

# 2. Test hash consistency
self.assertEqual(hash(p_sum1), hash(p_sum2))
self.assertNotEqual(hash(p_sum1), hash(p_avg))

# 3. Test behavior in a set
placement_set = {p_sum1, p_sum2}
self.assertEqual(len(placement_set), 1)
placement_set.add(p_avg)
self.assertEqual(len(placement_set), 2)


if __name__ == '__main__':
unittest.main()
Loading