diff --git a/python/paddle/autograd/primrules.py b/python/paddle/autograd/primrules.py index 3f44ebf1c51b1..b3e0085b29f3c 100644 --- a/python/paddle/autograd/primrules.py +++ b/python/paddle/autograd/primrules.py @@ -649,7 +649,9 @@ def gather_transpose(op, check_dot, y_bar): assert check_dot(x) axis = op.attr('axis') zeros = fill_const(0.0, x.shape, x.dtype) - return scatter_add(zeros, y_bar, indextensor, axis=axis), None + x_bar = scatter_add(zeros, y_bar, indextensor, axis=axis) + indextensor_bar = None + return x_bar, indextensor_bar @REGISTER_TRANSPOSE('scatter_add_p') @@ -658,6 +660,7 @@ def scatter_add_transpose(op, check_dot, z_bar): assert check_dot(x) and check_dot(y) axis = op.attr('axis') zeros = fill_const(value=0.0, shape=y.shape, dtype=y.dtype) - return scatter_add( - z_bar, zeros, indextensor, axis=axis), gather( - z_bar, indextensor, axis=axis), None + x_bar = scatter_add(z_bar, zeros, indextensor, axis=axis) + y_bar = gather(z_bar, indextensor, axis=axis) + indextensor_bar = None + return x_bar, y_bar, indextensor_bar diff --git a/python/paddle/autograd/primx.py b/python/paddle/autograd/primx.py index 614bd98374485..67b9293e65d74 100644 --- a/python/paddle/autograd/primx.py +++ b/python/paddle/autograd/primx.py @@ -194,8 +194,10 @@ def erase_dots(self, vars_to_erase): block = self.block for var in vars_to_erase: block.desc._remove_var(cpt.to_bytes(var.name)) + del block.vars[var.name] block._sync_with_cpp() + def var2dot_rec(self, vars, defaults=None): if isinstance(vars, paddle.fluid.framework.Variable):