Skip to content

Commit

Permalink
support p to s
Browse files Browse the repository at this point in the history
  • Loading branch information
LiYuRio committed Nov 6, 2023
1 parent a827e97 commit 4a5b4c4
Show file tree
Hide file tree
Showing 11 changed files with 288 additions and 0 deletions.
5 changes: 5 additions & 0 deletions paddle/fluid/pybind/auto_parallel_py.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h"
#include "paddle/phi/core/distributed/auto_parallel/nd_mesh_reshard_function.h"
#include "paddle/phi/core/distributed/auto_parallel/p_to_r_reshard_function.h"
#include "paddle/phi/core/distributed/auto_parallel/p_to_s_reshard_function.h"
#include "paddle/phi/core/distributed/auto_parallel/r_to_p_reshard_function.h"
#include "paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.h"
#include "paddle/phi/core/distributed/auto_parallel/s_to_r_reshard_function.h"
Expand Down Expand Up @@ -199,6 +200,10 @@ void BindAutoParallel(py::module *m) {
*m, "SToSReshardFunction", ReshardFunction)
.def(py::init<>());

py::class_<phi::distributed::PToSReshardFunction>(
*m, "PToSReshardFunction", ReshardFunction)
.def(py::init<>());

py::class_<phi::distributed::SameNdMeshReshardFunction>(
*m, "SameNdMeshReshardFunction", ReshardFunction)
.def(py::init<>());
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/core/distributed/auto_parallel/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,6 @@ collect_srcs(
r_to_p_reshard_function.cc
p_to_r_reshard_function.cc
s_to_s_reshard_function.cc
p_to_s_reshard_function.cc
nd_mesh_reshard_function.cc
same_status_reshard_function.cc)
102 changes: 102 additions & 0 deletions paddle/phi/core/distributed/auto_parallel/p_to_s_reshard_function.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/core/distributed/auto_parallel/p_to_s_reshard_function.h"

#include "glog/logging.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h"
#include "paddle/phi/core/distributed/auto_parallel/reshard_utils.h"
#include "paddle/phi/kernels/reduce_scatter_kernel.h"
#include "paddle/phi/kernels/transpose_kernel.h"

namespace phi {
namespace distributed {

bool PToSReshardFunction::IsSuitable(const DistTensor& in,
const TensorDistAttr& out_dist_attr) {
const auto& in_dist_attr = in.dist_attr();

RESHARD_SHORTCUT_IF_FALSE(in_dist_attr.is_partial());
RESHARD_SHORTCUT_IF_FALSE(out_dist_attr.is_shard());

const auto& in_process_mesh = in_dist_attr.process_mesh();
const auto& out_process_mesh = out_dist_attr.process_mesh();

RESHARD_SHORTCUT_IF_FALSE(in_process_mesh.ndim() == 1);
RESHARD_SHORTCUT_IF_FALSE(out_process_mesh.ndim() == 1);
RESHARD_SHORTCUT_IF_FALSE(in_process_mesh == out_process_mesh);

return true;
}

void PToSReshardFunction::Eval(DeviceContext* dev_ctx,
const DistTensor& in,
const TensorDistAttr& out_dist_attr,
DistTensor* out) {
VLOG(3) << "Call PToSReshardFunction Eval";
const auto& in_dist_attr = in.dist_attr();
const auto& in_process_mesh = in_dist_attr.process_mesh();
const auto& in_process_ids = in_process_mesh.process_ids();
auto dtype = in.dtype();
const auto& logical_ddim = in.dims();

int out_split_axis =
GetSplitAxisWithDimsMapping(out_dist_attr.dims_mapping()).begin()->first;

DenseTensor in_reduce_scatter = in.value();
if (out_split_axis != 0) {
std::vector<int> axis;
axis.emplace_back(out_split_axis);
for (size_t i = 0; i < vectorize(logical_ddim).size(); ++i) {
if (static_cast<int>(i) != out_split_axis) {
axis.emplace_back(i);
}
}
RESHARD_FUNCTOR(
dev_ctx, Transpose, dtype, in.value(), axis, &in_reduce_scatter);
}

DenseTensor out_reduce_scatter;
RESHARD_FUNCTOR_WITH_COMM(dev_ctx,
ReduceScatter,
dtype,
in_process_ids,
in_reduce_scatter,
static_cast<int64_t>(in_process_ids.size()),
&out_reduce_scatter);

if (out_split_axis != 0) {
std::vector<int> axis;
for (size_t i = 1; i < vectorize(logical_ddim).size(); ++i) {
axis.emplace_back(i);
}
axis.insert(axis.begin() + out_split_axis, 0);
RESHARD_FUNCTOR(dev_ctx,
Transpose,
dtype,
out_reduce_scatter,
axis,
GetMutableTensor(out));
} else {
SetValue(out, out_reduce_scatter);
}

SetDistProps(out, in.dims(), out_dist_attr);
}

REGISTER_RESHARD_FUNC(PToSReshardFunction);

} // namespace distributed
} // namespace phi
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include "paddle/phi/core/distributed/auto_parallel/reshard_function.h"

namespace phi {
namespace distributed {

class PToSReshardFunction final : public ReshardFunction {
public:
bool IsSuitable(const DistTensor& in,
const TensorDistAttr& out_dist_attr) override;

void Eval(DeviceContext* dev_ctx,
const DistTensor& in,
const TensorDistAttr& out_dist_attr,
DistTensor* out) override;
};

} // namespace distributed
} // namespace phi
3 changes: 3 additions & 0 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ limitations under the License. */
#include <algorithm>
#include <set>

#include "glog/logging.h"

