Skip to content

Commit

Permalink
[DoubleGrad PR PaddlePaddle#7] paddle.grad() to copy backward graph b…
Browse files Browse the repository at this point in the history
…efore backward run (PaddlePaddle#41306)

* [Refactor] refactored eager_gen.py PR #2

* [DoubleGrad PR #1] Decoupled code generation logics for Dygraph ForwardFunctions and GradNodes

* Fixed minor issue

* Adjusted logics of GenerateNodeCreationCodes and GenerateForwardDefinition

* Fixed issues

* Supported higher-order grad node generation

* [DoubleGrad PR PaddlePaddle#4] Supported higher-order GradNode generation

* [DoubleGrad PaddlePaddle#4] Bug Fixes to Double Grad Node Generation

* Fixed yaml typo

* Fixed yaml typo

* fixed minor issues

* [DoubleGrad PR PaddlePaddle#5] Enabled gradient computations for grad_tensors passed to paddle.grad()

* Fixed minor issue

* Fixed CI-Inference issue

* Fixed CI-inference issues

* [DoubleGrad PR PaddlePaddle#7] paddle.grad() to copy backward graph before backward run

* Fixed minor issues

* Fixed issue with backward graph construction logic

* Fixed implementation issues with backward graph reconstruction

* Fixed unittest issue

* Fixed issues
  • Loading branch information
jim19930609 authored Apr 4, 2022
1 parent 5936fa6 commit a2b8014
Show file tree
Hide file tree
Showing 12 changed files with 237 additions and 54 deletions.
15 changes: 9 additions & 6 deletions paddle/fluid/eager/accumulation/accumulation_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@ class GradNodeAccumulation : public GradNodeBase {
// Constructor: configure fwd input tensors to grad node
explicit GradNodeAccumulation(AutogradMeta* meta) : GradNodeBase(1, 1) {
VLOG(6) << "Construct GradNodeAccumulation";
weak_grad_ = meta->WeakGrad();
if (meta) {
weak_grad_ = meta->WeakGrad();
}

SetDefaultGradInOutMeta();
}

Expand All @@ -40,11 +43,6 @@ class GradNodeAccumulation : public GradNodeBase {

void ClearTensorWrappers() override { VLOG(6) << "Do nothing here now"; }

bool IsTensorWrappersCleared() override {
VLOG(6) << "Do nothing here now";
return false;
}

std::string name() { return "GradNodeAccumulation"; }

/**
Expand All @@ -58,6 +56,11 @@ class GradNodeAccumulation : public GradNodeBase {
inline bool ReduceHooksRegistered() { return reduce_hooks_.size() != 0; }
void ApplyReduceHooks();

std::shared_ptr<GradNodeBase> Copy() const override {
return std::shared_ptr<GradNodeAccumulation>(
new GradNodeAccumulation(nullptr));
}

private:
std::weak_ptr<paddle::experimental::Tensor> weak_grad_;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,18 +44,19 @@ class GradNodeScale : public GradNodeBase {

void ClearTensorWrappers() override { VLOG(6) << "Do nothing here now"; }

bool IsTensorWrappersCleared() override {
VLOG(6) << "Do nothing here now";
return false;
}

void SetTensorWrappers_X(
const std::vector<paddle::experimental::Tensor>& tensors);

void SetAttributes_scale(float scale);
std::string name() override { return ""; }
// Members: define fwd input tensors
// For Scale there is no fwd input tensor needed

std::shared_ptr<GradNodeBase> Copy() const override {
auto copied_node = std::make_shared<GradNodeScale>(*this);
return copied_node;
}

private:
float scale_{1.0};
};
Expand Down
18 changes: 10 additions & 8 deletions paddle/fluid/eager/auto_code_generator/eager_generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2479,22 +2479,23 @@ static std::string GenerateGradNodeHeaderContents(
"\n"
" void ClearTensorWrappers() override { \n"
"%s\n"
" is_tensor_wrappers_cleared = true;\n"
" SetIsTensorWrappersCleared(true);\n"
" }\n"
" std::string name() override { return \" GradNode%s \"; } \n "
"\n"
"std::shared_ptr<GradNodeBase> Copy() const override {{\n "
" auto copied_node = std::shared_ptr<GradNode%s>(new "
"GradNode%s(*this));\n "
" return copied_node;\n "
"}}\n "
"\n"
" // SetX, SetY, ...\n"
"%s\n"
" // SetAttrMap\n"
"%s\n"
" bool IsTensorWrappersCleared() override { \n"
" return is_tensor_wrappers_cleared;\n"
" }\n"
" private:\n"
" // TensorWrappers\n"
"%s\n"
" bool is_tensor_wrappers_cleared = false;\n"
"\n"
" // Attribute Map\n"
"%s\n"
"};";
Expand Down Expand Up @@ -2601,8 +2602,9 @@ static std::string GenerateGradNodeHeaderContents(

std::string grad_node_str = paddle::string::Sprintf(
GRAD_NODE_TEMPLATE, op_type, op_type, op_type, op_type, op_type, op_type,
op_type, clear_tensor_wrappers_str, op_type, set_tensor_wrappers_str,
set_attr_map_str, tensor_wrapper_members_str, attr_members_str);
op_type, clear_tensor_wrappers_str, op_type, op_type, op_type,
set_tensor_wrappers_str, set_attr_map_str, tensor_wrapper_members_str,
attr_members_str);

return grad_node_str;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,23 +125,24 @@ class {} : public egr::GradNodeBase {{
void ClearTensorWrappers() override {{
{}
is_tensor_wrappers_cleared = true;
SetIsTensorWrappersCleared(true);
}}
std::shared_ptr<GradNodeBase> Copy() const override {{
auto copied_node = std::shared_ptr<{}>(new {}(*this));
return copied_node;
}}
// SetTensorWrapperX, SetTensorWrapperY, ...
{}
// SetAttributes
{}
bool IsTensorWrappersCleared() override {{
return is_tensor_wrappers_cleared;
}}
private:
// TensorWrappers
{}
bool is_tensor_wrappers_cleared = false;
// Attributes
{}
}};
Expand Down Expand Up @@ -1218,9 +1219,10 @@ def GenerateNodeDeclaration(self):
grad_node_name = GetGradNodeName(forward_op_name)
self.node_declaration_str = NODE_DECLARATION_TEMPLATE.format(
grad_node_name, grad_node_name, grad_node_name, grad_node_name,
grad_node_name, clear_tensor_wrapper_str,
set_tensor_wrapper_methods_str, set_attribute_methods_str,
tensor_wrapper_members_str, attribute_members_str)
grad_node_name, clear_tensor_wrapper_str, grad_node_name,
grad_node_name, set_tensor_wrapper_methods_str,
set_attribute_methods_str, tensor_wrapper_members_str,
attribute_members_str)

logging.info(f"Generated Node Declaration: {self.node_declaration_str}")

Expand Down
116 changes: 111 additions & 5 deletions paddle/fluid/eager/backward.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,16 @@ class GeneralGrad {
for (size_t i = 0; i < num_inputs; i++) {
AutogradMeta* auto_grad_meta =
EagerUtils::unsafe_autograd_meta(inputs[i]);
auto target_node = auto_grad_meta->GetMutableGradNode().get();
auto* target_node = auto_grad_meta->GetMutableGradNode().get();

if (orig_to_copied_node_mapping_.count(target_node)) {
target_node = orig_to_copied_node_mapping_[target_node];
} else {
VLOG(6) << "Unable to find target node in "
"orig_to_copied_node_mapping_, likely indicating an "
"unused input";
}

PADDLE_ENFORCE_NOT_NULL(target_node,
paddle::platform::errors::Fatal(
"There is no grad op for %s:[%d] or it's"
Expand Down Expand Up @@ -249,7 +258,15 @@ class GeneralGrad {
for (size_t i = 0; i < inputs.size(); ++i) {
auto& input = inputs[i];
AutogradMeta* auto_grad_meta = EagerUtils::unsafe_autograd_meta(input);
auto target_node = auto_grad_meta->GetMutableGradNode().get();

auto* target_node = auto_grad_meta->GetMutableGradNode().get();
if (orig_to_copied_node_mapping_.count(target_node)) {
target_node = orig_to_copied_node_mapping_[target_node];
} else {
VLOG(6) << "Unable to find target node in "
"orig_to_copied_node_mapping_, likely indicating an unused "
"input";
}

auto iter = results_map.find(target_node);
if (iter != results_map.end()) {
Expand Down Expand Up @@ -326,6 +343,78 @@ class GeneralGrad {
potential_stop_nodes.clear();
depending_nodes.clear();
results_map.clear();
copied_grad_nodes_.clear();
orig_to_copied_node_mapping_.clear();
}

GradNodeBase* CopyGradNode(const std::shared_ptr<GradNodeBase>& orig_node) {
if (orig_to_copied_node_mapping_.count(orig_node.get())) {
return orig_to_copied_node_mapping_[orig_node.get()];
}
std::shared_ptr<GradNodeBase> copied_node = orig_node->Copy();

// Save node and update mapping
orig_to_copied_node_mapping_[orig_node.get()] = copied_node.get();
copied_grad_nodes_.push_back(copied_node);

return copied_node.get();
}

void ReconstructBackwardGraph(
const std::queue<GradNodeBase*>& orig_init_queue) {
std::queue<GradNodeBase*> queue = orig_init_queue;
std::unordered_set<GradNodeBase*> visited;

// BFS and recursively copy the grad nodes
while (!queue.empty()) {
GradNodeBase* orig_node = queue.front();
queue.pop();
if (visited.count(orig_node)) {
continue;
}
visited.insert(orig_node);

PADDLE_ENFORCE(
orig_to_copied_node_mapping_.count(orig_node),
paddle::platform::errors::Fatal(
"Cannot reconstruct backward graph,"
"unable to find copied target for certain grad node."));
GradNodeBase* copied_node = orig_to_copied_node_mapping_[orig_node];

const std::vector<std::vector<Edge>>& orig_edges = orig_node->GetEdges();
std::vector<std::vector<Edge>>& copied_edges =
copied_node->GetMutableEdges();
for (size_t i = 0; i < orig_edges.size(); i++) {
for (size_t j = 0; j < orig_edges[i].size(); j++) {
const Edge& orig_edge = orig_edges[i][j];
Edge& copied_edge = copied_edges[i][j];

std::shared_ptr<GradNodeBase> orig_next_node =
orig_edge.GetMutableGradNode();
if (!orig_next_node) continue;

// Copy Next Node
std::shared_ptr<GradNodeBase> copied_next_node;
if (orig_to_copied_node_mapping_.count(orig_next_node.get())) {
copied_next_node =
orig_to_copied_node_mapping_[orig_next_node.get()]
->shared_from_this();

} else {
copied_next_node = orig_next_node->Copy();
orig_to_copied_node_mapping_[orig_next_node.get()] =
copied_next_node.get();
copied_grad_nodes_.push_back(copied_next_node);
}

// Update Edge's Grad Node
copied_edge.SetGradNode(copied_next_node);

// Update BFS queue
queue.push(orig_next_node.get());
}
}
}
}

private:
Expand All @@ -345,6 +434,10 @@ class GeneralGrad {
std::unordered_set<GradNodeBase*> /* pre nodes */>
depending_nodes;
std::unordered_map<GradNodeBase*, paddle::experimental::Tensor> results_map;

std::vector<std::shared_ptr<GradNodeBase>> copied_grad_nodes_;
std::unordered_map<GradNodeBase*, GradNodeBase*> orig_to_copied_node_mapping_;

DISABLE_COPY_AND_ASSIGN(GeneralGrad);
};

Expand Down Expand Up @@ -444,6 +537,7 @@ std::vector<paddle::experimental::Tensor> RunBackward(
// 1. Init queue with starting nodes
// 2. Prepare initial input buffers
std::queue<GradNodeBase*> queue;
std::queue<GradNodeBase*> orig_queue;
std::unordered_map<GradNodeBase*, std::unique_ptr<GradTensorHolder>>
node_input_buffers_dict;
for (size_t i = 0; i < tensors.size(); i++) {
Expand All @@ -468,6 +562,16 @@ std::vector<paddle::experimental::Tensor> RunBackward(

// TODO(zhanlve): Copy and Modify GradNode if is_general_grad
GradNodeBase* grad_node = shared_grad_node.get();
if (is_general_grad) {
// Save orig grad node
orig_queue.push(grad_node);

// Replace grad_node with copied grad_node
grad_node = GeneralGrad::Instance().CopyGradNode(shared_grad_node);

// Record potential startup grad node
GeneralGrad::Instance().GetPotentialStartupNodes()->insert(grad_node);
}

// Prepare GradTensorHolder
if (!node_input_buffers_dict.count(grad_node)) {
Expand Down Expand Up @@ -504,9 +608,11 @@ std::vector<paddle::experimental::Tensor> RunBackward(

// Prepare queue, potential startup_nodes
queue.push(grad_node);
if (is_general_grad) {
GeneralGrad::Instance().GetPotentialStartupNodes()->emplace(grad_node);
}
}

if (is_general_grad) {
// Copy Backward Graph
GeneralGrad::Instance().ReconstructBackwardGraph(orig_queue);
}

VLOG(6) << "Update In degree Map for backward";
Expand Down
17 changes: 10 additions & 7 deletions paddle/fluid/eager/custom_operator/custom_operator_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,10 @@ class RunCustomOpNode : public GradNodeBase {
}

// Functor: perform backward computations
virtual std::vector<std::vector<paddle::experimental::Tensor>> operator()(
std::vector<std::vector<paddle::experimental::Tensor>>& grads,
bool create_graph = false) // NOLINT
virtual std::vector<std::vector<paddle::experimental::Tensor>>
operator()( // NOLINT
std::vector<std::vector<paddle::experimental::Tensor>>& grads, // NOLINT
bool create_graph = false) // NOLINT
override;

std::string name() {
Expand All @@ -64,13 +65,15 @@ class RunCustomOpNode : public GradNodeBase {
}

void ClearTensorWrappers() override { VLOG(6) << "Do nothing here now"; }
bool IsTensorWrappersCleared() override {
VLOG(6) << "Do nothing here now";
return false;
}

void SetAttrs(const std::vector<paddle::any>& attr) { attrs_ = attr; }

std::shared_ptr<GradNodeBase> Copy() const override {
auto copied_node =
std::shared_ptr<RunCustomOpNode>(new RunCustomOpNode(*this));
return copied_node;
}

public:
std::unordered_map<int, std::vector<egr::TensorWrapper>> fwd_outs;
std::unordered_map<int, std::vector<egr::TensorWrapper>> fwd_ins;
Expand Down
4 changes: 4 additions & 0 deletions paddle/fluid/eager/grad_node_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,10 @@ const std::vector<std::vector<Edge>>& GradNodeBase::GetEdges() const {
return adj_edges_;
}

std::vector<std::vector<Edge>>& GradNodeBase::GetMutableEdges() {
return adj_edges_;
}

std::vector<std::vector<paddle::experimental::Tensor>>
GradNodeBase::ApplyGradientHooks(
const std::vector<std::vector<paddle::experimental::Tensor>>& tensors) {
Expand Down
Loading

0 comments on commit a2b8014

Please sign in to comment.