Skip to content

Commit

Permalink
fix inplace bug when the first grad_var(loss_grad) is inplace var (#3…
Browse files Browse the repository at this point in the history
…7420)

* fix inplace bug

* fix custom grad input error

* add unittest

* fix inplace bug
  • Loading branch information
pangyoki authored Nov 23, 2021
1 parent 7980097 commit ee1e164
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 6 deletions.
15 changes: 9 additions & 6 deletions paddle/fluid/imperative/basic_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ void BasicEngine::Init(
platform::errors::AlreadyExists(
"Accumulators are not empty before preparing it for "
"backward network execution."));
PADDLE_ENFORCE_EQ(accumulators_with_grad_node_.empty(), true,
platform::errors::AlreadyExists(
"Accumulators with grad_node as the key are not empty "
"before preparing it for backward network execution."));

for (size_t i = 0; i < tensors.size(); ++i) {
auto var = tensors[i];
Expand All @@ -73,7 +77,6 @@ void BasicEngine::Init(
VLOG(5) << "Clear the auto-grad graph from grad var " << var->Name()
<< " because of retain_graph=False when calling backward";
var->GradVarBase()->SetGraphIsFreed(true);
var->GradVarBase()->ClearGradNode();
}

if (init_node == nullptr || var->OverridedStopGradient()) {
Expand Down Expand Up @@ -108,14 +111,18 @@ void BasicEngine::Init(
}

VariableWrapper* init_grad_var = var->GradVarBase()->SharedVar().get();
auto& accumulator = accumulators_[init_grad_var];
auto& accumulator =
accumulators_with_grad_node_[init_grad_var->GetGradNode()]
[init_grad_var];
if (!accumulator) {
if (FLAGS_sort_sum_gradient) {
accumulator.reset(new SortedGradientAccumulator(init_grad_var));
} else {
accumulator.reset(new EagerGradientAccumulator(init_grad_var));
}
}
accumulator->IncreaseRefCnt();
accumulator->IncreaseCurCnt();

init_nodes_.push_back(init_node);
}
Expand Down Expand Up @@ -253,10 +260,6 @@ void BasicEngine::PrepareDeps() {
node_deps_.empty(), true,
platform::errors::AlreadyExists("Op deps are not empty before preparing "
"it for backward network execution."));
PADDLE_ENFORCE_EQ(accumulators_with_grad_node_.empty(), true,
platform::errors::AlreadyExists(
"Accumulators with grad_node as the key are not empty "
"before preparing it for backward network execution."));

std::queue<GradOpNode*> q;
std::unordered_set<GradOpNode*> visited;
Expand Down
25 changes: 25 additions & 0 deletions python/paddle/fluid/tests/unittests/test_inplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,5 +409,30 @@ def inplace_api_processing(self, var):
return var.subtract_(self.input_var_2)


class TestLossIsInplaceVar(unittest.TestCase):
def test_loss_is_inplace_var(self):
with paddle.fluid.dygraph.guard():
var_a = paddle.ones((2, 2))
var_a.stop_gradient = False

var_b = var_a * 2
loss = var_b.tanh_()

loss.backward()
inplace_grad_var_a = var_a.grad.numpy()

with paddle.fluid.dygraph.guard():
var_a = paddle.ones((2, 2))
var_a.stop_gradient = False

var_b = var_a * 2
loss = var_b.tanh()

loss.backward()
grad_var_a = var_a.grad.numpy()

self.assertTrue(np.array_equal(inplace_grad_var_a, grad_var_a))


if __name__ == '__main__':
unittest.main()

0 comments on commit ee1e164

Please sign in to comment.