#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/type_traits.h"
#include "paddle/phi/core/enforce.h"
Expand Down Expand Up @@ -3334,6 +3336,7 @@ void ReduceIntArrayAxisInferMeta(const MetaTensor& x,

void ReduceScatterInferMeta(const MetaTensor& x, int nranks, MetaTensor* out) {
auto dim = x.dims();
VLOG(0) << "dim " << dim << " nrank " << nranks;
if (dim[0] > 0 || dim[0] < -1) {
PADDLE_ENFORCE_EQ(
dim[0] % nranks,
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/kernels/cpu/reduce_scatter_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,5 +40,6 @@ PD_REGISTER_KERNEL(reduce_scatter,
bool,
int8_t,
uint8_t,
int16_t,
int64_t,
phi::dtype::float16) {}
2 changes: 2 additions & 0 deletions paddle/phi/kernels/gpu/reduce_scatter_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ PD_REGISTER_KERNEL(reduce_scatter,
bool,
int8_t,
uint8_t,
int16_t,
int64_t,
phi::dtype::bfloat16,
phi::dtype::float16) {}
Expand All @@ -84,6 +85,7 @@ PD_REGISTER_KERNEL(reduce_scatter,
bool,
int8_t,
uint8_t,
int16_t,
int64_t,
phi::dtype::float16) {}
#endif
15 changes: 15 additions & 0 deletions paddle/phi/kernels/reduce_scatter_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#pragma once

#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/infermeta/unary.h"

namespace phi {

Expand All @@ -24,4 +25,18 @@ void ReduceScatterKernel(const Context& dev_ctx,
int nranks,
DenseTensor* out);

template <typename T, typename Context>
void ReduceScatter(const Context& dev_ctx,
const DenseTensor& x,
int nranks,
DenseTensor* out) {
MetaTensor out_meta(*out);
MetaTensor* out_meta_ptr = &out_meta;

ReduceScatterInferMeta(phi::MetaTensor(x), nranks, out_meta_ptr);
if (x.initialized()) {
ReduceScatterKernel<T, Context>(dev_ctx, x, nranks, out);
}
}

} // namespace phi
3 changes: 3 additions & 0 deletions test/auto_parallel/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_reshard_s_to_r MODULES test_reshard_s_to_r)
set_tests_properties(test_reshard_s_to_r
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 100)
py_test_modules(test_reshard_p_to_s MODULES test_reshard_p_to_s)
set_tests_properties(test_reshard_p_to_s
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 100)
py_test_modules(test_reshard_s_to_s MODULES test_reshard_s_to_s)
set_tests_properties(test_reshard_s_to_s
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 100)
Expand Down
78 changes: 78 additions & 0 deletions test/auto_parallel/reshard_p_to_s.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

import numpy as np

import paddle
import paddle.distributed as dist
from paddle.base import core


class TestReshardPToS:
def __init__(self):
self._shape = eval(os.getenv("shape"))
self._dtype = os.getenv("dtype")
self._seeds = eval(os.getenv("seeds"))
self._shard = eval(os.getenv("shard"))
self._backend = os.getenv("backend")
self._mesh = dist.ProcessMesh([0, 1], dim_names=["x"])

def run_test_case(self):
if self._backend == "gpu":
place = paddle.CUDAPlace(dist.get_rank())
dev_ctx = core.DeviceContext.create(place)

paddle.seed(self._seeds)
value = paddle.uniform(self._shape, self._dtype)

in_shard_specs = [None for i in range(len(self._shape))]
out_shard_specs = [None for i in range(len(self._shape))]
out_shard_specs[self._shard] = "x"

dist_attr = dist.DistAttr(
mesh=self._mesh, sharding_specs=in_shard_specs
)
dist_attr._set_partial_dims([0])
out_dist_attr = dist.DistAttr(
mesh=self._mesh, sharding_specs=out_shard_specs
)

input_tensor = dist.shard_tensor(value, dist_attr=dist_attr)

reshard_func = core.PToSReshardFunction()
assert reshard_func.is_suitable(input_tensor, out_dist_attr)

out_shape = list(self._shape)
out_shape[self._shard] = out_shape[self._shard] // 2
out_expected_local_tensor_list = paddle.split(
value, num_or_sections=self._mesh.shape[0], axis=self._shard
)

out = reshard_func.eval(dev_ctx, input_tensor, out_dist_attr)

np.testing.assert_equal(
out._local_value().numpy(),
out_expected_local_tensor_list[0].numpy()
if dist.get_rank() == 0
else out_expected_local_tensor_list[1].numpy(),
)

assert np.equal(out.shape, input_tensor.shape).all()
assert np.equal(out._local_shape, out_shape).all()


if __name__ == '__main__':
TestReshardPToS().run_test_case()
45 changes: 45 additions & 0 deletions test/auto_parallel/test_reshard_p_to_s.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import collective.test_communication_api_base as test_base


class TestReshardSToR(test_base.CommunicationTestDistBase):
def setUp(self):
super().setUp(num_of_devices=2, timeout=120)
self._default_envs = {
"shape": "(10, 20)",
"dtype": "float32",
"seeds": "2023",
}
self._changeable_envs = {
"shard": ["0", "1"],
"backend": ["gpu"],
}

def test_reshard_p_to_s(self):
envs_list = test_base.gen_product_envs_list(
self._default_envs, self._changeable_envs
)
for envs in envs_list:
self.run_test_case(
"reshard_p_to_s.py",
user_defined_envs=envs,
)


if __name__ == "__main__":
unittest.main()

0 comments on commit 4a5b4c4

Please sign in to comment.