Skip to content

Commit

Permalink
Cherry-pick PR 37420, fix inplace bug when the first grad_var(loss_gr…
Browse files Browse the repository at this point in the history
…ad) is inplace var (#37420)

* fix inplace bug

* fix custom grad input error

* add unittest

* fix inplace bug
  • Loading branch information
pangyoki committed Nov 23, 2021
1 parent f873d3a commit 5fd5c1c
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 6 deletions.
15 changes: 9 additions & 6 deletions paddle/fluid/imperative/basic_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ void BasicEngine::Init(
platform::errors::AlreadyExists(
"Accumulators are not empty before preparing it for "
"backward network execution."));
PADDLE_ENFORCE_EQ(accumulators_with_grad_node_.empty(), true,
platform::errors::AlreadyExists(
"Accumulators with grad_node as the key are not empty "
"before preparing it for backward network execution."));

for (size_t i = 0; i < tensors.size(); ++i) {
auto var = tensors[i];
Expand All @@ -73,7 +77,6 @@ void BasicEngine::Init(
VLOG(5) << "Clear the auto-grad graph from grad var " << var->Name()
<< " because of retain_graph=False when calling backward";
var->GradVarBase()->SetGraphIsFreed(true);
var->GradVarBase()->ClearGradNode();
}

if (init_node == nullptr || var->OverridedStopGradient()) {
Expand Down Expand Up @@ -108,14 +111,18 @@ void BasicEngine::Init(
}

VariableWrapper* init_grad_var = var->GradVarBase()->SharedVar().get();
auto& accumulator = accumulators_[init_grad_var];
auto& accumulator =
accumulators_with_grad_node_[init_grad_var->GetGradNode()]
[init_grad_var];
if (!accumulator) {
if (FLAGS_sort_sum_gradient) {
accumulator.reset(new SortedGradientAccumulator(init_grad_var));
} else {
accumulator.reset(new EagerGradientAccumulator(init_grad_var));
}
}
accumulator->IncreaseRefCnt();
accumulator->IncreaseCurCnt();

init_nodes_.push_back(init_node);
}
Expand Down Expand Up @@ -253,10 +260,6 @@ void BasicEngine::PrepareDeps() {
node_deps_.empty(), true,
platform::errors::AlreadyExists("Op deps are not empty before preparing "
"it for backward network execution."));
PADDLE_ENFORCE_EQ(accumulators_with_grad_node_.empty(), true,
platform::errors::AlreadyExists(
"Accumulators with grad_node as the key are not empty "
"before preparing it for backward network execution."));

std::queue<GradOpNode*> q;
std::unordered_set<GradOpNode*> visited;
Expand Down
25 changes: 25 additions & 0 deletions python/paddle/fluid/tests/unittests/test_inplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,5 +409,30 @@ def inplace_api_processing(self, var):
return var.subtract_(self.input_var_2)


class TestLossIsInplaceVar(unittest.TestCase):
def test_loss_is_inplace_var(self):
with paddle.fluid.dygraph.guard():
var_a = paddle.ones((2, 2))
var_a.stop_gradient = False

var_b = var_a * 2
loss = var_b.tanh_()

loss.backward()
inplace_grad_var_a = var_a.grad.numpy()

with paddle.fluid.dygraph.guard():
var_a = paddle.ones((2, 2))
var_a.stop_gradient = False

var_b = var_a * 2
loss = var_b.tanh()

loss.backward()
grad_var_a = var_a.grad.numpy()

self.assertTrue(np.array_equal(inplace_grad_var_a, grad_var_a))


if __name__ == '__main__':
unittest.main()

1 comment on commit 5fd5c1c

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.