Skip to content

Commit

Permalink
Add sqrt orig2prim rule and UT
Browse files Browse the repository at this point in the history
  • Loading branch information
0x45f committed Apr 25, 2022
1 parent 9b3609f commit 91f7fbf
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 0 deletions.
6 changes: 6 additions & 0 deletions python/paddle/autograd/primrules.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def linear_jvp(op, *args, **kwargs):
scale
assign
elementwise_mul
sqrt
These original ops are partially supported:
Expand All @@ -123,6 +124,11 @@ def linear_jvp(op, *args, **kwargs):
"""


@REGISTER_ORIG2PRIM('sqrt')
def sqrt_orig2prim(op, x):
return sqrt(x)


@REGISTER_ORIG2PRIM('elementwise_mul')
def elementwise_mul_orig2prim(op, x, y):
if x.shape != y.shape:
Expand Down
18 changes: 18 additions & 0 deletions python/paddle/fluid/tests/unittests/test_orig2prig_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,24 @@ def test_op(self):
self.assertEqual(prim_out[k].shape, v.shape)


class TestSqrtOrig2Prim(TestElementWiseAddOrig2Prim):
def init_data(self):
self.op_type = 'sqrt'
X = paddle.static.data(name='X', shape=[7, 8], dtype='float64')

self.input = {'X': X, }
self.output = {
'Out':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {}

self.orig2prim_args = (X, )
self.all_ops = ['sqrt', 'sqrt_p']
# { prim_op_output_index: orig_op_output_var }
self.out_map = {0: self.output['Out']}


class TestElementWiseMulOrig2Prim(TestElementWiseAddOrig2Prim):
def init_data(self):
self.op_type = 'elementwise_mul'
Expand Down

0 comments on commit 91f7fbf

Please sign in to comment.