Skip to content

Commit

Permalink
[PRIM][IR]support add vjp (#56163)
Browse files Browse the repository at this point in the history
* [prim][newir] add basic framework for primitive

* support desctensor in new ir

* add vjp interface

* support vjp in new ir

* support vjp in new ir

* polish vjp interface

* fix stop_gradients set

* fix vjp dispatch

* add comment

* add vjp test for new ir

* add test for tanh vjp

* [prim][newir] add basic framework for primitive

* support desctensor in new ir

* support vjp in new ir

* support vjp in new ir

* polish vjp interface

* fix stop_gradients set

* fix vjp dispatch

* add comment

* add vjp test for new ir

* add test for tanh vjp

* add eager and static backend for warp lower level api

* support call_vjp pybind

* polish code and add test for vjp

* remove useless code

* polish code

* remove useless code

* support mean vjp

* add test for mean vjp and support has_vjp function

* fix call_vjp

* polish code

* add primitive ops set for backend

* add vjp test for tanh_

* fix inference CI

* fix inference ci

* modify fluid cmake

* remove useless deps

* add cmake

* fix comment

* fix test

* polish code

* modify backward stop_gradients

* modify static_backend.cc

* support add and add_inplace vjp

* remove useless code

* remove useless code

* remove cout

* remove cout

* fix add_grad

* fix add test exe

---------

Co-authored-by: cxxly <chenxx_id@163.com>
Co-authored-by: zhangbo9674 <zhangbo54@baidu.com>
  • Loading branch information
3 people authored Aug 16, 2023
1 parent d0224b8 commit 84482da
Show file tree
Hide file tree
Showing 7 changed files with 259 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,4 @@
# TODO(wanghao107)
# remove this file and support Vjp methods
# code gen.
vjp_interface_gen_op_list = ["tanh", "mean"]
vjp_interface_gen_op_list = ["tanh", "mean", "add"]
50 changes: 50 additions & 0 deletions paddle/fluid/ir/dialect/pd_op_vjp_manual.cc
Original file line number Diff line number Diff line change
Expand Up @@ -98,5 +98,55 @@ std::vector<std::vector<ir::OpResult>> MeanOp::Vjp(
}
return res;
}

std::vector<std::vector<ir::OpResult>> AddOp::Vjp(
ir::Operation* op,
const std::vector<std::vector<ir::OpResult>>& out_grads,
const std::vector<std::vector<bool>>& stop_gradients) {
AddOp op_obj = op->dyn_cast<AddOp>();
Tensor x(std::make_shared<primitive::experimental::DescTensor>(op_obj.x()));
Tensor y(std::make_shared<primitive::experimental::DescTensor>(op_obj.y()));
Tensor out_grad(
std::make_shared<primitive::experimental::DescTensor>(out_grads[0][0]));
int axis = -1;

std::vector<std::vector<Tensor>> tensor_res =
primitive::experimental::add_vjp(x, y, out_grad, axis, stop_gradients);
std::vector<std::vector<ir::OpResult>> res(2, std::vector<ir::OpResult>(1));
for (size_t i = 0; i < 2; ++i) {
if (tensor_res[i][0].defined()) {
res[i][0] = std::static_pointer_cast<primitive::experimental::DescTensor>(
tensor_res[i][0].impl())
->getValue()
.dyn_cast<ir::OpResult>();
}
}
return res;
}

