Skip to content

Commit

Permalink
[AutoParallel] Add reduce_all spmd rule (#59411)
Browse files Browse the repository at this point in the history
* [AutoParallel] Add reduce_all spmd rule

* fix compile error on window

* add c++ unittest for reduce_all spmd
  • Loading branch information
deepllz authored Nov 29, 2023
1 parent 3b40279 commit 87bf502
Show file tree
Hide file tree
Showing 9 changed files with 104 additions and 8 deletions.
1 change: 1 addition & 0 deletions paddle/phi/api/yaml/legacy_ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
output : Tensor(out)
infer_meta :
func : ReduceInferMeta
spmd_rule : ReductionAllInferSpmdDynamic
kernel :
func : all

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,14 @@ CommContext* CreateOrGetCommContext(const DeviceContext& dev_ctx,
do { \
if (phi::CPUContext::classof(dev_ctx)) { \
VLOG(4) << "Call `" << #fn_name << "` in Resharding on CPU."; \
PD_VISIT_FLOATING_AND_INTEGRAL_TYPES( \
PD_VISIT_BOOL_AND_FLOATING_AND_INTEGRAL_TYPES_CPU( \
dtype, #fn_name, ([&] { \
fn_name<data_t>(static_cast<const CPUContext&>(*dev_ctx), \
__VA_ARGS__); \
})); \
} else if (phi::GPUContext::classof(dev_ctx)) { \
VLOG(4) << "Call `" << #fn_name << "` in Resharding on GPU."; \
PD_VISIT_BOOL_AND_FLOATING_AND_INTEGRAL_TYPES( \
PD_VISIT_BOOL_AND_FLOATING_AND_INTEGRAL_TYPES_GPU( \
dtype, #fn_name, ([&] { \
fn_name<data_t>(static_cast<const GPUContext&>(*dev_ctx), \
__VA_ARGS__); \
Expand All @@ -99,7 +99,7 @@ CommContext* CreateOrGetCommContext(const DeviceContext& dev_ctx,
do { \
if (phi::CPUContext::classof(dev_ctx)) { \
VLOG(4) << "Call `" << #fn_name << "` in Resharding on CPU."; \
PD_VISIT_FLOATING_AND_INTEGRAL_TYPES( \
PD_VISIT_BOOL_AND_FLOATING_AND_INTEGRAL_TYPES_CPU( \
dtype, #fn_name, ([&] { \
fn_name<data_t>(static_cast<const CPUContext&>(*dev_ctx), \
__VA_ARGS__); \
Expand Down
7 changes: 7 additions & 0 deletions paddle/phi/core/distributed/gloo_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,13 @@ void SetReduceFunc(P* opts, int reduce_type) {
static_cast<void (*)(void*, const void*, const void*, size_t)>(
&gloo::product<T>));
break;
case ReduceType::kRedAll:
// NOTE(zhonghui): There is no reduce_all math function for gloo, just use
// min to replace
opts->setReduceFunction(
static_cast<void (*)(void*, const void*, const void*, size_t)>(
&gloo::min<T>));
break;
default:
PADDLE_THROW(
errors::InvalidArgument("Unsupport reduce type: %d.", reduce_type));
Expand Down
29 changes: 27 additions & 2 deletions paddle/phi/core/visit_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ namespace phi {
///////// BOOL and Floating and Integral Dispatch Marco ///////////

#if NCCL_VERSION_CODE >= 21000
#define PD_VISIT_BOOL_AND_FLOATING_AND_INTEGRAL_TYPES(TYPE, NAME, ...) \
#define PD_VISIT_BOOL_AND_FLOATING_AND_INTEGRAL_TYPES_GPU(TYPE, NAME, ...) \
[&] { \
const auto& __dtype__ = TYPE; \
switch (__dtype__) { \
Expand Down Expand Up @@ -180,7 +180,7 @@ namespace phi {
} \
}()
#else
#define PD_VISIT_BOOL_AND_FLOATING_AND_INTEGRAL_TYPES(TYPE, NAME, ...) \
#define PD_VISIT_BOOL_AND_FLOATING_AND_INTEGRAL_TYPES_GPU(TYPE, NAME, ...) \
[&] { \
const auto& __dtype__ = TYPE; \
switch (__dtype__) { \
Expand Down Expand Up @@ -208,6 +208,31 @@ namespace phi {
}()
#endif

#define PD_VISIT_BOOL_AND_FLOATING_AND_INTEGRAL_TYPES_CPU(TYPE, NAME, ...) \
[&] { \
const auto& __dtype__ = TYPE; \
switch (__dtype__) { \
PD_PRIVATE_CASE_TYPE(NAME, ::phi::DataType::BOOL, bool, __VA_ARGS__) \
PD_PRIVATE_CASE_TYPE( \
NAME, ::paddle::DataType::FLOAT32, float, __VA_ARGS__) \
PD_PRIVATE_CASE_TYPE( \
NAME, ::paddle::DataType::FLOAT64, double, __VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::INT32, int, __VA_ARGS__) \
PD_PRIVATE_CASE_TYPE( \
NAME, ::paddle::DataType::INT64, int64_t, __VA_ARGS__) \
PD_PRIVATE_CASE_TYPE( \
NAME, ::paddle::DataType::INT8, int8_t, __VA_ARGS__) \
PD_PRIVATE_CASE_TYPE( \
NAME, ::paddle::DataType::UINT8, uint8_t, __VA_ARGS__) \
PD_PRIVATE_CASE_TYPE( \
NAME, ::paddle::DataType::INT16, int16_t, __VA_ARGS__) \
default: \
PD_THROW("function " #NAME " is not implemented for data type `", \
__dtype__, \
"`"); \
} \
}()

///////// Floating and Complex Dispatch Marco ///////////

#define PD_VISIT_FLOATING_AND_COMPLEX_TYPES(TYPE, NAME, ...) \
Expand Down
7 changes: 7 additions & 0 deletions paddle/phi/infermeta/spmd_rules/reduction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,13 @@ SpmdInfo ReductionMaxInferSpmdDynamic(const DistMetaTensor& x,
x, axis.GetData(), keep_dim, static_cast<int>(ReduceType::kRedMax));
}

SpmdInfo ReductionAllInferSpmdDynamic(const DistMetaTensor& x,
const IntArray& axis,
bool keep_dim) {
return ReductionInferSpmdBase(
x, axis.GetData(), keep_dim, static_cast<int>(ReduceType::kRedAll));
}

SpmdInfo ReductionInferSpmdReverse(const DistMetaTensor& x,
const DistMetaTensor& out,
const std::vector<int64_t>& axis,
Expand Down
4 changes: 4 additions & 0 deletions paddle/phi/infermeta/spmd_rules/reduction.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ SpmdInfo ReductionMaxInferSpmdDynamic(const DistMetaTensor& x,
const IntArray& axis,
bool keep_dim);

SpmdInfo ReductionAllInferSpmdDynamic(const DistMetaTensor& x,
const IntArray& axis,
bool keep_dim);

SpmdInfo ReductionInferSpmdReverse(const DistMetaTensor& x,
const DistMetaTensor& out,
const std::vector<int64_t>& axis,
Expand Down
5 changes: 5 additions & 0 deletions paddle/phi/kernels/gpu/all_reduce_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ void AllReduceKernel(const Context& dev_ctx,
case ReduceType::kRedProd:
red_type = ncclProd;
break;
case ReduceType::kRedAll:
// NOTE(zhonghui): There is no reduce_all type of ncclRedOp_t, just use
// min to replace
red_type = ncclMin;
break;
}
comm_ctx->AllReduce(out, x, red_type, stream);
#else
Expand Down
32 changes: 29 additions & 3 deletions test/auto_parallel/semi_auto_parallel_for_reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,11 @@ def test_body(
):
paddle.seed(self._seed)
np.random.seed(self._seed)
is_op_func_all = op_func == paddle.all

x = paddle.randn(x_shape, self._dtype)
if is_op_func_all:
x = x > 0
x.stop_gradient = False

dist_x = dist.shard_tensor(x, self._mesh, x_placements)
Expand All @@ -49,9 +52,10 @@ def test_body(
self.check_tensor_eq(out, dist_out)
np.testing.assert_equal(dist_out.shape, out_shape, verbose=True)

dist_out.backward()
out.backward()
self.check_tensor_eq(x.grad, dist_x.grad)
if not is_op_func_all:
dist_out.backward()
out.backward()
self.check_tensor_eq(x.grad, dist_x.grad)

def test_sum_x_shard(self):
self.test_body(
Expand Down Expand Up @@ -113,6 +117,26 @@ def test_max_x_shard_on_axis(self):
op_func=paddle.max,
)

def test_all_x_shard(self):
self.test_body(
x_shape=[4, 8, 6],
out_shape=[4, 6],
x_placements=[dist.Shard(0)],
axis=1,
keepdim=False,
op_func=paddle.all,
)

def test_all_x_shard_on_axis(self):
self.test_body(
x_shape=[4, 8, 6],
out_shape=[4, 6],
x_placements=[dist.Shard(1)],
axis=1,
keepdim=False,
op_func=paddle.all,
)

def run_test_case(self):
if self._backend == "cpu":
paddle.set_device("cpu")
Expand All @@ -127,6 +151,8 @@ def run_test_case(self):
self.test_mean_x_shard()
self.test_max_x_shard()
self.test_max_x_shard_on_axis()
self.test_all_x_shard()
self.test_all_x_shard_on_axis()


if __name__ == '__main__':
Expand Down
21 changes: 21 additions & 0 deletions test/cpp/auto_parallel/spmd_rule_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1006,6 +1006,27 @@ TEST(ReduceMaxRule, Ctor) {
check_partial_dims(backward_info.second[0], {});
}

TEST(ReduceAllRule, Ctor) {
std::vector<int64_t> mesh_shape = {2};
std::vector<int64_t> process_ids = {0, 1};
std::vector<std::string> dim_names = {"x"};
ProcessMesh process_mesh(mesh_shape, process_ids, dim_names);

// test forward
auto t_dist_attr = TensorDistAttr();
t_dist_attr.set_process_mesh(process_mesh);
t_dist_attr.set_dims_mapping({-1, 0, -1});
t_dist_attr.set_dynamic_dims({false, false, false});
phi::distributed::DistMetaTensor x =
phi::distributed::DistMetaTensor(phi::make_ddim({4, 6, 8}), t_dist_attr);
IntArray axis = {1};
bool keep_dim = false;
phi::distributed::SpmdInfo forward_info =
phi::distributed::ReductionAllInferSpmdDynamic(x, axis, keep_dim);
check_dim_mapping(forward_info.second[0], {-1, -1});
check_partial_dims(forward_info.second[0], {0});
}

TEST(Numel, Ctor) {
std::vector<int64_t> mesh_shape = {2, 2};
std::vector<int64_t> process_ids = {0, 1, 2, 3};
Expand Down

0 comments on commit 87bf502

Please sign in to comment.