Skip to content

Commit

Permalink
[Semi Auto] Revise spmd rule static mode API (PaddlePaddle#57269)
Browse files Browse the repository at this point in the history
* adapt general spmd rule

* polish details

* add new rules
---------

Co-authored-by: Chen Weihang <chenweihang@baidu.com>
  • Loading branch information
2 people authored and Frida-a committed Oct 14, 2023
1 parent f67a22c commit 9be5864
Show file tree
Hide file tree
Showing 4 changed files with 239 additions and 204 deletions.
264 changes: 152 additions & 112 deletions paddle/fluid/pybind/auto_parallel_py.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include <Python.h>
#include <pybind11/operators.h>
#include <pybind11/stl.h>
#include <utility>
Expand All @@ -20,6 +21,8 @@
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/pybind/auto_parallel_py.h"
#include "paddle/fluid/pybind/eager_utils.h"
#include "paddle/fluid/pybind/op_function_common.h"
#include "paddle/fluid/pybind/pybind_variant_caster.h"
#include "paddle/phi/core/device_context.h"
#include "paddle/phi/core/distributed/auto_parallel/device_mesh.h"
Expand Down Expand Up @@ -48,6 +51,14 @@ namespace py = pybind11;
namespace paddle {
namespace pybind {

static bool PyCheckInteger(PyObject *obj) {
#if PY_VERSION_HEX < 0x03000000
return (PyLong_Check(obj) || PyInt_Check(obj)) && !PyBool_Check(obj);
#else
return PyLong_Check(obj) && !PyBool_Check(obj);
#endif
}

using paddle::distributed::auto_parallel::DistTensorSpec;
using paddle::distributed::auto_parallel::kDefault;
using paddle::distributed::auto_parallel::OperatorDistAttr;
Expand All @@ -67,6 +78,8 @@ using phi::distributed::auto_parallel::LinkCapability;
using phi::distributed::auto_parallel::Machine;

PyTypeObject *g_tensor_dist_attr_pytype = nullptr;
PyTypeObject *g_dist_tensor_spec_pytype = nullptr;
constexpr const char *infer_spmd_string = "infer_spmd";

static inline const ProcessMesh *get_tensor_process_mesh(
const TensorDistAttr &self) {
Expand Down Expand Up @@ -123,6 +136,11 @@ static inline void reset_operator_dist_attr(OperatorDistAttr *dist_attr) {
dist_attr->clear_annotated();
}

static std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
infer_forward(const phi::distributed::SpmdRule &self, const py::args &args);
static std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
infer_backward(const phi::distributed::SpmdRule &self, const py::args &args);

void BindAutoParallel(py::module *m) {
auto ReshardFunction =
py::class_<phi::distributed::ReshardFunction>(*m, "ReshardFunction")
Expand Down Expand Up @@ -373,118 +391,14 @@ void BindAutoParallel(py::module *m) {
// .def("infer_backward", &SPMDRuleBase::InferBackward) [revert in future]

py::class_<phi::distributed::SpmdRule>(*m, "SpmdRule")
.def("infer_forward",
[](const phi::distributed::SpmdRule &self,
const std::vector<DistTensorSpec> &input_specs,
const std::vector<phi::Attribute> &attrs) {
phi::distributed::InferSpmdContext ctx;
for (auto &spec : input_specs) {
ctx.EmplaceBackInput(phi::distributed::DistMetaTensor(
phi::make_ddim(spec.shape()), spec.dist_attr()));
}
for (auto &attr : attrs) {
ctx.EmplaceBackAttr(attr);
}
return self.InferForward(ctx);
})
.def("infer_forward", // for op that have vector argument
[](const phi::distributed::SpmdRule &self,
const std::vector<std::pair<int, int>> &input_ranges,
const std::vector<DistTensorSpec> &input_specs,
const std::vector<phi::Attribute> &attrs) {
/*
to distingish between single tensor argument and vector argument of
one tensor: start - end == 0: single tensor start - end == 1:
vector containing one tensor input_ranges: [(0, 0), (1, 3), (3, 4)]
+ input_specs: [t0, t1, t2, t3] --> t0, [t1, t2], [t3]
*/
phi::distributed::InferSpmdContext ctx;
paddle::small_vector<phi::distributed::DistMetaTensor,
phi::kInputSmallVectorSize>
ins;
for (auto &range : input_ranges) {
if (range.second - range.first == 0) {
auto &in = input_specs.at(range.first);
ctx.EmplaceBackInput(phi::distributed::DistMetaTensor(
phi::make_ddim(in.shape()), in.dist_attr()));
} else {
int start = range.first;
int end = range.second;
ins.reserve(end - start);
for (int i = start; i < end; ++i) {
auto &in = input_specs.at(i);
ins.emplace_back(phi::distributed::DistMetaTensor(
phi::make_ddim(in.shape()), in.dist_attr()));
}
ctx.EmplaceBackInputs(ins);
ins.clear();
}
}
for (auto &attr : attrs) {
ctx.EmplaceBackAttr(attr);
}
return self.InferForward(ctx);
})
.def("infer_backward",
[](const phi::distributed::SpmdRule &self,
const std::vector<DistTensorSpec> &input_specs,
const std::vector<DistTensorSpec> &output_specs,
const std::vector<phi::Attribute> &attrs) {
phi::distributed::InferSpmdContext ctx;
for (auto &spec : input_specs) {
ctx.EmplaceBackInput(phi::distributed::DistMetaTensor(
phi::make_ddim(spec.shape()), spec.dist_attr()));
}
for (auto &spec : output_specs) {
ctx.EmplaceBackInput(phi::distributed::DistMetaTensor(
phi::make_ddim(spec.shape()), spec.dist_attr()));
}
for (auto &attr : attrs) {
ctx.EmplaceBackAttr(attr);
}
return self.InferBackward(ctx);
})
.def("infer_backward", // for op that have vector argument
[](const phi::distributed::SpmdRule &self,
const std::vector<std::pair<int, int>> &input_ranges,
const std::vector<DistTensorSpec> &input_specs,
const std::vector<phi::Attribute> &attrs) {
/*
to distingish between single tensor argument and vector argument of
one tensor: start - end == 0: single tensor start - end == 1:
vector containing one tensor input_ranges: [(0, 0), (1, 3), (3, 4)]
+ input_specs: [t0, t1, t2, t3] --> t0, [t1, t2], [t3]
*/
phi::distributed::InferSpmdContext ctx;
paddle::small_vector<phi::distributed::DistMetaTensor,
phi::kInputSmallVectorSize>
ins;
for (auto &range : input_ranges) {
if (range.second - range.first == 0) {
auto &in = input_specs.at(range.first);
ctx.EmplaceBackInput(phi::distributed::DistMetaTensor(
phi::make_ddim(in.shape()), in.dist_attr()));
} else {
int start = range.first;
int end = range.second;
ins.reserve(end - start);
for (int i = start; i < end; ++i) {
auto &in = input_specs.at(i);
ins.emplace_back(phi::distributed::DistMetaTensor(
phi::make_ddim(in.shape()), in.dist_attr()));
}
ctx.EmplaceBackInputs(ins);
ins.clear();
}
}
for (auto &attr : attrs) {
ctx.EmplaceBackAttr(attr);
}
return self.InferBackward(ctx);
});

py::class_<DistTensorSpec>(*m, "DistTensorSpec")
.def(py::init<>())
.def("infer_forward", &infer_forward)
.def("infer_backward", &infer_backward);

py::class_<DistTensorSpec> py_dist_tensor_spec(
*m, "DistTensorSpec"); // TODO(ljz) remove and unify to DistTensor
g_dist_tensor_spec_pytype =
reinterpret_cast<PyTypeObject *>(py_dist_tensor_spec.ptr());
py_dist_tensor_spec.def(py::init<>())
.def(py::init<const DistTensorSpec &>())
.def(py::init<const std::vector<int64_t> &, const TensorDistAttr &>())
.def("dims_mapping", &DistTensorSpec::dims_mapping)
Expand Down Expand Up @@ -639,5 +553,131 @@ void BindAutoParallel(py::module *m) {
});
}

static void parse_tensors(PyObject *obj,
phi::distributed::InferSpmdContext *ctx,
const size_t arg_pos) {
Py_ssize_t len = PyList_Size(obj);
VLOG(6) << "args indx: [" << arg_pos << "] input vector of ["
<< static_cast<size_t>(len) << "] tensors.";
paddle::small_vector<phi::distributed::DistMetaTensor,
phi::kInputSmallVectorSize>
ins;
ins.reserve(static_cast<size_t>(len));
for (Py_ssize_t i = 0; i < len; i++) {
DistTensorSpec in = py::cast<DistTensorSpec>(PyList_GetItem(obj, i));
VLOG(6) << "Vector emplace_back DistTensorSpec: " << in.to_string();
ins.emplace_back(phi::distributed::DistMetaTensor(
phi::make_ddim(in.shape()), in.dist_attr()));
}
ctx->EmplaceBackInputs(ins);
}

static void parse_tensor(PyObject *obj,
phi::distributed::InferSpmdContext *ctx,
const size_t arg_pos) {
VLOG(6) << "args indx: [" << arg_pos << "] input one tensor.";
DistTensorSpec in = py::cast<DistTensorSpec>(obj);
VLOG(6) << "DistTensorSpec: " << in.to_string();
ctx->EmplaceBackInput(phi::distributed::DistMetaTensor(
phi::make_ddim(in.shape()), in.dist_attr()));
}

// TODO(ljz) support other types
static void parse_attrs(PyObject *obj,
PyObject *first_item,
phi::distributed::InferSpmdContext *ctx,
const size_t arg_pos) {
if (PyBool_Check(first_item)) {
auto attrs = CastPyArg2Booleans(obj, infer_spmd_string, arg_pos);
ctx->EmplaceBackAttr(attrs);
} else if (PyCheckInteger(first_item)) {
auto attrs = CastPyArg2Ints(obj, infer_spmd_string, arg_pos);
ctx->EmplaceBackAttr(attrs);
} else if (PyLong_Check(first_item)) {
auto attrs = CastPyArg2Longs(obj, infer_spmd_string, arg_pos);
ctx->EmplaceBackAttr(attrs);
} else if (PyFloat_Check(first_item)) {
auto attrs = CastPyArg2Floats(obj, infer_spmd_string, arg_pos);
ctx->EmplaceBackAttr(attrs);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
"list of int, float, bool or Tensor, but got %s",
infer_spmd_string,
arg_pos,
((PyTypeObject *)first_item->ob_type)->tp_name)); // NOLINT
}
}

// TODO(ljz) support other types
static void parse_attr(PyObject *obj,
phi::distributed::InferSpmdContext *ctx,
const size_t arg_pos) {
if (PyBool_Check(obj)) {
auto attr = CastPyArg2Boolean(obj, infer_spmd_string, arg_pos);
ctx->EmplaceBackAttr(attr);
} else if (PyCheckInteger(obj)) {
auto attr = CastPyArg2Int(obj, infer_spmd_string, arg_pos);
ctx->EmplaceBackAttr(attr);
} else if (PyLong_Check(obj)) {
auto attr = CastPyArg2Long(obj, infer_spmd_string, arg_pos);
ctx->EmplaceBackAttr(attr);
} else if (PyFloat_Check(obj)) {
auto attr = CastPyArg2Float(obj, infer_spmd_string, arg_pos);
ctx->EmplaceBackAttr(attr);
} else { // TODO(ljz) support other types
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
"int, float, bool or Tensor, but got %s",
infer_spmd_string,
arg_pos,
((PyTypeObject *)obj->ob_type)->tp_name)); // NOLINT
}
}

