diff --git a/paddle/fluid/pybind/auto_parallel_py.cc b/paddle/fluid/pybind/auto_parallel_py.cc index 27d6a75ba0736..ac2721e4c8cf7 100644 --- a/paddle/fluid/pybind/auto_parallel_py.cc +++ b/paddle/fluid/pybind/auto_parallel_py.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include #include #include @@ -20,6 +21,8 @@ #include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/var_desc.h" #include "paddle/fluid/pybind/auto_parallel_py.h" +#include "paddle/fluid/pybind/eager_utils.h" +#include "paddle/fluid/pybind/op_function_common.h" #include "paddle/fluid/pybind/pybind_variant_caster.h" #include "paddle/phi/core/device_context.h" #include "paddle/phi/core/distributed/auto_parallel/device_mesh.h" @@ -47,6 +50,14 @@ namespace py = pybind11; namespace paddle { namespace pybind { +static bool PyCheckInteger(PyObject *obj) { +#if PY_VERSION_HEX < 0x03000000 + return (PyLong_Check(obj) || PyInt_Check(obj)) && !PyBool_Check(obj); +#else + return PyLong_Check(obj) && !PyBool_Check(obj); +#endif +} + using paddle::distributed::auto_parallel::DistTensorSpec; using paddle::distributed::auto_parallel::kDefault; using paddle::distributed::auto_parallel::OperatorDistAttr; @@ -66,6 +77,8 @@ using phi::distributed::auto_parallel::LinkCapability; using phi::distributed::auto_parallel::Machine; PyTypeObject *g_tensor_dist_attr_pytype = nullptr; +PyTypeObject *g_dist_tensor_spec_pytype = nullptr; +constexpr const char *infer_spmd_string = "infer_spmd"; static inline const ProcessMesh *get_tensor_process_mesh( const TensorDistAttr &self) { @@ -122,6 +135,11 @@ static inline void reset_operator_dist_attr(OperatorDistAttr *dist_attr) { dist_attr->clear_annotated(); } +static std::pair, std::vector> +infer_forward(const phi::distributed::SpmdRule &self, const py::args &args); +static std::pair, std::vector> +infer_backward(const phi::distributed::SpmdRule &self, const py::args &args); + void BindAutoParallel(py::module *m) { auto ReshardFunction = py::class_(*m, "ReshardFunction") @@ -368,118 +386,14 @@ void BindAutoParallel(py::module *m) { // .def("infer_backward", &SPMDRuleBase::InferBackward) [revert in future] py::class_(*m, "SpmdRule") - .def("infer_forward", - [](const phi::distributed::SpmdRule &self, - const std::vector &input_specs, - const std::vector &attrs) { - phi::distributed::InferSpmdContext ctx; - for (auto &spec : input_specs) { - ctx.EmplaceBackInput(phi::distributed::DistMetaTensor( - phi::make_ddim(spec.shape()), spec.dist_attr())); - } - for (auto &attr : attrs) { - ctx.EmplaceBackAttr(attr); - } - return self.InferForward(ctx); - }) - .def("infer_forward", // for op that have vector argument - [](const phi::distributed::SpmdRule &self, - const std::vector> &input_ranges, - const std::vector &input_specs, - const std::vector &attrs) { - /* - to distingish between single tensor argument and vector argument of - one tensor: start - end == 0: single tensor start - end == 1: - vector containing one tensor input_ranges: [(0, 0), (1, 3), (3, 4)] - + input_specs: [t0, t1, t2, t3] --> t0, [t1, t2], [t3] - */ - phi::distributed::InferSpmdContext ctx; - paddle::small_vector - ins; - for (auto &range : input_ranges) { - if (range.second - range.first == 0) { - auto &in = input_specs.at(range.first); - ctx.EmplaceBackInput(phi::distributed::DistMetaTensor( - phi::make_ddim(in.shape()), in.dist_attr())); - } else { - int start = range.first; - int end = range.second; - ins.reserve(end - start); - for (int i = start; i < end; ++i) { - auto &in = input_specs.at(i); - ins.emplace_back(phi::distributed::DistMetaTensor( - phi::make_ddim(in.shape()), in.dist_attr())); - } - ctx.EmplaceBackInputs(ins); - ins.clear(); - } - } - for (auto &attr : attrs) { - ctx.EmplaceBackAttr(attr); - } - return self.InferForward(ctx); - }) - .def("infer_backward", - [](const phi::distributed::SpmdRule &self, - const std::vector &input_specs, - const std::vector &output_specs, - const std::vector &attrs) { - phi::distributed::InferSpmdContext ctx; - for (auto &spec : input_specs) { - ctx.EmplaceBackInput(phi::distributed::DistMetaTensor( - phi::make_ddim(spec.shape()), spec.dist_attr())); - } - for (auto &spec : output_specs) { - ctx.EmplaceBackInput(phi::distributed::DistMetaTensor( - phi::make_ddim(spec.shape()), spec.dist_attr())); - } - for (auto &attr : attrs) { - ctx.EmplaceBackAttr(attr); - } - return self.InferBackward(ctx); - }) - .def("infer_backward", // for op that have vector argument - [](const phi::distributed::SpmdRule &self, - const std::vector> &input_ranges, - const std::vector &input_specs, - const std::vector &attrs) { - /* - to distingish between single tensor argument and vector argument of - one tensor: start - end == 0: single tensor start - end == 1: - vector containing one tensor input_ranges: [(0, 0), (1, 3), (3, 4)] - + input_specs: [t0, t1, t2, t3] --> t0, [t1, t2], [t3] - */ - phi::distributed::InferSpmdContext ctx; - paddle::small_vector - ins; - for (auto &range : input_ranges) { - if (range.second - range.first == 0) { - auto &in = input_specs.at(range.first); - ctx.EmplaceBackInput(phi::distributed::DistMetaTensor( - phi::make_ddim(in.shape()), in.dist_attr())); - } else { - int start = range.first; - int end = range.second; - ins.reserve(end - start); - for (int i = start; i < end; ++i) { - auto &in = input_specs.at(i); - ins.emplace_back(phi::distributed::DistMetaTensor( - phi::make_ddim(in.shape()), in.dist_attr())); - } - ctx.EmplaceBackInputs(ins); - ins.clear(); - } - } - for (auto &attr : attrs) { - ctx.EmplaceBackAttr(attr); - } - return self.InferBackward(ctx); - }); - - py::class_(*m, "DistTensorSpec") - .def(py::init<>()) + .def("infer_forward", &infer_forward) + .def("infer_backward", &infer_backward); + + py::class_ py_dist_tensor_spec( + *m, "DistTensorSpec"); // TODO(ljz) remove and unify to DistTensor + g_dist_tensor_spec_pytype = + reinterpret_cast(py_dist_tensor_spec.ptr()); + py_dist_tensor_spec.def(py::init<>()) .def(py::init()) .def(py::init &, const TensorDistAttr &>()) .def("dims_mapping", &DistTensorSpec::dims_mapping) @@ -634,5 +548,131 @@ void BindAutoParallel(py::module *m) { }); } +static void parse_tensors(PyObject *obj, + phi::distributed::InferSpmdContext *ctx, + const size_t arg_pos) { + Py_ssize_t len = PyList_Size(obj); + VLOG(6) << "args indx: [" << arg_pos << "] input vector of [" + << static_cast(len) << "] tensors."; + paddle::small_vector + ins; + ins.reserve(static_cast(len)); + for (Py_ssize_t i = 0; i < len; i++) { + DistTensorSpec in = py::cast(PyList_GetItem(obj, i)); + VLOG(6) << "Vector emplace_back DistTensorSpec: " << in.to_string(); + ins.emplace_back(phi::distributed::DistMetaTensor( + phi::make_ddim(in.shape()), in.dist_attr())); + } + ctx->EmplaceBackInputs(ins); +} + +static void parse_tensor(PyObject *obj, + phi::distributed::InferSpmdContext *ctx, + const size_t arg_pos) { + VLOG(6) << "args indx: [" << arg_pos << "] input one tensor."; + DistTensorSpec in = py::cast(obj); + VLOG(6) << "DistTensorSpec: " << in.to_string(); + ctx->EmplaceBackInput(phi::distributed::DistMetaTensor( + phi::make_ddim(in.shape()), in.dist_attr())); +} + +// TODO(ljz) support other types +static void parse_attrs(PyObject *obj, + PyObject *first_item, + phi::distributed::InferSpmdContext *ctx, + const size_t arg_pos) { + if (PyBool_Check(first_item)) { + auto attrs = CastPyArg2Booleans(obj, infer_spmd_string, arg_pos); + ctx->EmplaceBackAttr(attrs); + } else if (PyCheckInteger(first_item)) { + auto attrs = CastPyArg2Ints(obj, infer_spmd_string, arg_pos); + ctx->EmplaceBackAttr(attrs); + } else if (PyLong_Check(first_item)) { + auto attrs = CastPyArg2Longs(obj, infer_spmd_string, arg_pos); + ctx->EmplaceBackAttr(attrs); + } else if (PyFloat_Check(first_item)) { + auto attrs = CastPyArg2Floats(obj, infer_spmd_string, arg_pos); + ctx->EmplaceBackAttr(attrs); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "%s(): argument (position %d) must be " + "list of int, float, bool or Tensor, but got %s", + infer_spmd_string, + arg_pos, + ((PyTypeObject *)first_item->ob_type)->tp_name)); // NOLINT + } +} + +// TODO(ljz) support other types +static void parse_attr(PyObject *obj, + phi::distributed::InferSpmdContext *ctx, + const size_t arg_pos) { + if (PyBool_Check(obj)) { + auto attr = CastPyArg2Boolean(obj, infer_spmd_string, arg_pos); + ctx->EmplaceBackAttr(attr); + } else if (PyCheckInteger(obj)) { + auto attr = CastPyArg2Int(obj, infer_spmd_string, arg_pos); + ctx->EmplaceBackAttr(attr); + } else if (PyLong_Check(obj)) { + auto attr = CastPyArg2Long(obj, infer_spmd_string, arg_pos); + ctx->EmplaceBackAttr(attr); + } else if (PyFloat_Check(obj)) { + auto attr = CastPyArg2Float(obj, infer_spmd_string, arg_pos); + ctx->EmplaceBackAttr(attr); + } else { // TODO(ljz) support other types + PADDLE_THROW(platform::errors::InvalidArgument( + "%s(): argument (position %d) must be " + "int, float, bool or Tensor, but got %s", + infer_spmd_string, + arg_pos, + ((PyTypeObject *)obj->ob_type)->tp_name)); // NOLINT + } +} + +static void parse_single_pyobject(PyObject *obj, + phi::distributed::InferSpmdContext *ctx, + const size_t arg_pos) { + if (PyList_Check(obj)) { // list inputs, spmd not allow tuple inputs + PyObject *first_item = PyList_GetItem(obj, 0); + if (PyObject_TypeCheck(first_item, g_dist_tensor_spec_pytype)) { + parse_tensors(obj, ctx, arg_pos); + } else { + parse_attrs(obj, first_item, ctx, arg_pos); + } + } else { + if (PyObject_TypeCheck(obj, g_dist_tensor_spec_pytype)) { + parse_tensor(obj, ctx, arg_pos); + } else { + parse_attr(obj, ctx, arg_pos); + } + } +} + +static void prepare_ctx(phi::distributed::InferSpmdContext *ctx, + const py::args &args) { + VLOG(6) << "prepare_ctx "; + size_t inputs_size = args.size(); + for (size_t i = 0; i < inputs_size; ++i) { + PyObject *obj = args[i].ptr(); + parse_single_pyobject(obj, ctx, i); + } +} +static std::pair, std::vector> +infer_forward(const phi::distributed::SpmdRule &self, const py::args &args) { + VLOG(6) << "infer_forward "; + phi::distributed::InferSpmdContext ctx; + prepare_ctx(&ctx, args); + return self.InferForward(ctx); +} + +static std::pair, std::vector> +infer_backward(const phi::distributed::SpmdRule &self, const py::args &args) { + VLOG(6) << "infer_backward "; + phi::distributed::InferSpmdContext ctx; + prepare_ctx(&ctx, args); + return self.InferBackward(ctx); +} + } // namespace pybind } // namespace paddle diff --git a/test/auto_parallel/spmd_rules/test_default_data_parallel_rule.py b/test/auto_parallel/spmd_rules/test_default_data_parallel_rule.py index d9c4c3f7d7528..8d69da185246e 100644 --- a/test/auto_parallel/spmd_rules/test_default_data_parallel_rule.py +++ b/test/auto_parallel/spmd_rules/test_default_data_parallel_rule.py @@ -61,11 +61,8 @@ def test_default_dp_infer_forward(self): # 2 inputs 2 outputs, sharded batch axis in_vec = [self.x_dist_tensor_spec, self.y_dist_tensor_spec] out_vec = [self.out1_dist_tensor_spec, self.out2_dist_tensor_spec] - result_dist_attrs = self.rule.infer_forward( - [(0, len(in_vec)), (len(in_vec), len(in_vec) + len(out_vec))], - in_vec + out_vec, - [], - ) + result_dist_attrs = self.rule.infer_forward(in_vec, out_vec) + self.assertEqual(len(result_dist_attrs), 2) self.assertEqual(len(result_dist_attrs[0]), 2) self.assertEqual(len(result_dist_attrs[1]), 2) @@ -82,11 +79,9 @@ def test_default_dp_infer_forward(self): self.out1_dist_tensor_spec, self.out2_dist_tensor_spec, ] - result_dist_attrs = self.rule.infer_forward( - [(0, len(in_vec)), (len(in_vec), len(in_vec) + len(out_vec))], - in_vec + out_vec, - [], - ) + + result_dist_attrs = self.rule.infer_forward(in_vec, out_vec) + self.assertEqual(len(result_dist_attrs), 2) self.assertEqual(len(result_dist_attrs[0]), 1) self.assertEqual(len(result_dist_attrs[1]), 3) @@ -103,11 +98,7 @@ def test_default_dp_infer_forward(self): in_vec = [self.x_dist_tensor_spec, self.y_dist_tensor_spec] out_vec = [self.out1_dist_tensor_spec, self.out2_dist_tensor_spec] with self.assertRaises(NotImplementedError): - result_dist_attrs = self.rule.infer_forward( - [(0, len(in_vec)), (len(in_vec), len(in_vec) + len(out_vec))], - in_vec + out_vec, - [], - ) + result_dist_attrs = self.rule.infer_forward(in_vec, out_vec) def test_default_dp_infer_backward(self): # replicated out1 from [-1, 0, 1, -1] to [0, -1, -1, -1] @@ -117,11 +108,9 @@ def test_default_dp_infer_backward(self): in_vec = [self.x_dist_tensor_spec, self.y_dist_tensor_spec] out_vec = [self.out1_dist_tensor_spec, self.out2_dist_tensor_spec] - result_dist_attrs = self.rule.infer_backward( - [(0, len(in_vec)), (len(in_vec), len(in_vec) + len(out_vec))], - in_vec + out_vec, - [], - ) + + result_dist_attrs = self.rule.infer_backward(in_vec, out_vec) + self.assertEqual(len(result_dist_attrs), 2) self.assertEqual(len(result_dist_attrs[0]), 2) self.assertEqual(len(result_dist_attrs[1]), 2) @@ -137,11 +126,9 @@ def test_default_dp_infer_backward(self): in_vec = [self.x_dist_tensor_spec, self.y_dist_tensor_spec] out_vec = [self.out1_dist_tensor_spec, self.out2_dist_tensor_spec] - result_dist_attrs = self.rule.infer_backward( - [(0, len(in_vec)), (len(in_vec), len(in_vec) + len(out_vec))], - in_vec + out_vec, - [], - ) + + result_dist_attrs = self.rule.infer_backward(in_vec, out_vec) + self.assertEqual(len(result_dist_attrs), 2) self.assertEqual(len(result_dist_attrs[0]), 2) self.assertEqual(len(result_dist_attrs[1]), 2) diff --git a/test/auto_parallel/spmd_rules/test_matmul_rule.py b/test/auto_parallel/spmd_rules/test_matmul_rule.py index 1cf2f49860b33..d909bd2cfd3a1 100644 --- a/test/auto_parallel/spmd_rules/test_matmul_rule.py +++ b/test/auto_parallel/spmd_rules/test_matmul_rule.py @@ -49,9 +49,9 @@ def test_matmul_infer_forward(self): # TODO test partial: mk[1, 0],kn[0, -1] --> mk[1, 0],kn[0, -1] = nm[1, -1] partial[0] result_dist_attrs = self.rule.infer_forward( - [self.x_dist_tensor_spec, self.y_dist_tensor_spec], - list(self.attrs.values()), + self.x_dist_tensor_spec, self.y_dist_tensor_spec, False, False ) + infered_input_dist_attrs = result_dist_attrs[0] infered_output_dist_attrs = result_dist_attrs[1] @@ -68,10 +68,11 @@ def test_matmul_infer_forward(self): # test row parallel: mk[1, -1],kn[-1, -1] --> mk[1, -1],kn[-1, -1] = nm[1, -1] partial[] self.x_dist_tensor_spec.set_dims_mapping([1, -1]) self.y_dist_tensor_spec.set_dims_mapping([-1, -1]) + result_dist_attrs = self.rule.infer_forward( - [self.x_dist_tensor_spec, self.y_dist_tensor_spec], - list(self.attrs.values()), + self.x_dist_tensor_spec, self.y_dist_tensor_spec, False, False ) + infered_input_dist_attrs = result_dist_attrs[0] infered_output_dist_attrs = result_dist_attrs[1] @@ -83,9 +84,9 @@ def test_matmul_infer_forward(self): # test row parallel: mk[1, -1],kn[-1, -1] --> mk[1, -1],kn[-1, -1] = nm[1, -1] partial[] self.x_dist_tensor_spec.set_dims_mapping([1, -1]) self.y_dist_tensor_spec.set_dims_mapping([-1, -1]) + result_dist_attrs = self.rule.infer_forward( - [self.x_dist_tensor_spec, self.y_dist_tensor_spec], - list(self.attrs.values()), + self.x_dist_tensor_spec, self.y_dist_tensor_spec, False, False ) infered_input_dist_attrs = result_dist_attrs[0] infered_output_dist_attrs = result_dist_attrs[1] @@ -97,9 +98,9 @@ def test_matmul_infer_forward(self): # test n parallel: mk[-1, -1],kn[-1, 0] --> mk[-1, -1],kn[-1, 0] = nm[-1, 0] partial[] self.x_dist_tensor_spec.set_dims_mapping([-1, -1]) self.y_dist_tensor_spec.set_dims_mapping([-1, 0]) + result_dist_attrs = self.rule.infer_forward( - [self.x_dist_tensor_spec, self.y_dist_tensor_spec], - list(self.attrs.values()), + self.x_dist_tensor_spec, self.y_dist_tensor_spec, False, False ) infered_input_dist_attrs = result_dist_attrs[0] infered_output_dist_attrs = result_dist_attrs[1] @@ -111,9 +112,9 @@ def test_matmul_infer_forward(self): # test partial with propogation: mk[1, 0],kn[-1,-1] --> mk[1, 0],kn[0, -1] = nm[1, -1] partial[0] self.x_dist_tensor_spec.set_dims_mapping([1, 0]) self.y_dist_tensor_spec.set_dims_mapping([-1, -1]) + result_dist_attrs = self.rule.infer_forward( - [self.x_dist_tensor_spec, self.y_dist_tensor_spec], - list(self.attrs.values()), + self.x_dist_tensor_spec, self.y_dist_tensor_spec, False, False ) infered_input_dist_attrs = result_dist_attrs[0] infered_output_dist_attrs = result_dist_attrs[1] @@ -126,9 +127,9 @@ def test_matmul_infer_forward(self): # mk[-1,-1],kn[1,0] --> mk[-1, 1],kn[1, 0] = nm[-1, 0] partial[1]: self.x_dist_tensor_spec.set_dims_mapping([-1, -1]) self.y_dist_tensor_spec.set_dims_mapping([1, 0]) + result_dist_attrs = self.rule.infer_forward( - [self.x_dist_tensor_spec, self.y_dist_tensor_spec], - list(self.attrs.values()), + self.x_dist_tensor_spec, self.y_dist_tensor_spec, False, False ) infered_input_dist_attrs = result_dist_attrs[0] infered_output_dist_attrs = result_dist_attrs[1] @@ -142,9 +143,9 @@ def test_matmul_infer_forward(self): self.x_dist_tensor_spec.shape = [512, 48, 64, 32] self.x_dist_tensor_spec.set_dims_mapping([0, 1, -1, -1]) self.y_dist_tensor_spec.set_dims_mapping([-1, -1]) + result_dist_attrs = self.rule.infer_forward( - [self.x_dist_tensor_spec, self.y_dist_tensor_spec], - list(self.attrs.values()), + self.x_dist_tensor_spec, self.y_dist_tensor_spec, False, False ) infered_input_dist_attrs = result_dist_attrs[0] infered_output_dist_attrs = result_dist_attrs[1] @@ -160,10 +161,11 @@ def test_matmul_infer_forward(self): # abcmk[1, -1, -1, 0],kn[-1, -1] --> abcmk[1, -1, -1, 0],kn[0, -1] = abcmn[1,-1, -1, -1] partial[0] self.x_dist_tensor_spec.set_dims_mapping([1, -1, -1, 0]) self.y_dist_tensor_spec.set_dims_mapping([-1, -1]) + result_dist_attrs = self.rule.infer_forward( - [self.x_dist_tensor_spec, self.y_dist_tensor_spec], - list(self.attrs.values()), + self.x_dist_tensor_spec, self.y_dist_tensor_spec, False, False ) + infered_input_dist_attrs = result_dist_attrs[0] infered_output_dist_attrs = result_dist_attrs[1] self.assertEqual( @@ -179,11 +181,11 @@ def test_matmul_infer_forward(self): # trans_x = True, abcmk[1, -1, -1, 0], kn[-1, -1] --> abcmk[1, -1, -1, 0],kn[-1, -1] = abcmn[1, -1, 0, -1] partial[] self.x_dist_tensor_spec.set_dims_mapping([1, -1, -1, 0]) self.y_dist_tensor_spec.set_dims_mapping([-1, -1]) - self.attrs['trans_x'] = True + result_dist_attrs = self.rule.infer_forward( - [self.x_dist_tensor_spec, self.y_dist_tensor_spec], - list(self.attrs.values()), + self.x_dist_tensor_spec, self.y_dist_tensor_spec, True, False ) + infered_input_dist_attrs = result_dist_attrs[0] infered_output_dist_attrs = result_dist_attrs[1] self.assertEqual( @@ -198,12 +200,11 @@ def test_matmul_infer_forward(self): # trans_y = True, abcmk[-1, -1, -1, -1], kn[1, 0] --> abcmk[-1, -1, -1, 0],kn[1, 0] = abcmn[-1, -1, -1, 1] partial[0]: done self.x_dist_tensor_spec.set_dims_mapping([-1, -1, -1, -1]) self.y_dist_tensor_spec.set_dims_mapping([1, 0]) - self.attrs['trans_x'] = False - self.attrs['trans_y'] = True + result_dist_attrs = self.rule.infer_forward( - [self.x_dist_tensor_spec, self.y_dist_tensor_spec], - list(self.attrs.values()), + self.x_dist_tensor_spec, self.y_dist_tensor_spec, False, True ) + infered_input_dist_attrs = result_dist_attrs[0] infered_output_dist_attrs = result_dist_attrs[1] self.assertEqual( @@ -222,12 +223,11 @@ def test_matmul_infer_forward(self): # multiple mesh dim shard same tensor axis self.x_dist_tensor_spec.set_dims_mapping([-1, -1, 0, 1]) self.y_dist_tensor_spec.set_dims_mapping([1, 0]) - self.attrs['trans_x'] = True - self.attrs['trans_y'] = True + result_dist_attrs = self.rule.infer_forward( - [self.x_dist_tensor_spec, self.y_dist_tensor_spec], - list(self.attrs.values()), + self.x_dist_tensor_spec, self.y_dist_tensor_spec, True, True ) + infered_input_dist_attrs = result_dist_attrs[0] infered_output_dist_attrs = result_dist_attrs[1] self.assertEqual( @@ -249,9 +249,11 @@ def test_matmul_infer_forward(self): self.attrs['trans_x'] = True self.attrs['trans_y'] = True with self.assertRaises(NotImplementedError): - self.rule.infer_forward( - [self.x_dist_tensor_spec, self.y_dist_tensor_spec], - list(self.attrs.values()), + result_dist_attrs = self.rule.infer_forward( + self.x_dist_tensor_spec, + self.y_dist_tensor_spec, + self.attrs['trans_x'], + self.attrs['trans_y'], ) def test_matmul_infer_backward(self): @@ -280,10 +282,13 @@ def test_matmul_infer_backward(self): # mn[1, 0] --> mk[1, -1],kn[-1, 0] result_dist_attrs = self.rule.infer_backward( - [self.x_dist_tensor_spec, self.y_dist_tensor_spec], - [self.out_dist_tensor_spec], - list(self.attrs.values()), + self.x_dist_tensor_spec, + self.y_dist_tensor_spec, + self.out_dist_tensor_spec, + self.attrs['trans_x'], + self.attrs['trans_y'], ) + infered_input_dist_attrs = result_dist_attrs[0] infered_output_dist_attrs = result_dist_attrs[1] @@ -315,12 +320,14 @@ def test_matmul_infer_backward(self): ] ) # dims mapping of input should not influence inferbackward self.out_dist_tensor_spec.set_dims_mapping([1, 0, -1, -1]) - result_dist_attrs = self.rule.infer_backward( - [self.x_dist_tensor_spec, self.y_dist_tensor_spec], - [self.out_dist_tensor_spec], - list(self.attrs.values()), + self.x_dist_tensor_spec, + self.y_dist_tensor_spec, + self.out_dist_tensor_spec, + self.attrs['trans_x'], + self.attrs['trans_y'], ) + infered_input_dist_attrs = result_dist_attrs[0] infered_output_dist_attrs = result_dist_attrs[1] @@ -339,10 +346,13 @@ def test_matmul_infer_backward(self): self.out_dist_tensor_spec.set_dims_mapping([-1, 0, -1, 1]) result_dist_attrs = self.rule.infer_backward( - [self.x_dist_tensor_spec, self.y_dist_tensor_spec], - [self.out_dist_tensor_spec], - list(self.attrs.values()), + self.x_dist_tensor_spec, + self.y_dist_tensor_spec, + self.out_dist_tensor_spec, + self.attrs['trans_x'], + self.attrs['trans_y'], ) + infered_input_dist_attrs = result_dist_attrs[0] infered_output_dist_attrs = result_dist_attrs[1] @@ -363,11 +373,15 @@ def test_matmul_infer_backward(self): self.out_dist_tensor_spec.set_dims_mapping([-1, -1, 0, 1]) self.attrs['trans_x'] = True self.attrs['trans_y'] = True + result_dist_attrs = self.rule.infer_backward( - [self.x_dist_tensor_spec, self.y_dist_tensor_spec], - [self.out_dist_tensor_spec], - list(self.attrs.values()), + self.x_dist_tensor_spec, + self.y_dist_tensor_spec, + self.out_dist_tensor_spec, + self.attrs['trans_x'], + self.attrs['trans_y'], ) + infered_input_dist_attrs = result_dist_attrs[0] infered_output_dist_attrs = result_dist_attrs[1] @@ -385,10 +399,12 @@ def test_matmul_infer_backward(self): # one mesh dim shard multiple tensor axes self.out_dist_tensor_spec.set_dims_mapping([-1, 1, 0, 1]) with self.assertRaises(RuntimeError): - self.rule.infer_backward( - [self.x_dist_tensor_spec, self.y_dist_tensor_spec], - [self.out_dist_tensor_spec], - list(self.attrs.values()), + result_dist_attrs = self.rule.infer_backward( + self.x_dist_tensor_spec, + self.y_dist_tensor_spec, + self.out_dist_tensor_spec, + self.attrs['trans_x'], + self.attrs['trans_y'], ) diff --git a/test/auto_parallel/spmd_rules/test_replicated_rule.py b/test/auto_parallel/spmd_rules/test_replicated_rule.py index d09d99ffe6a1f..7b2f406627ee7 100644 --- a/test/auto_parallel/spmd_rules/test_replicated_rule.py +++ b/test/auto_parallel/spmd_rules/test_replicated_rule.py @@ -60,11 +60,9 @@ def test_replicated_infer_forward(self): # 2 inputs 2 outputs in_vec = [self.x_dist_tensor_spec, self.y_dist_tensor_spec] out_vec = [self.out1_dist_tensor_spec, self.out2_dist_tensor_spec] - result_dist_attrs = self.rule.infer_forward( - [(0, len(in_vec)), (len(in_vec), len(in_vec) + len(out_vec))], - in_vec + out_vec, - [], - ) + + result_dist_attrs = self.rule.infer_forward(in_vec, out_vec) + self.assertEqual(len(result_dist_attrs), 2) self.assertEqual(len(result_dist_attrs[0]), 2) self.assertEqual(len(result_dist_attrs[1]), 2) @@ -77,11 +75,9 @@ def test_replicated_infer_forward(self): # 1 inputs 2 outputs in_vec = [self.y_dist_tensor_spec] out_vec = [self.out1_dist_tensor_spec, self.out2_dist_tensor_spec] - result_dist_attrs = self.rule.infer_forward( - [(0, len(in_vec)), (len(in_vec), len(in_vec) + len(out_vec))], - in_vec + out_vec, - [], - ) + + result_dist_attrs = self.rule.infer_forward(in_vec, out_vec) + self.assertEqual(len(result_dist_attrs), 2) self.assertEqual(len(result_dist_attrs[0]), 1) self.assertEqual(len(result_dist_attrs[1]), 2) @@ -99,11 +95,9 @@ def test_replicated_infer_backward(self): in_vec = [self.x_dist_tensor_spec, self.y_dist_tensor_spec] out_vec = [self.out1_dist_tensor_spec, self.out2_dist_tensor_spec] - result_dist_attrs = self.rule.infer_backward( - [(0, len(in_vec)), (len(in_vec), len(in_vec) + len(out_vec))], - in_vec + out_vec, - [], - ) + + result_dist_attrs = self.rule.infer_backward(in_vec, out_vec) + self.assertEqual(len(result_dist_attrs), 2) self.assertEqual(len(result_dist_attrs[0]), 2) self.assertEqual(len(result_dist_attrs[1]), 2) @@ -120,11 +114,9 @@ def test_replicated_infer_backward(self): self.out1_dist_tensor_spec, self.out2_dist_tensor_spec, ] - result_dist_attrs = self.rule.infer_backward( - [(0, len(in_vec)), (len(in_vec), len(in_vec) + len(out_vec))], - in_vec + out_vec, - [], - ) + + result_dist_attrs = self.rule.infer_backward(in_vec, out_vec) + self.assertEqual(len(result_dist_attrs), 2) self.assertEqual(len(result_dist_attrs[0]), 1) self.assertEqual(len(result_dist_attrs[1]), 3)