From b2419b6c0c83363f04a386f5c689f14d7369cbef Mon Sep 17 00:00:00 2001 From: Tongxin Bai Date: Mon, 18 Apr 2022 12:03:54 +0000 Subject: [PATCH 1/3] Update gather_p pywrapper. --- python/paddle/autograd/primops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/autograd/primops.py b/python/paddle/autograd/primops.py index b997148cbeaa3..3965754990531 100644 --- a/python/paddle/autograd/primops.py +++ b/python/paddle/autograd/primops.py @@ -263,7 +263,7 @@ def slice_assign(x, y, axis, starts, ends, strides, out=None): return out -@REGISTER_FN('gather_p', 'X', 'Y') +@REGISTER_FN('gather_p', 'X', 'IndexTensor', 'Y') def gather(x, indextensor, axis, out=None): attrs = {'axis': axis} helper = LayerHelper('gather_p', **locals()) From bf103dee59c915b77ae8da0c9eb1cb0bd0a7b0ec Mon Sep 17 00:00:00 2001 From: Tongxin Bai Date: Mon, 18 Apr 2022 13:50:20 +0000 Subject: [PATCH 2/3] polish code. --- python/paddle/autograd/primrules.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/python/paddle/autograd/primrules.py b/python/paddle/autograd/primrules.py index 782b64b110969..87c67af44b3b9 100644 --- a/python/paddle/autograd/primrules.py +++ b/python/paddle/autograd/primrules.py @@ -625,7 +625,9 @@ def gather_transpose(op, check_dot, y_bar): assert check_dot(x) axis = op.attr('axis') zeros = fill_const(0.0, x.shape, x.dtype) - return scatter_add(zeros, y_bar, indextensor, axis=axis), None + x_bar = scatter_add(zeros, y_bar, indextensor, axis=axis) + indextensor_bar = None + return x_bar, indextensor_bar @REGISTER_TRANSPOSE('scatter_add_p') @@ -634,6 +636,7 @@ def scatter_add_transpose(op, check_dot, z_bar): assert check_dot(x) and check_dot(y) axis = op.attr('axis') zeros = fill_const(value=0.0, shape=y.shape, dtype=y.dtype) - return scatter_add( - z_bar, zeros, indextensor, axis=axis), gather( - z_bar, indextensor, axis=axis), None + x_bar = scatter_add(z_bar, zeros, indextensor, axis=axis) + y_bar = gather(z_bar, indextensor, axis=axis) + indextensor_bar = None + return x_bar, y_bar, indextensor_bar From 0d382360be1c127c301eee7c3980771a8dead26f Mon Sep 17 00:00:00 2001 From: Tongxin Bai Date: Tue, 19 Apr 2022 03:18:56 +0000 Subject: [PATCH 3/3] delete vars after block.desc._remove --- python/paddle/autograd/primx.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/paddle/autograd/primx.py b/python/paddle/autograd/primx.py index 9b23cdc4f4fb0..61f1835c2f760 100644 --- a/python/paddle/autograd/primx.py +++ b/python/paddle/autograd/primx.py @@ -193,8 +193,10 @@ def erase_dots(self, vars_to_erase): block = self.block for var in vars_to_erase: block.desc._remove_var(cpt.to_bytes(var.name)) + del block.vars[var.name] block._sync_with_cpp() + def var2dot_rec(self, vars, defaults=None): if isinstance(vars, paddle.fluid.framework.Variable):