diff --git a/paddle/fluid/pybind/auto_parallel_py.cc b/paddle/fluid/pybind/auto_parallel_py.cc index 785e80a3abeaab..dec9c75216fc42 100644 --- a/paddle/fluid/pybind/auto_parallel_py.cc +++ b/paddle/fluid/pybind/auto_parallel_py.cc @@ -186,6 +186,10 @@ void BindAutoParallel(py::module *m) { *m, "SToRReshardFunction", ReshardFunction) .def(py::init<>()); + py::class_( + *m, "SToRReshardFunctionCrossMesh", ReshardFunction) + .def(py::init<>()); + py::class_( *m, "RToPReshardFunction", ReshardFunction) .def(py::init<>()); diff --git a/paddle/phi/core/distributed/auto_parallel/s_to_r_reshard_function.cc b/paddle/phi/core/distributed/auto_parallel/s_to_r_reshard_function.cc index a5f8ce455871d7..d4d62844be34f1 100644 --- a/paddle/phi/core/distributed/auto_parallel/s_to_r_reshard_function.cc +++ b/paddle/phi/core/distributed/auto_parallel/s_to_r_reshard_function.cc @@ -115,7 +115,52 @@ void SToRReshardFunction::Eval(DeviceContext* dev_ctx, } } +bool SToRReshardFunctionCrossMesh::IsSuitable( + const DistTensor& in, const TensorDistAttr& out_dist_attr) { + const auto& in_dist_attr = in.dist_attr(); + const auto& in_dims_mapping = in_dist_attr.dims_mapping(); + + RESHARD_SHORTCUT_IF_FALSE(in_dist_attr.is_shard()); + RESHARD_SHORTCUT_IF_FALSE(out_dist_attr.is_replicated()); + + const auto& in_process_mesh = in_dist_attr.process_mesh(); + const auto& out_process_mesh = out_dist_attr.process_mesh(); + + int split_axis = GetSplitAxisWithDimsMapping(in_dims_mapping).begin()->first; + int64_t num_of_process = in_process_mesh.size(); + RESHARD_SHORTCUT_IF_FALSE(in.local_dims()[static_cast(split_axis)] * + num_of_process == + in.dims()[static_cast(split_axis)]); + + RESHARD_SHORTCUT_IF_FALSE(in_process_mesh.ndim() == 1); + RESHARD_SHORTCUT_IF_FALSE(out_process_mesh.ndim() == 1); + RESHARD_SHORTCUT_IF_FALSE(in_process_mesh.shape() == + out_process_mesh.shape()); + RESHARD_SHORTCUT_IF_FALSE(in_process_mesh != out_process_mesh); + + return true; +} + +void SToRReshardFunctionCrossMesh::Eval(DeviceContext* dev_ctx, + const DistTensor& in, + const TensorDistAttr& out_dist_attr, + DistTensor* out) { + VLOG(3) << "Call SToRReshardFunctionCrossMesh Eval"; + const auto& out_process_mesh = out_dist_attr.process_mesh(); + + SameStatusReshardFunction same_status_func; + DistTensor tmp_result; + + TensorDistAttr tmp_dist_attr = in.dist_attr(); + tmp_dist_attr.set_process_mesh(out_process_mesh); + same_status_func.Eval(dev_ctx, in, tmp_dist_attr, &tmp_result); + + SToRReshardFunction s_to_r_func; + s_to_r_func.Eval(dev_ctx, tmp_result, out_dist_attr, out); +} + REGISTER_RESHARD_FUNC(SToRReshardFunction); +REGISTER_RESHARD_FUNC(SToRReshardFunctionCrossMesh); } // namespace distributed } // namespace phi diff --git a/paddle/phi/core/distributed/auto_parallel/s_to_r_reshard_function.h b/paddle/phi/core/distributed/auto_parallel/s_to_r_reshard_function.h index 869b4ed9178deb..4360c616595be4 100644 --- a/paddle/phi/core/distributed/auto_parallel/s_to_r_reshard_function.h +++ b/paddle/phi/core/distributed/auto_parallel/s_to_r_reshard_function.h @@ -14,6 +14,7 @@ #pragma once #include "paddle/phi/core/distributed/auto_parallel/reshard_function.h" +#include "paddle/phi/core/distributed/auto_parallel/same_status_reshard_function.h" namespace phi { namespace distributed { @@ -32,5 +33,16 @@ class SToRReshardFunction final : public ReshardFunction { DistTensor* out) override; }; +class SToRReshardFunctionCrossMesh final : public ReshardFunction { + public: + bool IsSuitable(const DistTensor& in, + const TensorDistAttr& out_dist_attr) override; + + void Eval(DeviceContext* dev_ctx, + const DistTensor& in, + const TensorDistAttr& out_dist_attr, + DistTensor* out) override; +}; + } // namespace distributed } // namespace phi diff --git a/test/auto_parallel/reshard_s_to_r_cross_mesh.py b/test/auto_parallel/reshard_s_to_r_cross_mesh.py new file mode 100644 index 00000000000000..e1ea23f7a95d6d --- /dev/null +++ b/test/auto_parallel/reshard_s_to_r_cross_mesh.py @@ -0,0 +1,85 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import numpy as np + +import paddle +import paddle.distributed as dist +from paddle.base import core + + +class TestReshardSToRCrossMesh: + def __init__(self): + self._shape = eval(os.getenv("shape")) + self._dtype = os.getenv("dtype") + self._seeds = eval(os.getenv("seeds")) + self._shard = eval(os.getenv("shard")) + self._backend = os.getenv("backend") + + self._in_mesh = dist.ProcessMesh([0, 1], dim_names=["x"]) + self._out_mesh = dist.ProcessMesh([1, 0], dim_names=["x"]) + + def run_test_case(self): + if self._backend == "cpu": + paddle.set_device("cpu") + place = paddle.CPUPlace() + elif self._backend == "gpu": + place = paddle.CUDAPlace(dist.get_rank()) + + dev_ctx = core.DeviceContext.create(place) + a = paddle.randn(self._shape) + + in_shard_specs = [None for i in range(len(self._shape))] + in_shard_specs[self._shard] = "x" + + out_shard_specs = [None for i in range(len(self._shape))] + dist_attr = dist.DistAttr( + mesh=self._in_mesh, sharding_specs=in_shard_specs + ) + out_dist_attr = dist.DistAttr( + mesh=self._out_mesh, sharding_specs=out_shard_specs + ) + + input_tensor = dist.shard_tensor(a, dist_attr=dist_attr) + + reshard_func = core.SToRReshardFunctionCrossMesh() + assert reshard_func.is_suitable(input_tensor, out_dist_attr) + + out = reshard_func.eval(dev_ctx, input_tensor, out_dist_attr) + + out_shape = list(self._shape) + if out_shape[self._shard] % 2 == 0: + split_shape = self._in_mesh.shape[0] + else: + split_shape = [ + out_shape[self._shard] // 2 + 1, + out_shape[self._shard] // 2, + ] + + in_expected_local_tensor_list = paddle.split( + out._local_value(), num_or_sections=split_shape, axis=self._shard + ) + + np.testing.assert_equal( + input_tensor._local_value().numpy(), + in_expected_local_tensor_list[dist.get_rank()].numpy(), + ) + + assert np.equal(out.shape, out_shape).all() + + +if __name__ == '__main__': + TestReshardSToRCrossMesh().run_test_case() diff --git a/test/auto_parallel/test_reshard_s_to_r.py b/test/auto_parallel/test_reshard_s_to_r.py index fd67df648a9b0a..ec61fbb2a3358d 100644 --- a/test/auto_parallel/test_reshard_s_to_r.py +++ b/test/auto_parallel/test_reshard_s_to_r.py @@ -40,6 +40,18 @@ def test_reshard_s_to_r(self): user_defined_envs=envs, ) + def test_reshard_s_to_r_cross_mesh(self): + envs_list = test_base.gen_product_envs_list( + self._default_envs, self._changeable_envs + ) + + for envs in envs_list: + if envs["backend"] != "cpu": + self.run_test_case( + "reshard_s_to_r_cross_mesh.py", + user_defined_envs=envs, + ) + if __name__ == "__main__": unittest.main()