Skip to content

Commit

Permalink
update primx.py
Browse files Browse the repository at this point in the history
  • Loading branch information
tongxin committed Apr 13, 2022
1 parent 386e60c commit 5d7cda7
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 6 deletions.
4 changes: 2 additions & 2 deletions python/paddle/autograd/primops.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def _manipulation_unop(helper):

attrs = {
k: helper.kwargs[k]
for k in ('shape', 'axis', 'indexes') if k in helper.kwargs
for k in ('shape', 'axis', 'index') if k in helper.kwargs
}

if out is None:
Expand Down Expand Up @@ -140,7 +140,7 @@ def split(x, num_or_sections, axis=0, outs=None):
return outs


@REGISTER_FN('concat')
@REGISTER_FN('concat_p')
def concat(xs, axis=0, out=None):
assert isinstance(xs, (list, tuple)) and len(xs) > 0
attrs = {'axis': axis}
Expand Down
8 changes: 4 additions & 4 deletions python/paddle/autograd/primreg.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ def lookup(self, name):


_primop_fn = Registry('primop_fn')
_primop_jvp = Registry('primop_jvps')
_primop_transpose = Registry('primop_vjps')
_primop_jvp = Registry('primop_jvp')
_primop_transpose = Registry('primop_transpose')


def lookup_fn(optype):
Expand Down Expand Up @@ -63,7 +63,7 @@ def REGISTER_JVP(op_type):
Usage:
.. code-block:: python
@RegisterJVP('add')
@REGISTER_JVP('add')
def add_jvp(op, x_dot, y_dot):
return primops.add(x_dot, y_dot)
Expand All @@ -85,7 +85,7 @@ def REGISTER_TRANSPOSE(op_type):
Usage:
.. code-block:: python
@RegisterJVP('add')
@REGISTER_TRANSPOSE('add')
def add_transpose(op, z_bar):
return z_bar, z_bar
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/autograd/primx.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,8 @@ def erase_dots(self, vars_to_erase):
del self.vars[id(var)]
self.dot2bar.delete_keyvars(vars_to_erase)
self.var2dot.delete_valuevars(vars_to_erase)
for var in vars_to_erase:
del var.block.vars[var.name]

def is_dot(self, var):
return self.var2dot.contain_value(var)
Expand Down

0 comments on commit 5d7cda7

Please sign in to comment.