static void parse_single_pyobject(PyObject *obj,
phi::distributed::InferSpmdContext *ctx,
const size_t arg_pos) {
if (PyList_Check(obj)) { // list inputs, spmd not allow tuple inputs
PyObject *first_item = PyList_GetItem(obj, 0);
if (PyObject_TypeCheck(first_item, g_dist_tensor_spec_pytype)) {
parse_tensors(obj, ctx, arg_pos);
} else {
parse_attrs(obj, first_item, ctx, arg_pos);
}
} else {
if (PyObject_TypeCheck(obj, g_dist_tensor_spec_pytype)) {
parse_tensor(obj, ctx, arg_pos);
} else {
parse_attr(obj, ctx, arg_pos);
}
}
}

static void prepare_ctx(phi::distributed::InferSpmdContext *ctx,
const py::args &args) {
VLOG(6) << "prepare_ctx ";
size_t inputs_size = args.size();
for (size_t i = 0; i < inputs_size; ++i) {
PyObject *obj = args[i].ptr();
parse_single_pyobject(obj, ctx, i);
}
}
static std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
infer_forward(const phi::distributed::SpmdRule &self, const py::args &args) {
VLOG(6) << "infer_forward ";
phi::distributed::InferSpmdContext ctx;
prepare_ctx(&ctx, args);
return self.InferForward(ctx);
}

