Skip to content

Commit

Permalink
Add check for non-dispensable input (PaddlePaddle#28666)
Browse files Browse the repository at this point in the history
* Add check for non-dispensable input

* fix typo
  • Loading branch information
zhiqiu authored Nov 18, 2020
1 parent 19226ba commit 3d09929
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 5 deletions.
16 changes: 14 additions & 2 deletions paddle/fluid/pybind/op_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,15 @@ namespace pybind {

static inline std::shared_ptr<imperative::VarBase> CastPyHandleToVarBase(
const std::string& op_type, const std::string& arg_name, int arg_idx,
const py::handle& handle) {
const py::handle& handle, bool dispensable = false) {
PyObject* py_obj = handle.ptr(); // get underlying PyObject
if (!py_obj || py_obj == Py_None) {
if (!dispensable) {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument '%s' (position %d) must be Tensor, but got "
"%s",
op_type, arg_name, arg_idx, Py_TYPE(py_obj)->tp_name));
}
return nullptr;
}
try {
Expand All @@ -54,9 +60,15 @@ static inline std::shared_ptr<imperative::VarBase> CastPyHandleToVarBase(
static inline std::vector<std::shared_ptr<imperative::VarBase>>
CastPyHandleToVarBaseList(const std::string& op_type,
const std::string& arg_name, int arg_idx,
const py::handle& handle) {
const py::handle& handle, bool dispensable = false) {
PyObject* py_obj = handle.ptr(); // get underlying PyObject
if (!py_obj || py_obj == Py_None) {
if (!dispensable) {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument '%s' (position %d) must be Tensor, but got "
"%s",
op_type, arg_name, arg_idx, Py_TYPE(py_obj)->tp_name));
}
return {};
}
std::vector<std::shared_ptr<imperative::VarBase>> result;
Expand Down
7 changes: 4 additions & 3 deletions paddle/fluid/pybind/op_function_generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -166,10 +166,10 @@ const char* OUT_VAR_TYPE = R"(std::shared_ptr<imperative::VarBase>)";
const char* OUT_VAR_LIST_TYPE = R"(std::vector<std::shared_ptr<imperative::VarBase>>)";

const char* CAST_VAR_TEMPLATE = R"(
auto %s = CastPyHandleToVarBase("%s", "%s", %d, %s);)";
auto %s = CastPyHandleToVarBase("%s", "%s", %d, %s, %s);)";

const char* CAST_VAR_LIST_TEMPLATE = R"(
auto %s = CastPyHandleToVarBaseList("%s", "%s", %d, %s);)";
auto %s = CastPyHandleToVarBaseList("%s", "%s", %d, %s, %s);)";


const char* ARG_TEMPLATE = R"(const %s& %s)";
Expand Down Expand Up @@ -263,9 +263,10 @@ GenerateOpFunctions(const std::string& module_name) {
input_args_num++;
const auto in_cast_type =
input.duplicable() ? CAST_VAR_LIST_TEMPLATE : CAST_VAR_TEMPLATE;
auto dispensable = input.dispensable() ? "true" : "false";
ins_cast_str +=
paddle::string::Sprintf(in_cast_type, in_name, op_type, in_name,
arg_idx++, TempName(in_name));
arg_idx++, TempName(in_name), dispensable);

if (input.dispensable()) {
const auto in_template = input.duplicable()
Expand Down

0 comments on commit 3d09929

Please sign in to comment.