std::vector<std::vector<ir::OpResult>> Add_Op::Vjp(
ir::Operation* op,
const std::vector<std::vector<ir::OpResult>>& out_grads,
const std::vector<std::vector<bool>>& stop_gradients) {
Add_Op op_obj = op->dyn_cast<Add_Op>();
Tensor x(std::make_shared<primitive::experimental::DescTensor>(op_obj.x()));
Tensor y(std::make_shared<primitive::experimental::DescTensor>(op_obj.y()));
Tensor out_grad(
std::make_shared<primitive::experimental::DescTensor>(out_grads[0][0]));
int axis = -1;

std::vector<std::vector<Tensor>> tensor_res =
primitive::experimental::add_vjp(x, y, out_grad, axis, stop_gradients);
std::vector<std::vector<ir::OpResult>> res(2, std::vector<ir::OpResult>(1));
for (size_t i = 0; i < 2; ++i) {
if (tensor_res[i][0].defined()) {
res[i][0] = std::static_pointer_cast<primitive::experimental::DescTensor>(
tensor_res[i][0].impl())
->getValue()
.dyn_cast<ir::OpResult>();
}
}
return res;
}
} // namespace dialect
} // namespace paddle
25 changes: 25 additions & 0 deletions paddle/fluid/primitive/backend/static_backend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,31 @@ Tensor mean_grad<DescTensor>(const Tensor& x,
return Tensor(std::make_shared<primitive::experimental::DescTensor>(op_res));
}

template <>
std::tuple<Tensor, Tensor> add_grad<DescTensor>(const Tensor& x,
const Tensor& y,
const Tensor& out_grad,
int axis) {
ir::OpResult x_res = std::static_pointer_cast<DescTensor>(x.impl())
->getValue()
.dyn_cast<ir::OpResult>();
ir::OpResult y_res = std::static_pointer_cast<DescTensor>(y.impl())
->getValue()
.dyn_cast<ir::OpResult>();
ir::OpResult out_grad_res =
std::static_pointer_cast<DescTensor>(out_grad.impl())
->getValue()
.dyn_cast<ir::OpResult>();

std::tuple<ir::OpResult, ir::OpResult> op_res =
paddle::dialect::add_grad(x_res, y_res, out_grad_res, axis);

return std::make_tuple(
Tensor(std::make_shared<primitive::experimental::DescTensor>(
std::get<0>(op_res))),
Tensor(std::make_shared<primitive::experimental::DescTensor>(
std::get<1>(op_res))));
}
} // namespace experimental
} // namespace backend
} // namespace primitive
Expand Down
6 changes: 6 additions & 0 deletions paddle/fluid/primitive/backend/static_backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ Tensor mean_grad(const Tensor& x,
const IntArray& axis = {},
bool keepdim = false,
bool reduce_all = false);

template <typename T>
std::tuple<Tensor, Tensor> add_grad(const Tensor& x,
const Tensor& y,
const Tensor& out_grad,
int axis);
} // namespace experimental
} // namespace backend
} // namespace primitive
Expand Down
41 changes: 41 additions & 0 deletions paddle/fluid/primitive/rule/vjp/vjp.cc
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,47 @@ std::vector<std::vector<paddle::Tensor>> mean_vjp(
return vjp_res;
}

std::vector<std::vector<paddle::Tensor>> add_vjp(
const Tensor& x,
const Tensor& y,
const Tensor& out_grad,
int axis,
const std::vector<std::vector<bool>>& stop_gradients) {
std::vector<std::vector<paddle::Tensor>> vjp_res(
2, std::vector<paddle::Tensor>(1));
// get mean_grad res.
std::tuple<Tensor, Tensor> op_res =
backend::experimental::add_grad<primitive::experimental::DescTensor>(
x, y, out_grad, axis);

// set op stop_gradient info
// TODO(wanghao107): Replace with more generic code.
// Support set stop_gradients for all ops.
ir::Operation* grad_op =
std::static_pointer_cast<primitive::experimental::DescTensor>(
std::get<0>(op_res).impl())
->getValue()
.dyn_cast<ir::OpResult>()
.owner();
std::vector<ir::Attribute> ir_stop_gradients(2);
for (size_t i = 0; i < 2; i++) {
if (stop_gradients[i][0]) {
ir_stop_gradients[i] =
ir::BoolAttribute::get(ir::IrContext::Instance(), true);
} else {
ir_stop_gradients[i] =
ir::BoolAttribute::get(ir::IrContext::Instance(), false);
}
}
grad_op->set_attribute(
"stop_gradient",
ir::ArrayAttribute::get(ir::IrContext::Instance(), ir_stop_gradients));

// construct vjp result by op result and stop_gradients info
vjp_res[0][0] = !stop_gradients[0][0] ? std::get<0>(op_res) : vjp_res[0][0];
vjp_res[1][0] = !stop_gradients[1][0] ? std::get<1>(op_res) : vjp_res[1][0];
return vjp_res;
}
} // namespace experimental
} // namespace primitive
} // namespace paddle
7 changes: 7 additions & 0 deletions paddle/fluid/primitive/rule/vjp/vjp.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,13 @@ std::vector<std::vector<paddle::Tensor>> mean_vjp(
bool reduce_all,
const std::vector<std::vector<bool>>& stop_gradients);