static std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
infer_backward(const phi::distributed::SpmdRule &self, const py::args &args) {
VLOG(6) << "infer_backward ";
phi::distributed::InferSpmdContext ctx;
prepare_ctx(&ctx, args);
return self.InferBackward(ctx);
}

} // namespace pybind
} // namespace paddle
37 changes: 12 additions & 25 deletions test/auto_parallel/spmd_rules/test_default_data_parallel_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,8 @@ def test_default_dp_infer_forward(self):
# 2 inputs 2 outputs, sharded batch axis
in_vec = [self.x_dist_tensor_spec, self.y_dist_tensor_spec]
out_vec = [self.out1_dist_tensor_spec, self.out2_dist_tensor_spec]
result_dist_attrs = self.rule.infer_forward(
[(0, len(in_vec)), (len(in_vec), len(in_vec) + len(out_vec))],
in_vec + out_vec,
[],
)
result_dist_attrs = self.rule.infer_forward(in_vec, out_vec)

self.assertEqual(len(result_dist_attrs), 2)
self.assertEqual(len(result_dist_attrs[0]), 2)
self.assertEqual(len(result_dist_attrs[1]), 2)
Expand All @@ -82,11 +79,9 @@ def test_default_dp_infer_forward(self):
self.out1_dist_tensor_spec,
self.out2_dist_tensor_spec,
]
result_dist_attrs = self.rule.infer_forward(
[(0, len(in_vec)), (len(in_vec), len(in_vec) + len(out_vec))],
in_vec + out_vec,
[],
)

