From 4a5b4c4e4e5c032f359cd37eebbdb619a7d2f89e Mon Sep 17 00:00:00 2001 From: LiYuRio Date: Mon, 6 Nov 2023 10:57:43 +0800 Subject: [PATCH] support p to s --- paddle/fluid/pybind/auto_parallel_py.cc | 5 + .../distributed/auto_parallel/CMakeLists.txt | 1 + .../auto_parallel/p_to_s_reshard_function.cc | 102 ++++++++++++++++++ .../auto_parallel/p_to_s_reshard_function.h | 33 ++++++ paddle/phi/infermeta/unary.cc | 3 + .../phi/kernels/cpu/reduce_scatter_kernel.cc | 1 + .../phi/kernels/gpu/reduce_scatter_kernel.cu | 2 + paddle/phi/kernels/reduce_scatter_kernel.h | 15 +++ test/auto_parallel/CMakeLists.txt | 3 + test/auto_parallel/reshard_p_to_s.py | 78 ++++++++++++++ test/auto_parallel/test_reshard_p_to_s.py | 45 ++++++++ 11 files changed, 288 insertions(+) create mode 100644 paddle/phi/core/distributed/auto_parallel/p_to_s_reshard_function.cc create mode 100644 paddle/phi/core/distributed/auto_parallel/p_to_s_reshard_function.h create mode 100644 test/auto_parallel/reshard_p_to_s.py create mode 100644 test/auto_parallel/test_reshard_p_to_s.py diff --git a/paddle/fluid/pybind/auto_parallel_py.cc b/paddle/fluid/pybind/auto_parallel_py.cc index ac02cd0fc87ac7..0e40350d2f4b16 100644 --- a/paddle/fluid/pybind/auto_parallel_py.cc +++ b/paddle/fluid/pybind/auto_parallel_py.cc @@ -42,6 +42,7 @@ #include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h" #include "paddle/phi/core/distributed/auto_parallel/nd_mesh_reshard_function.h" #include "paddle/phi/core/distributed/auto_parallel/p_to_r_reshard_function.h" +#include "paddle/phi/core/distributed/auto_parallel/p_to_s_reshard_function.h" #include "paddle/phi/core/distributed/auto_parallel/r_to_p_reshard_function.h" #include "paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.h" #include "paddle/phi/core/distributed/auto_parallel/s_to_r_reshard_function.h" @@ -199,6 +200,10 @@ void BindAutoParallel(py::module *m) { *m, "SToSReshardFunction", ReshardFunction) .def(py::init<>()); + py::class_( + *m, "PToSReshardFunction", ReshardFunction) + .def(py::init<>()); + py::class_( *m, "SameNdMeshReshardFunction", ReshardFunction) .def(py::init<>()); diff --git a/paddle/phi/core/distributed/auto_parallel/CMakeLists.txt b/paddle/phi/core/distributed/auto_parallel/CMakeLists.txt index 92e69e0dc7657d..60f8a060c4b8cb 100644 --- a/paddle/phi/core/distributed/auto_parallel/CMakeLists.txt +++ b/paddle/phi/core/distributed/auto_parallel/CMakeLists.txt @@ -17,5 +17,6 @@ collect_srcs( r_to_p_reshard_function.cc p_to_r_reshard_function.cc s_to_s_reshard_function.cc + p_to_s_reshard_function.cc nd_mesh_reshard_function.cc same_status_reshard_function.cc) diff --git a/paddle/phi/core/distributed/auto_parallel/p_to_s_reshard_function.cc b/paddle/phi/core/distributed/auto_parallel/p_to_s_reshard_function.cc new file mode 100644 index 00000000000000..eda8dbcecd29ad --- /dev/null +++ b/paddle/phi/core/distributed/auto_parallel/p_to_s_reshard_function.cc @@ -0,0 +1,102 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/distributed/auto_parallel/p_to_s_reshard_function.h" + +#include "glog/logging.h" +#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" +#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h" +#include "paddle/phi/core/distributed/auto_parallel/reshard_utils.h" +#include "paddle/phi/kernels/reduce_scatter_kernel.h" +#include "paddle/phi/kernels/transpose_kernel.h" + +namespace phi { +namespace distributed { + +bool PToSReshardFunction::IsSuitable(const DistTensor& in, + const TensorDistAttr& out_dist_attr) { + const auto& in_dist_attr = in.dist_attr(); + + RESHARD_SHORTCUT_IF_FALSE(in_dist_attr.is_partial()); + RESHARD_SHORTCUT_IF_FALSE(out_dist_attr.is_shard()); + + const auto& in_process_mesh = in_dist_attr.process_mesh(); + const auto& out_process_mesh = out_dist_attr.process_mesh(); + + RESHARD_SHORTCUT_IF_FALSE(in_process_mesh.ndim() == 1); + RESHARD_SHORTCUT_IF_FALSE(out_process_mesh.ndim() == 1); + RESHARD_SHORTCUT_IF_FALSE(in_process_mesh == out_process_mesh); + + return true; +} + +void PToSReshardFunction::Eval(DeviceContext* dev_ctx, + const DistTensor& in, + const TensorDistAttr& out_dist_attr, + DistTensor* out) { + VLOG(3) << "Call PToSReshardFunction Eval"; + const auto& in_dist_attr = in.dist_attr(); + const auto& in_process_mesh = in_dist_attr.process_mesh(); + const auto& in_process_ids = in_process_mesh.process_ids(); + auto dtype = in.dtype(); + const auto& logical_ddim = in.dims(); + + int out_split_axis = + GetSplitAxisWithDimsMapping(out_dist_attr.dims_mapping()).begin()->first; + + DenseTensor in_reduce_scatter = in.value(); + if (out_split_axis != 0) { + std::vector axis; + axis.emplace_back(out_split_axis); + for (size_t i = 0; i < vectorize(logical_ddim).size(); ++i) { + if (static_cast(i) != out_split_axis) { + axis.emplace_back(i); + } + } + RESHARD_FUNCTOR( + dev_ctx, Transpose, dtype, in.value(), axis, &in_reduce_scatter); + } + + DenseTensor out_reduce_scatter; + RESHARD_FUNCTOR_WITH_COMM(dev_ctx, + ReduceScatter, + dtype, + in_process_ids, + in_reduce_scatter, + static_cast(in_process_ids.size()), + &out_reduce_scatter); + + if (out_split_axis != 0) { + std::vector axis; + for (size_t i = 1; i < vectorize(logical_ddim).size(); ++i) { + axis.emplace_back(i); + } + axis.insert(axis.begin() + out_split_axis, 0); + RESHARD_FUNCTOR(dev_ctx, + Transpose, + dtype, + out_reduce_scatter, + axis, + GetMutableTensor(out)); + } else { + SetValue(out, out_reduce_scatter); + } + + SetDistProps(out, in.dims(), out_dist_attr); +} + +REGISTER_RESHARD_FUNC(PToSReshardFunction); + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/core/distributed/auto_parallel/p_to_s_reshard_function.h b/paddle/phi/core/distributed/auto_parallel/p_to_s_reshard_function.h new file mode 100644 index 00000000000000..079dce0da2efa4 --- /dev/null +++ b/paddle/phi/core/distributed/auto_parallel/p_to_s_reshard_function.h @@ -0,0 +1,33 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/phi/core/distributed/auto_parallel/reshard_function.h" + +namespace phi { +namespace distributed { + +class PToSReshardFunction final : public ReshardFunction { + public: + bool IsSuitable(const DistTensor& in, + const TensorDistAttr& out_dist_attr) override; + + void Eval(DeviceContext* dev_ctx, + const DistTensor& in, + const TensorDistAttr& out_dist_attr, + DistTensor* out) override; +}; + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 8873a617ef303f..747a13a3d2b482 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -17,6 +17,8 @@ limitations under the License. */ #include #include +#include "glog/logging.h" + #include "paddle/phi/common/data_type.h" #include "paddle/phi/common/type_traits.h" #include "paddle/phi/core/enforce.h" @@ -3334,6 +3336,7 @@ void ReduceIntArrayAxisInferMeta(const MetaTensor& x, void ReduceScatterInferMeta(const MetaTensor& x, int nranks, MetaTensor* out) { auto dim = x.dims(); + VLOG(0) << "dim " << dim << " nrank " << nranks; if (dim[0] > 0 || dim[0] < -1) { PADDLE_ENFORCE_EQ( dim[0] % nranks, diff --git a/paddle/phi/kernels/cpu/reduce_scatter_kernel.cc b/paddle/phi/kernels/cpu/reduce_scatter_kernel.cc index 6f0a73dd106e1f..03b54c34113584 100644 --- a/paddle/phi/kernels/cpu/reduce_scatter_kernel.cc +++ b/paddle/phi/kernels/cpu/reduce_scatter_kernel.cc @@ -40,5 +40,6 @@ PD_REGISTER_KERNEL(reduce_scatter, bool, int8_t, uint8_t, + int16_t, int64_t, phi::dtype::float16) {} diff --git a/paddle/phi/kernels/gpu/reduce_scatter_kernel.cu b/paddle/phi/kernels/gpu/reduce_scatter_kernel.cu index 28689ecc92f587..68cf339ada75b8 100644 --- a/paddle/phi/kernels/gpu/reduce_scatter_kernel.cu +++ b/paddle/phi/kernels/gpu/reduce_scatter_kernel.cu @@ -70,6 +70,7 @@ PD_REGISTER_KERNEL(reduce_scatter, bool, int8_t, uint8_t, + int16_t, int64_t, phi::dtype::bfloat16, phi::dtype::float16) {} @@ -84,6 +85,7 @@ PD_REGISTER_KERNEL(reduce_scatter, bool, int8_t, uint8_t, + int16_t, int64_t, phi::dtype::float16) {} #endif diff --git a/paddle/phi/kernels/reduce_scatter_kernel.h b/paddle/phi/kernels/reduce_scatter_kernel.h index ee889c0f172901..76bf5f02781462 100644 --- a/paddle/phi/kernels/reduce_scatter_kernel.h +++ b/paddle/phi/kernels/reduce_scatter_kernel.h @@ -15,6 +15,7 @@ #pragma once #include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/infermeta/unary.h" namespace phi { @@ -24,4 +25,18 @@ void ReduceScatterKernel(const Context& dev_ctx, int nranks, DenseTensor* out); +template +void ReduceScatter(const Context& dev_ctx, + const DenseTensor& x, + int nranks, + DenseTensor* out) { + MetaTensor out_meta(*out); + MetaTensor* out_meta_ptr = &out_meta; + + ReduceScatterInferMeta(phi::MetaTensor(x), nranks, out_meta_ptr); + if (x.initialized()) { + ReduceScatterKernel(dev_ctx, x, nranks, out); + } +} + } // namespace phi diff --git a/test/auto_parallel/CMakeLists.txt b/test/auto_parallel/CMakeLists.txt index aa1212c9ecc11d..5cbd8f0a86e758 100644 --- a/test/auto_parallel/CMakeLists.txt +++ b/test/auto_parallel/CMakeLists.txt @@ -95,6 +95,9 @@ if(WITH_DISTRIBUTE AND WITH_GPU) py_test_modules(test_reshard_s_to_r MODULES test_reshard_s_to_r) set_tests_properties(test_reshard_s_to_r PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 100) + py_test_modules(test_reshard_p_to_s MODULES test_reshard_p_to_s) + set_tests_properties(test_reshard_p_to_s + PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 100) py_test_modules(test_reshard_s_to_s MODULES test_reshard_s_to_s) set_tests_properties(test_reshard_s_to_s PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 100) diff --git a/test/auto_parallel/reshard_p_to_s.py b/test/auto_parallel/reshard_p_to_s.py new file mode 100644 index 00000000000000..0c7b6d189fe233 --- /dev/null +++ b/test/auto_parallel/reshard_p_to_s.py @@ -0,0 +1,78 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import numpy as np + +import paddle +import paddle.distributed as dist +from paddle.base import core + + +class TestReshardPToS: + def __init__(self): + self._shape = eval(os.getenv("shape")) + self._dtype = os.getenv("dtype") + self._seeds = eval(os.getenv("seeds")) + self._shard = eval(os.getenv("shard")) + self._backend = os.getenv("backend") + self._mesh = dist.ProcessMesh([0, 1], dim_names=["x"]) + + def run_test_case(self): + if self._backend == "gpu": + place = paddle.CUDAPlace(dist.get_rank()) + dev_ctx = core.DeviceContext.create(place) + + paddle.seed(self._seeds) + value = paddle.uniform(self._shape, self._dtype) + + in_shard_specs = [None for i in range(len(self._shape))] + out_shard_specs = [None for i in range(len(self._shape))] + out_shard_specs[self._shard] = "x" + + dist_attr = dist.DistAttr( + mesh=self._mesh, sharding_specs=in_shard_specs + ) + dist_attr._set_partial_dims([0]) + out_dist_attr = dist.DistAttr( + mesh=self._mesh, sharding_specs=out_shard_specs + ) + + input_tensor = dist.shard_tensor(value, dist_attr=dist_attr) + + reshard_func = core.PToSReshardFunction() + assert reshard_func.is_suitable(input_tensor, out_dist_attr) + + out_shape = list(self._shape) + out_shape[self._shard] = out_shape[self._shard] // 2 + out_expected_local_tensor_list = paddle.split( + value, num_or_sections=self._mesh.shape[0], axis=self._shard + ) + + out = reshard_func.eval(dev_ctx, input_tensor, out_dist_attr) + + np.testing.assert_equal( + out._local_value().numpy(), + out_expected_local_tensor_list[0].numpy() + if dist.get_rank() == 0 + else out_expected_local_tensor_list[1].numpy(), + ) + + assert np.equal(out.shape, input_tensor.shape).all() + assert np.equal(out._local_shape, out_shape).all() + + +if __name__ == '__main__': + TestReshardPToS().run_test_case() diff --git a/test/auto_parallel/test_reshard_p_to_s.py b/test/auto_parallel/test_reshard_p_to_s.py new file mode 100644 index 00000000000000..bd55e5355479ba --- /dev/null +++ b/test/auto_parallel/test_reshard_p_to_s.py @@ -0,0 +1,45 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import collective.test_communication_api_base as test_base + + +class TestReshardSToR(test_base.CommunicationTestDistBase): + def setUp(self): + super().setUp(num_of_devices=2, timeout=120) + self._default_envs = { + "shape": "(10, 20)", + "dtype": "float32", + "seeds": "2023", + } + self._changeable_envs = { + "shard": ["0", "1"], + "backend": ["gpu"], + } + + def test_reshard_p_to_s(self): + envs_list = test_base.gen_product_envs_list( + self._default_envs, self._changeable_envs + ) + for envs in envs_list: + self.run_test_case( + "reshard_p_to_s.py", + user_defined_envs=envs, + ) + + +if __name__ == "__main__": + unittest.main()