Skip to content

Commit

Permalink
Add elementwise_mul orig2prim and support p_norm when p=1
Browse files Browse the repository at this point in the history
  • Loading branch information
0x45f committed Apr 25, 2022
1 parent e5945ed commit 9b3609f
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 4 deletions.
29 changes: 26 additions & 3 deletions python/paddle/autograd/primrules.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def linear_jvp(op, *args, **kwargs):
elementwise_sub
scale
assign
elementwise_mul
These original ops are partially supported:
Expand All @@ -122,6 +123,24 @@ def linear_jvp(op, *args, **kwargs):
"""


@REGISTER_ORIG2PRIM('elementwise_mul')
def elementwise_mul_orig2prim(op, x, y):
if x.shape != y.shape:
y = broadcast(y, shape=x.shape)
if op.attr('Scale_x') - 1.0 > 1e-5:
tmp = fill_const(shape=x.shape, dtype=x.dtype, value=op.attr('Scale_x'))
x = mul(x, tmp)
if op.attr('Scale_y') - 1.0 > 1e-5:
tmp = fill_const(shape=y.shape, dtype=y.dtype, value=op.attr('Scale_y'))
y = mul(y, tmp)
z = mul(x, y)
if op.attr('Scale_out') - 1.0 > 1e-5:
tmp = fill_const(
shape=z.shape, dtype=z.dtype, value=op.attr('Scale_out'))
z = mul(z, tmp)
return z


@REGISTER_ORIG2PRIM('matmul_v2')
def matmul_v2_orig2prim(op, x, y):
def trans(shape):
Expand Down Expand Up @@ -222,13 +241,17 @@ def num_el(shape):
n = n * s
return n

assert op.attr(
'porder') - 2.0 < 1e-5, 'Only support lower l2 norm currently'
assert op.attr(
'asvector'), 'Only support lower pnorm when asvector=True currently'
if len(x.shape) > 1:
x = reshape(x, shape=[num_el(x.shape)])
return sqrt(reduce(mul(x, x), axis=[0]))

if op.attr('porder') - 2.0 < 1e-5:
return sqrt(reduce(mul(x, x), axis=[0]))
elif op.attr('porder') - 1.0 < 1e-5:
return reduce(sqrt(mul(x, x)), axis=[0])
else:
raise RuntimeError('Only support lower l2/l1 norm currently')


@REGISTER_ORIG2PRIM('index_select')
Expand Down
41 changes: 40 additions & 1 deletion python/paddle/fluid/tests/unittests/test_orig2prig_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,25 @@ def test_op(self):
self.assertEqual(prim_out[k].shape, v.shape)


class TestElementWiseMulOrig2Prim(TestElementWiseAddOrig2Prim):
def init_data(self):
self.op_type = 'elementwise_mul'
X = paddle.static.data(name='X', shape=[8, 8], dtype='float')
Y = paddle.static.data(name='Y', shape=[8, 8], dtype='float')

self.input = {'X': X, 'Y': Y}
self.output = {
'Out':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {}

self.orig2prim_args = (X, Y)
self.all_ops = ['elementwise_mul', 'mul_p']
# { prim_op_output_index: orig_op_output_var }
self.out_map = {0: self.output['Out']}


class TestMatmulV2Orig2Prim(TestElementWiseAddOrig2Prim):
def init_data(self):
self.op_type = 'matmul_v2'
Expand Down Expand Up @@ -201,7 +220,27 @@ def init_data(self):
self.out_map = {0: self.output['Out']}


class TestPNormOrig2Prim(TestElementWiseAddOrig2Prim):
class TestPNormOrig2Prim1(TestElementWiseAddOrig2Prim):
def init_data(self):
self.op_type = 'p_norm'
X = paddle.static.data(name='X', shape=[5, 6], dtype='int64')

self.input = {'X': X, }
self.output = {
'Out':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {
'porder': 1,
'asvector': True,
}

self.orig2prim_args = (X, )
self.all_ops = ['p_norm', 'reshape_p', 'sqrt_p', 'reduce_p', 'mul_p']
self.out_map = {0: self.output['Out']}


class TestPNormOrig2Prim2(TestElementWiseAddOrig2Prim):
def init_data(self):
self.op_type = 'p_norm'
X = paddle.static.data(name='X', shape=[5, 6], dtype='int64')
Expand Down

0 comments on commit 9b3609f

Please sign in to comment.