std::vector<std::vector<paddle::Tensor>> add_vjp(
const Tensor& x,
const Tensor& y,
const Tensor& out_grad,
int axis,
const std::vector<std::vector<bool>>& stop_gradients);

namespace details {
// NOTE: this namespace will store
// primitive ops grad composite rules.
Expand Down
129 changes: 129 additions & 0 deletions test/cpp/prim/test_vjp.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ PD_DECLARE_KERNEL(tanh, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(mean, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(tanh_grad, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(mean_grad, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(add_grad, CPU, ALL_LAYOUT);

namespace paddle {
namespace framework {
Expand Down Expand Up @@ -204,5 +205,133 @@ TEST(VJP, MeanBackwardTest) {
ASSERT_EQ(grad_out_tensor.data<float>()[3], 0.25);
}

TEST(VJP, AddBackwardTest) {
ir::IrContext* ctx = ir::IrContext::Instance();
ir::Program program((ctx));
paddle::dialect::APIBuilder::Instance().SetProgram(&program);

std::shared_ptr<ir::Builder> builder =
paddle::dialect::APIBuilder::Instance().GetBuilder();
paddle::dialect::FullOp op1 = builder->Build<paddle::dialect::FullOp>(
std::vector<int64_t>{1}, 2.0, phi::DataType::FLOAT32, phi::CPUPlace());
paddle::dialect::FullOp op2 = builder->Build<paddle::dialect::FullOp>(
std::vector<int64_t>{1}, 2.0, phi::DataType::FLOAT32, phi::CPUPlace());
paddle::dialect::AddOp op3 =
builder->Build<paddle::dialect::AddOp>(op1.out(), op2.out());

paddle::dialect::FullOp op4 = builder->Build<paddle::dialect::FullOp>(
std::vector<int64_t>{1}, 1.0, phi::DataType::FLOAT32, phi::CPUPlace());

std::vector<std::vector<bool>> stop_gradients{{false}, {false}};
std::vector<std::vector<ir::OpResult>> out_grads{{op4.out()}};

ir::OpInfo op3_info = ctx->GetRegisteredOpInfo("pd.add");
auto add_vjp_interface_impl =
op3_info.GetInterfaceImpl<paddle::dialect::VjpInterface>();
add_vjp_interface_impl->vjp_(op3.operation(), out_grads, stop_gradients);

auto kernel_program = paddle::dialect::PdOpLowerToKernelPass(&program);

auto place = platform::CPUPlace();
Scope scope;

ProgramDesc prog_desc;
InterpreterCore test_core(place, {}, std::move(kernel_program), &scope);
std::stringstream os;
os << reinterpret_cast<NewIRInterpreter*>(
const_cast<InterpreterBaseImpl*>(test_core.Impl()));
std::string prefix_str = os.str();
test_core.SetSkipGcVars({prefix_str + "_inner_var_2",
prefix_str + "_inner_var_4",
prefix_str + "_inner_var_5"});
test_core.Run({});
auto out_tensor =
test_core.local_scope() == nullptr
? scope.FindVar(prefix_str + "_inner_var_2")->Get<phi::DenseTensor>()
: test_core.local_scope()
->FindVar(prefix_str + "_inner_var_2")
->Get<phi::DenseTensor>();
auto dx =
test_core.local_scope() == nullptr
? scope.FindVar(prefix_str + "_inner_var_4")->Get<phi::DenseTensor>()
: test_core.local_scope()
->FindVar(prefix_str + "_inner_var_4")
->Get<phi::DenseTensor>();

auto dy =
test_core.local_scope() == nullptr
? scope.FindVar(prefix_str + "_inner_var_5")->Get<phi::DenseTensor>()
: test_core.local_scope()
->FindVar(prefix_str + "_inner_var_5")
->Get<phi::DenseTensor>();
ASSERT_EQ(out_tensor.data<float>()[0], 4.0);
ASSERT_EQ(dx.data<float>()[0], 1.0);
ASSERT_EQ(dy.data<float>()[0], 1.0);
}

TEST(VJP, Add_BackwardTest) {
ir::IrContext* ctx = ir::IrContext::Instance();
ir::Program program((ctx));
paddle::dialect::APIBuilder::Instance().SetProgram(&program);

std::shared_ptr<ir::Builder> builder =
paddle::dialect::APIBuilder::Instance().GetBuilder();
paddle::dialect::FullOp op1 = builder->Build<paddle::dialect::FullOp>(
std::vector<int64_t>{1}, 2.0, phi::DataType::FLOAT32, phi::CPUPlace());
paddle::dialect::FullOp op2 = builder->Build<paddle::dialect::FullOp>(
std::vector<int64_t>{1}, 2.0, phi::DataType::FLOAT32, phi::CPUPlace());
paddle::dialect::Add_Op op3 =
builder->Build<paddle::dialect::Add_Op>(op1.out(), op2.out());

paddle::dialect::FullOp op4 = builder->Build<paddle::dialect::FullOp>(
std::vector<int64_t>{1}, 1.0, phi::DataType::FLOAT32, phi::CPUPlace());

std::vector<std::vector<bool>> stop_gradients{{false}, {false}};
std::vector<std::vector<ir::OpResult>> out_grads{{op4.out()}};

ir::OpInfo op3_info = ctx->GetRegisteredOpInfo("pd.add_");
auto add_inplace_vjp_interface_impl =
op3_info.GetInterfaceImpl<paddle::dialect::VjpInterface>();
add_inplace_vjp_interface_impl->vjp_(
op3.operation(), out_grads, stop_gradients);

auto kernel_program = paddle::dialect::PdOpLowerToKernelPass(&program);

auto place = platform::CPUPlace();
Scope scope;

ProgramDesc prog_desc;
InterpreterCore test_core(place, {}, std::move(kernel_program), &scope);
std::stringstream os;
os << reinterpret_cast<NewIRInterpreter*>(
const_cast<InterpreterBaseImpl*>(test_core.Impl()));
std::string prefix_str = os.str();
test_core.SetSkipGcVars({prefix_str + "_inner_var_0",
prefix_str + "_inner_var_3",
prefix_str + "_inner_var_4"});
test_core.Run({});
auto out_tensor =
test_core.local_scope() == nullptr
? scope.FindVar(prefix_str + "_inner_var_0")->Get<phi::DenseTensor>()
: test_core.local_scope()
->FindVar(prefix_str + "_inner_var_0")
->Get<phi::DenseTensor>();
auto dx =
test_core.local_scope() == nullptr
? scope.FindVar(prefix_str + "_inner_var_3")->Get<phi::DenseTensor>()
: test_core.local_scope()
->FindVar(prefix_str + "_inner_var_3")
->Get<phi::DenseTensor>();

auto dy =
test_core.local_scope() == nullptr
? scope.FindVar(prefix_str + "_inner_var_4")->Get<phi::DenseTensor>()
: test_core.local_scope()
->FindVar(prefix_str + "_inner_var_4")
->Get<phi::DenseTensor>();
ASSERT_EQ(out_tensor.data<float>()[0], 4.0);
ASSERT_EQ(dx.data<float>()[0], 1.0);
ASSERT_EQ(dy.data<float>()[0], 1.0);
}
} // namespace framework
} // namespace paddle

0 comments on commit 84482da

Please sign in to comment.