result_dist_attrs = self.rule.infer_forward(in_vec, out_vec)

self.assertEqual(len(result_dist_attrs), 2)
self.assertEqual(len(result_dist_attrs[0]), 1)
self.assertEqual(len(result_dist_attrs[1]), 3)
Expand All @@ -103,11 +98,7 @@ def test_default_dp_infer_forward(self):
in_vec = [self.x_dist_tensor_spec, self.y_dist_tensor_spec]
out_vec = [self.out1_dist_tensor_spec, self.out2_dist_tensor_spec]
with self.assertRaises(NotImplementedError):
result_dist_attrs = self.rule.infer_forward(
[(0, len(in_vec)), (len(in_vec), len(in_vec) + len(out_vec))],
in_vec + out_vec,
[],
)
result_dist_attrs = self.rule.infer_forward(in_vec, out_vec)

def test_default_dp_infer_backward(self):
# replicated out1 from [-1, 0, 1, -1] to [0, -1, -1, -1]
Expand All @@ -117,11 +108,9 @@ def test_default_dp_infer_backward(self):

in_vec = [self.x_dist_tensor_spec, self.y_dist_tensor_spec]
out_vec = [self.out1_dist_tensor_spec, self.out2_dist_tensor_spec]
result_dist_attrs = self.rule.infer_backward(
[(0, len(in_vec)), (len(in_vec), len(in_vec) + len(out_vec))],
in_vec + out_vec,
[],
)

result_dist_attrs = self.rule.infer_backward(in_vec, out_vec)

self.assertEqual(len(result_dist_attrs), 2)
self.assertEqual(len(result_dist_attrs[0]), 2)
self.assertEqual(len(result_dist_attrs[1]), 2)
Expand All @@ -137,11 +126,9 @@ def test_default_dp_infer_backward(self):

in_vec = [self.x_dist_tensor_spec, self.y_dist_tensor_spec]
out_vec = [self.out1_dist_tensor_spec, self.out2_dist_tensor_spec]
result_dist_attrs = self.rule.infer_backward(
[(0, len(in_vec)), (len(in_vec), len(in_vec) + len(out_vec))],
in_vec + out_vec,
[],
)

result_dist_attrs = self.rule.infer_backward(in_vec, out_vec)

self.assertEqual(len(result_dist_attrs), 2)
self.assertEqual(len(result_dist_attrs[0]), 2)
self.assertEqual(len(result_dist_attrs[1]), 2)
Expand Down
Loading

0 comments on commit 9be5864

Please sign in to comment.