Skip to content

Commit

Permalink
【Prim】Refactor prim flags system (#49930)
Browse files Browse the repository at this point in the history
  • Loading branch information
JiabinYang authored Jan 20, 2023
1 parent 44855da commit 23d20e3
Show file tree
Hide file tree
Showing 49 changed files with 339 additions and 206 deletions.
4 changes: 2 additions & 2 deletions paddle/fluid/eager/auto_code_generator/generator/eager_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1841,7 +1841,7 @@ def GenerateHigherOrderNodeCreationCode(self):

if is_composite_grad_api and next_grad_node_creation_str != '':
next_grad_node_creation_str = f"""
if (!paddle::prim::PrimCommonUtils::IsPrimEnabled()) {{
if (!paddle::prim::PrimCommonUtils::IsBwdPrimEnabled()) {{
{next_grad_node_creation_str}
}}
"""
Expand Down Expand Up @@ -2261,7 +2261,7 @@ def GenerateNodeDefinition(
# TODO(Ruting):using composite only when we don't have backward kernel in the future.
elif is_composite_grad_api:
grad_function_call_str = f"""
if (paddle::prim::PrimCommonUtils::IsPrimEnabled()) {{
if (paddle::prim::PrimCommonUtils::IsBwdPrimEnabled()) {{
{indent}{composite_grad_api_namespace}{composite_grad_api_name}{composite_template_name}({composite_grad_api_args_str});
VLOG(4) << "Composite api {composite_grad_api_name} is called ";
}}else{{
Expand Down
12 changes: 11 additions & 1 deletion paddle/fluid/prim/api/generated/prim_api/static_prim_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include <string.h>
#include <memory>
#include <sstream>
#include <string>
Expand Down Expand Up @@ -166,7 +167,16 @@ Tensor full<DescTensor>(const IntArray& shape,
phi::errors::InvalidArgument(
"We only support float32/float16 for full, but we got data type: %s",
phi::DataTypeToString(dtype)));
op->SetAttr("value", value.to<float>());
if (dtype == phi::DataType::FLOAT32) {
op->SetAttr("value", value.to<float>());
} else if (dtype == phi::DataType::FLOAT64) {
op->SetAttr("str_value", std::to_string(value.to<double>()));
} else if (dtype == phi::DataType::FLOAT16) {
op->SetAttr("str_value", std::to_string(value.to<float>()));
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"We only support float64/float32/float16 for full"));
}
op->SetAttr("dtype", paddle::framework::TransToProtoVarType(dtype));
op->SetOutput(
"Out", {std::static_pointer_cast<prim::DescTensor>(out.impl())->Name()});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ void divide_grad(const Tensor& x,
} // indicate we will compute dy
if (dx) {
// dx = (1/y) * dout
auto one_tensor = full<T>(phi::vectorize(y.dims()), 1.0);
auto one_tensor = full<T>(phi::vectorize(y.dims()), 1.0, y.dtype());
auto tmp0 = divide<T>(one_tensor, y);
auto dx_res = multiply<T>(tmp0, out_grad);
if (y.dims() != x.dims()) {
Expand Down
16 changes: 8 additions & 8 deletions paddle/fluid/prim/tests/test_eager_prim.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,16 +68,16 @@ TEST(EagerPrim, TanhBackwardTest) {
paddle::experimental::Tensor out0 = tanh_ad_func(tensor0);
std::vector<paddle::experimental::Tensor> outs0 = {out0};
// Disable prim
PrimCommonUtils::SetPrimEnabled(false);
ASSERT_FALSE(PrimCommonUtils::IsPrimEnabled());
PrimCommonUtils::SetBwdPrimEnabled(false);
ASSERT_FALSE(PrimCommonUtils::IsBwdPrimEnabled());
// 4. Run Backward
egr::Backward(outs0, {}, false);

paddle::experimental::Tensor out1 = tanh_ad_func(tensor1);
std::vector<paddle::experimental::Tensor> outs1 = {out1};
// Disable prim
PrimCommonUtils::SetPrimEnabled(true);
ASSERT_TRUE(PrimCommonUtils::IsPrimEnabled());
PrimCommonUtils::SetBwdPrimEnabled(true);
ASSERT_TRUE(PrimCommonUtils::IsBwdPrimEnabled());
// 4. Run Backward
::egr::Backward(outs1, {}, false);
VLOG(7)
Expand All @@ -99,10 +99,10 @@ TEST(EagerPrim, TanhBackwardTest) {
}

TEST(EagerPrim, TestFlags) {
PrimCommonUtils::SetPrimEnabled(true);
ASSERT_TRUE(PrimCommonUtils::IsPrimEnabled());
PrimCommonUtils::SetPrimEnabled(false);
ASSERT_FALSE(PrimCommonUtils::IsPrimEnabled());
PrimCommonUtils::SetBwdPrimEnabled(true);
ASSERT_TRUE(PrimCommonUtils::IsBwdPrimEnabled());
PrimCommonUtils::SetBwdPrimEnabled(false);
ASSERT_FALSE(PrimCommonUtils::IsBwdPrimEnabled());
}

} // namespace prim
Expand Down
8 changes: 4 additions & 4 deletions paddle/fluid/prim/tests/test_static_prim.cc
Original file line number Diff line number Diff line change
Expand Up @@ -341,10 +341,10 @@ TEST(StaticCompositeGradMaker, TestMutiOutputMethod) {
}

TEST(StaticPrim, TestFlags) {
PrimCommonUtils::SetPrimEnabled(true);
ASSERT_TRUE(PrimCommonUtils::IsPrimEnabled());
PrimCommonUtils::SetPrimEnabled(false);
ASSERT_FALSE(PrimCommonUtils::IsPrimEnabled());
PrimCommonUtils::SetBwdPrimEnabled(true);
ASSERT_TRUE(PrimCommonUtils::IsBwdPrimEnabled());
PrimCommonUtils::SetBwdPrimEnabled(false);
ASSERT_FALSE(PrimCommonUtils::IsBwdPrimEnabled());
}

} // namespace prim
Expand Down
3 changes: 2 additions & 1 deletion paddle/fluid/prim/utils/static/static_global_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ namespace paddle {
namespace prim {
StaticCompositeContext* StaticCompositeContext::static_composite_context_ =
new StaticCompositeContext();
thread_local bool StaticCompositeContext::enable_prim_ = false;
thread_local bool StaticCompositeContext::enable_bwd_prim_ = false;
thread_local bool StaticCompositeContext::enable_fwd_prim_ = false;
} // namespace prim
} // namespace paddle
16 changes: 13 additions & 3 deletions paddle/fluid/prim/utils/static/static_global_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,17 +56,27 @@ class StaticCompositeContext {
return generator_->Generate(key);
}

void SetPrimEnabled(bool enable_prim) { enable_prim_ = enable_prim; }
void SetBwdPrimEnabled(bool enable_prim) { enable_bwd_prim_ = enable_prim; }

bool IsPrimEnabled() { return enable_prim_; }
bool IsBwdPrimEnabled() { return enable_bwd_prim_; }

void SetFwdPrimEnabled(bool enable_prim) { enable_fwd_prim_ = enable_prim; }

bool IsFwdPrimEnabled() { return enable_fwd_prim_; }

void SetAllPrimEnabled(bool enable_prim) {
enable_fwd_prim_ = enable_prim;
enable_bwd_prim_ = enable_prim;
}

private:
StaticCompositeContext()
: current_block_desc_(nullptr), generator_(new UniqueNameGenerator()) {}

framework::BlockDesc* current_block_desc_;
std::unique_ptr<UniqueNameGenerator> generator_;
static thread_local bool enable_prim_;
static thread_local bool enable_bwd_prim_;
static thread_local bool enable_fwd_prim_;
static StaticCompositeContext* static_composite_context_;
DISABLE_COPY_AND_ASSIGN(StaticCompositeContext);
};
Expand Down
20 changes: 16 additions & 4 deletions paddle/fluid/prim/utils/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,24 @@
PADDLE_DEFINE_EXPORTED_bool(prim_enabled, false, "enable_prim or not");
namespace paddle {
namespace prim {
bool PrimCommonUtils::IsPrimEnabled() {
return StaticCompositeContext::Instance().IsPrimEnabled();
bool PrimCommonUtils::IsBwdPrimEnabled() {
return StaticCompositeContext::Instance().IsBwdPrimEnabled();
}

void PrimCommonUtils::SetPrimEnabled(bool enable_prim) {
return StaticCompositeContext::Instance().SetPrimEnabled(enable_prim);
void PrimCommonUtils::SetBwdPrimEnabled(bool enable_prim) {
return StaticCompositeContext::Instance().SetBwdPrimEnabled(enable_prim);
}

bool PrimCommonUtils::IsFwdPrimEnabled() {
return StaticCompositeContext::Instance().IsFwdPrimEnabled();
}

void PrimCommonUtils::SetFwdPrimEnabled(bool enable_prim) {
return StaticCompositeContext::Instance().SetFwdPrimEnabled(enable_prim);
}

void PrimCommonUtils::SetAllPrimEnabled(bool enable_prim) {
return StaticCompositeContext::Instance().SetAllPrimEnabled(enable_prim);
}
} // namespace prim
} // namespace paddle
7 changes: 5 additions & 2 deletions paddle/fluid/prim/utils/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,11 @@ namespace paddle {
namespace prim {
class PrimCommonUtils {
public:
static bool IsPrimEnabled();
static void SetPrimEnabled(bool enabled);
static bool IsBwdPrimEnabled();
static void SetBwdPrimEnabled(bool enabled);
static bool IsFwdPrimEnabled();
static void SetFwdPrimEnabled(bool enabled);
static void SetAllPrimEnabled(bool enabled);
};
} // namespace prim
} // namespace paddle
15 changes: 12 additions & 3 deletions paddle/fluid/pybind/pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -660,8 +660,16 @@ PYBIND11_MODULE(libpaddle, m) {
return oss.str();
});

m.def("set_prim_enabled", &paddle::prim::PrimCommonUtils::SetPrimEnabled);
m.def("is_prim_enabled", &paddle::prim::PrimCommonUtils::IsPrimEnabled);
m.def("__set_bwd_prim_enabled",
&paddle::prim::PrimCommonUtils::SetBwdPrimEnabled);
m.def("_is_bwd_prim_enabled",
&paddle::prim::PrimCommonUtils::IsBwdPrimEnabled);
m.def("__set_fwd_prim_enabled",
&paddle::prim::PrimCommonUtils::SetFwdPrimEnabled);
m.def("_is_fwd_prim_enabled",
&paddle::prim::PrimCommonUtils::IsFwdPrimEnabled);
m.def("__set_all_prim_enabled",
&paddle::prim::PrimCommonUtils::SetAllPrimEnabled);
m.def("set_num_threads", &platform::SetNumThreads);

m.def("disable_signal_handler", &DisableSignalHandler);
Expand Down Expand Up @@ -1264,8 +1272,9 @@ All parameter, weight, gradient are variables in Paddle.
// priority of GradCompOpMaker is less than GradCompMaker for better
// performance.
std::vector<std::unique_ptr<OpDesc>> grad_op_descs;
if (paddle::prim::PrimCommonUtils::IsPrimEnabled()) {
if (paddle::prim::PrimCommonUtils::IsBwdPrimEnabled()) {
if (grad_comp_op_maker != nullptr) {
VLOG(3) << "Runing composite fun for " << op_desc.Type();
grad_op_descs = grad_comp_op_maker(op_desc,
no_grad_set,
&grad_to_var,
Expand Down
6 changes: 3 additions & 3 deletions paddle/phi/api/yaml/legacy_backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
kernel :
func : add_grad
no_need_buffer : x, y
composite : add_grad(Tensor x, Tensor y, Tensor out_grad, int axis)
composite : add_grad(x, y, out_grad, axis)
backward : add_double_grad
inplace : (out_grad -> x_grad)

Expand Down Expand Up @@ -390,7 +390,7 @@
param : [x, y]
kernel :
func : divide_grad
composite : divide_grad(Tensor x, Tensor y, Tensor out, Tensor out_grad, int axis = -1)
composite : divide_grad(x, y, out, out_grad, -1)
backward : divide_double_grad

- backward_op : dropout_grad
Expand Down Expand Up @@ -1319,7 +1319,7 @@
kernel :
func : subtract_grad
no_need_buffer : x, y
composite : subtract_grad(Tensor x, Tensor y, Tensor out_grad, int axis)
composite : subtract_grad(x, y, out_grad, axis)
backward : subtract_double_grad
inplace : (out_grad -> x_grad)

Expand Down
5 changes: 3 additions & 2 deletions python/paddle/fluid/backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1493,14 +1493,15 @@ def update_distop_context(

# remove some backward ops
# TODO(Jiabin): Support this in prime later, it will prune add_grad, fix this problem
if not core.is_prim_enabled():
if not core._is_bwd_prim_enabled():
not_need_ops = _find_not_need_ops(
grad_op_descs, ops, input_grad_names_set
)

grad_op_descs = [
op_desc for op_desc in grad_op_descs if op_desc not in not_need_ops
]
else:
logging.debug("Runing backward composite and disable find_not_need_ops")

# append op_desc in grad_op_descs to target_block
op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName()
Expand Down
Loading

0 comments on commit 23d20e3

Please sign in to comment.