diff --git a/paddle/fluid/framework/details/scale_loss_grad_op_handle.h b/paddle/fluid/framework/details/scale_loss_grad_op_handle.h index 88c87472d0318..d447d0284d83a 100644 --- a/paddle/fluid/framework/details/scale_loss_grad_op_handle.h +++ b/paddle/fluid/framework/details/scale_loss_grad_op_handle.h @@ -44,6 +44,8 @@ struct ScaleLossGradOpHandle : public OpHandleBase { ~ScaleLossGradOpHandle() final; + proto::VarType::Type DType() const { return out_dtype_; } + std::string Name() const override; platform::Place GetPlace() const { return place_; } diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index d31555bf7247c..321317035243b 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -67,7 +67,7 @@ cc_library( cc_library( graph_helper SRCS graph_helper.cc - DEPS graph) + DEPS graph scale_loss_grad_op_handle) cc_library( pass SRCS pass.cc diff --git a/paddle/fluid/framework/ir/graph_helper.cc b/paddle/fluid/framework/ir/graph_helper.cc index 8afd20d6a00b5..d6bd2e4a80d0b 100644 --- a/paddle/fluid/framework/ir/graph_helper.cc +++ b/paddle/fluid/framework/ir/graph_helper.cc @@ -18,6 +18,7 @@ limitations under the License. */ #include #include "paddle/fluid/framework/details/multi_devices_helper.h" +#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h" #include "paddle/fluid/framework/op_proto_maker.h" DECLARE_bool(convert_all_blocks); @@ -469,11 +470,23 @@ void RemoveControlDepInputAndOuput(OpDesc *op_desc) { static OpDesc *ReplaceScaleLossGradOp(const Node &node, OpDesc *desc) { desc->SetType("fill_constant"); + desc->SetAttr("shape", std::vector({1})); + desc->SetAttr("value", 1.0f); + + if (node.IsWrappedBy()) { + details::OpHandleBase &op_hander = + const_cast(&node)->Wrapper(); + desc->SetAttr( + "dtype", + dynamic_cast(&op_hander)->DType()); + } + + desc->SetAttr("force_cpu", false); desc->SetAttr( OpProtoAndCheckerMaker::OpRoleAttrName(), (static_cast(OpRole::kBackward) | static_cast(OpRole::kLoss))); - desc->SetAttr("value", 1.0f); - desc->SetAttr("shape", std::vector({1})); + // TODO(Ruibiao) : Set OpDeviceAttrName when needed + std::vector output_names; for (auto out : node.outputs) { output_names.emplace_back(out->Name()); @@ -503,6 +516,7 @@ static void GetGraphOpDesc(const std::vector &nodes, // create fill_constant op if (n->Name() == "scale_loss_grad") { + VLOG(4) << "convert op node scale_loss_grad to desc fill_constant"; ops->emplace_back(); auto &desc = ops->back(); ReplaceScaleLossGradOp(*n, &desc); diff --git a/python/paddle/fluid/tests/custom_op/CMakeLists.txt b/python/paddle/fluid/tests/custom_op/CMakeLists.txt index f21fc730fc8de..2addead40fc22 100644 --- a/python/paddle/fluid/tests/custom_op/CMakeLists.txt +++ b/python/paddle/fluid/tests/custom_op/CMakeLists.txt @@ -7,6 +7,8 @@ if(WITH_GPU OR APPLE) # Compiling shared library will cost some time, but running process is very fast. set_tests_properties(test_custom_relu_op_setup PROPERTIES TIMEOUT 250) + set_tests_properties(test_custom_relu_op_setup + PROPERTIES ENVIRONMENT FLAGS_CONVERT_GRAPH_TO_PROGRAM=1) set_tests_properties(test_custom_relu_op_jit PROPERTIES TIMEOUT 180) set_tests_properties(test_custom_relu_model PROPERTIES TIMEOUT 180) set_tests_properties(test_context_pool PROPERTIES TIMEOUT 180)