diff --git a/paddle/fluid/prim/utils/utils.cc b/paddle/fluid/prim/utils/utils.cc index 87a2fef943f41..fcee9301a9aa7 100644 --- a/paddle/fluid/prim/utils/utils.cc +++ b/paddle/fluid/prim/utils/utils.cc @@ -17,13 +17,15 @@ #include "paddle/fluid/prim/utils/static/static_global_utils.h" PADDLE_DEFINE_EXPORTED_bool(prim_enabled, false, "enable_prim or not"); -PADDLE_DEFINE_EXPORTED_string(prim_blacklist, "", "prim ops blacklist"); +PADDLE_DEFINE_EXPORTED_bool(prim_all, false, "enable prim_all or not"); +PADDLE_DEFINE_EXPORTED_bool(prim_forward, false, "enable prim_forward or not"); +PADDLE_DEFINE_EXPORTED_bool(prim_backward, false, "enable prim_backward not"); namespace paddle { namespace prim { - bool PrimCommonUtils::IsBwdPrimEnabled() { - return StaticCompositeContext::Instance().IsBwdPrimEnabled(); + bool res = StaticCompositeContext::Instance().IsBwdPrimEnabled(); + return res || FLAGS_prim_all || FLAGS_prim_backward; } void PrimCommonUtils::SetBwdPrimEnabled(bool enable_prim) { @@ -39,16 +41,15 @@ void PrimCommonUtils::SetEagerPrimEnabled(bool enable_prim) { } bool PrimCommonUtils::IsFwdPrimEnabled() { - return StaticCompositeContext::Instance().IsFwdPrimEnabled(); + bool res = StaticCompositeContext::Instance().IsFwdPrimEnabled(); + return res || FLAGS_prim_all || FLAGS_prim_forward; } void PrimCommonUtils::SetFwdPrimEnabled(bool enable_prim) { - VLOG(0) << "FLAGS_prim_enabled ====================== " << FLAGS_prim_enabled; StaticCompositeContext::Instance().SetFwdPrimEnabled(enable_prim); } void PrimCommonUtils::SetAllPrimEnabled(bool enable_prim) { - VLOG(0) << "FLAGS_prim_enabled ====================== " << FLAGS_prim_enabled; StaticCompositeContext::Instance().SetAllPrimEnabled(enable_prim); }