Skip to content

Commit

Permalink
Fix comments
Browse files Browse the repository at this point in the history
  • Loading branch information
0x45f committed May 11, 2022
1 parent 6f6c4be commit 2875959
Show file tree
Hide file tree
Showing 9 changed files with 156 additions and 132 deletions.
3 changes: 1 addition & 2 deletions python/paddle/incubate/autograd/primops.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,7 @@ def split(x, num_or_sections, axis=0, outs=None):

@REGISTER_FN('concat_p', 'XS', 'Y')
def concat(xs, axis=0, out=None):
# TODO(lml): This is hacky, refine it later
if not isinstance(xs, (list, tuple)):
if isinstance(xs, paddle.fluid.framework.Variable):
xs = [xs]
attrs = {'axis': axis}
helper = LayerHelper('concat_p', **locals())
Expand Down
43 changes: 27 additions & 16 deletions python/paddle/incubate/autograd/primreg.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ def __init__(self, name):
self.tab = {}

def register(self, name, value):
assert name not in self.tab
assert name not in self.tab, 'name "{}" should not be registered before.'.format(
name)
self.tab[name] = value

def lookup(self, name):
Expand Down Expand Up @@ -67,20 +68,23 @@ def op_position_inputs(op):
```
@REGISTER_FN('div_p', 'X', 'Y', 'Z')
def div(x, y, out=None):
...
return _simple_binop(LayerHelper('div_p', **locals()))
```
The registered inputs are ['X', 'Y'] for div_p and accordingly this
function will return inputs in the order of X then Y.
"""
args = _primop_position_argnames.lookup(op.type)
assert args is not None
assert args is not None, 'args should not be None in op_position_inputs().'
*input_names, _ = args

inputs = []
for name in input_names:
vars = list(map(op.block.var, op.input(name)))
assert len(vars) >= 0
assert len(
vars
) >= 0, 'len(vars) should be greater than or equal to 0, but len(vars)={}.'.format(
(len(vars)))
if len(vars) > 1:
inputs.append(vars)
else:
Expand All @@ -97,18 +101,21 @@ def op_position_output(op):
```
@REGISTER_FN('div_p', 'X', 'Y', 'Z')
def div(x, y, out=None):
...
return _simple_binop(LayerHelper('div_p', **locals()))
```
The registered output is ['Z'] for div_p and accordingly this
function will return output Z.
"""
args = _primop_position_argnames.lookup(op.type)
assert args is not None
assert args is not None, 'args should not be None in op_position_output().'
*_, output_name = args

outvars = list(map(op.block.var, op.output(output_name)))
assert len(outvars) >= 0
assert len(
outvars
) >= 0, 'len(outvars) should be greater than or equal to 0, but len(outvars)={}.'.format(
len(outvars))
if len(outvars) > 1:
output = outvars
else:
Expand All @@ -120,7 +127,7 @@ def div(x, y, out=None):
def REGISTER_FN(op_type, *position_argnames):
"""Decorator for registering the Python function for a primitive op."""

assert isinstance(op_type, str)
assert isinstance(op_type, str), 'type(op_type) must be str.'

_primop_position_argnames.register(op_type, position_argnames)

Expand All @@ -142,11 +149,12 @@ def tanh_orig2prim(op):
return primops.tanh(x)
"""
assert isinstance(op_type, str)
assert isinstance(op_type, str), 'type(op_type) must be str.'

def wrapper(f):
def _lower(op, *args, **kwargs):
assert op.type == op_type
assert op.type == op_type, 'op.type should be equal to op_type, but op.type is {} and op_type is {}'.format(
op.type, op_type)
return f(op, *args, **kwargs)

_orig2prim.register(op_type, _lower)
Expand All @@ -165,11 +173,12 @@ def tanh_prim2orig(op):
return paddle.tanh(x)
"""
assert isinstance(op_type, str)
assert isinstance(op_type, str), 'type(op_type) must be str.'

def wrapper(f):
def _lower(op, *args, **kwargs):
assert op.type == op_type
assert op.type == op_type, 'op.type should be equal to op_type, but op.type is {} and op_type is {}'.format(
op.type, op_type)
return f(op, *args, **kwargs)

_prim2orig.register(op_type, _lower)
Expand All @@ -187,11 +196,12 @@ def add_jvp(op, x_dot, y_dot):
return primops.add(x_dot, y_dot)
"""
assert isinstance(op_type, str)
assert isinstance(op_type, str), 'type(op_type) must be str.'

def wrapper(f):
def _jvp(op, *args, **kwargs):
assert op.type == op_type
assert op.type == op_type, 'op.type should be equal to op_type, but op.type is {} and op_type is {}'.format(
op.type, op_type)
return f(op, *args, **kwargs)

_primop_jvp.register(op_type, _jvp)
Expand All @@ -211,11 +221,12 @@ def add_transpose(op, z_bar):
return z_bar, z_bar
"""
assert isinstance(op_type, str)
assert isinstance(op_type, str), 'type(op_type) must be str.'

def wrapper(f):
def _transpose(op, dot_checker, *args, **kwargs):
assert op.type == op_type
assert op.type == op_type, 'op.type should be equal to op_type, but op.type is {} and op_type is {}'.format(
op.type, op_type)
return f(op, dot_checker, *args, **kwargs)

_primop_transpose.register(op_type, _transpose)
Expand Down
161 changes: 80 additions & 81 deletions python/paddle/incubate/autograd/primrules.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,81 @@ def linear_jvp(op, *args, **kwargs):
"""


@REGISTER_ORIG2PRIM('sqrt')
def sqrt_orig2prim(op, x):
return sqrt(x)
@REGISTER_ORIG2PRIM('elementwise_add')
def elementwise_add_orig2prim(op, x, y):
if x.shape != y.shape:
y = broadcast(y, shape=x.shape)
if op.attr('Scale_x') - 1.0 > 1e-5:
tmp = fill_const(shape=x.shape, dtype=x.dtype, value=op.attr('Scale_x'))
x = mul(x, tmp)
if op.attr('Scale_y') - 1.0 > 1e-5:
tmp = fill_const(shape=y.shape, dtype=y.dtype, value=op.attr('Scale_y'))
y = mul(y, tmp)
z = add(x, y)
if op.attr('Scale_out') - 1.0 > 1e-5:
tmp = fill_const(
shape=z.shape, dtype=z.dtype, value=op.attr('Scale_out'))
z = mul(z, tmp)
return z


@REGISTER_ORIG2PRIM('tanh')
def tanh_orig2prim(op, x):
return tanh(x)


@REGISTER_ORIG2PRIM('fill_zeros_like')
def fill_zeros_like_orig2prim(op, x):
return fill_const(value=0.0, shape=x.shape, dtype=x.dtype)


@REGISTER_ORIG2PRIM('sum')
def sum_orig2prim(op, xs):
x0 = xs[0]
for x in xs[1:]:
x0 = add(x0, x)
return x0


@REGISTER_ORIG2PRIM('index_select')
def index_select_orig2prim(op, index_t, x):
return gather(x, indextensor=index_t, axis=op.attr('dim'))


@REGISTER_ORIG2PRIM('elementwise_sub')
def elementwise_sub_orig2prim(op, x, y):
if x.shape != y.shape:
y = broadcast(y, shape=x.shape)
if op.attr('Scale_x') - 1.0 > 1e-5:
tmp = fill_const(shape=x.shape, dtype=x.dtype, value=op.attr('Scale_x'))
x = mul(x, tmp)
if op.attr('Scale_y') - 1.0 > 1e-5:
tmp = fill_const(shape=y.shape, dtype=y.dtype, value=op.attr('Scale_y'))
y = mul(y, tmp)
z = sub(x, y)
if op.attr('Scale_out') - 1.0 > 1e-5:
tmp = fill_const(
shape=z.shape, dtype=z.dtype, value=op.attr('Scale_out'))
z = mul(z, tmp)
return z


@REGISTER_ORIG2PRIM('scale')
def scale_orig2prim(op, scale_t, x):
if scale_t is None:
scale_t = fill_const(
shape=x.shape, dtype=x.dtype, value=op.attr('scale'))
bias_t = fill_const(shape=x.shape, dtype=x.dtype, value=op.attr('bias'))
if op.attr('bias_after_scale'):
return add(mul(x, scale_t), bias_t)
else:
return mul(add(x, bias_t), scale_t)


@REGISTER_ORIG2PRIM('assign')
def assign_orig2prim(op, x):
zero_t = fill_const(shape=x.shape, dtype=x.dtype, value=0.0)
return add(x, zero_t)


@REGISTER_ORIG2PRIM('elementwise_mul')
Expand All @@ -97,6 +169,11 @@ def elementwise_mul_orig2prim(op, x, y):
return z


@REGISTER_ORIG2PRIM('sqrt')
def sqrt_orig2prim(op, x):
return sqrt(x)


@REGISTER_ORIG2PRIM('matmul_v2')
def matmul_v2_orig2prim(op, x, y):
def trans(shape):
Expand All @@ -118,29 +195,6 @@ def trans(shape):
return matmul(x, y)


@REGISTER_ORIG2PRIM('elementwise_add')
def elementwise_add_orig2prim(op, x, y):
if x.shape != y.shape:
y = broadcast(y, shape=x.shape)
if op.attr('Scale_x') - 1.0 > 1e-5:
tmp = fill_const(shape=x.shape, dtype=x.dtype, value=op.attr('Scale_x'))
x = mul(x, tmp)
if op.attr('Scale_y') - 1.0 > 1e-5:
tmp = fill_const(shape=y.shape, dtype=y.dtype, value=op.attr('Scale_y'))
y = mul(y, tmp)
z = add(x, y)
if op.attr('Scale_out') - 1.0 > 1e-5:
tmp = fill_const(
shape=z.shape, dtype=z.dtype, value=op.attr('Scale_out'))
z = mul(z, tmp)
return z


@REGISTER_ORIG2PRIM('tanh')
def tanh_orig2prim(op, x):
return tanh(x)


## NOTE(lml): The second output of reshape2 Xshape, which is only used in reshape2_grad, is meanlingless in new autograd mechanism, thus we use a zero tensor instead.
@REGISTER_ORIG2PRIM('reshape2')
def reshape2_orig2prim(op, shape_t, shape_tl, x):
Expand Down Expand Up @@ -169,25 +223,11 @@ def slice_orig2prim(op, ends_t, ends_tl, x, starts_t, starts_tl):
strides = [1 for _ in starts]
axis = op.attr('axes')
y = slice_select(x, starts=starts, ends=ends, strides=strides, axis=axis)
# op.attr('decrease_axis') is p[]
if op.attr('decrease_axis'):
y = reshape(y, shape=get_output_var_list(op)[0].shape)
return y


@REGISTER_ORIG2PRIM('fill_zeros_like')
def fill_zeros_like_orig2prim(op, x):
return fill_const(value=0.0, shape=x.shape, dtype=x.dtype)


@REGISTER_ORIG2PRIM('sum')
def sum_orig2prim(op, xs):
x0 = xs[0]
for x in xs[1:]:
x0 = add(x0, x)
return x0


@REGISTER_ORIG2PRIM('p_norm')
def p_norm_orig2prim(op, x):
def num_el(shape):
Expand All @@ -209,47 +249,6 @@ def num_el(shape):
raise RuntimeError('Only support lower l2/l1 norm currently')


@REGISTER_ORIG2PRIM('index_select')
def index_select_orig2prim(op, index_t, x):
return gather(x, indextensor=index_t, axis=op.attr('dim'))


@REGISTER_ORIG2PRIM('elementwise_sub')
def elementwise_sub_orig2prim(op, x, y):
if x.shape != y.shape:
y = broadcast(y, shape=x.shape)
if op.attr('Scale_x') - 1.0 > 1e-5:
tmp = fill_const(shape=x.shape, dtype=x.dtype, value=op.attr('Scale_x'))
x = mul(x, tmp)
if op.attr('Scale_y') - 1.0 > 1e-5:
tmp = fill_const(shape=y.shape, dtype=y.dtype, value=op.attr('Scale_y'))
y = mul(y, tmp)
z = sub(x, y)
if op.attr('Scale_out') - 1.0 > 1e-5:
tmp = fill_const(
shape=z.shape, dtype=z.dtype, value=op.attr('Scale_out'))
z = mul(z, tmp)
return z


@REGISTER_ORIG2PRIM('scale')
def scale_orig2prim(op, scale_t, x):
if scale_t is None:
scale_t = fill_const(
shape=x.shape, dtype=x.dtype, value=op.attr('scale'))
bias_t = fill_const(shape=x.shape, dtype=x.dtype, value=op.attr('bias'))
if op.attr('bias_after_scale'):
return add(mul(x, scale_t), bias_t)
else:
return mul(add(x, bias_t), scale_t)


@REGISTER_ORIG2PRIM('assign')
def assign_orig2prim(op, x):
zero_t = fill_const(shape=x.shape, dtype=x.dtype, value=0.0)
return add(x, zero_t)


## Register prim2orig lower rules


Expand Down
Loading

0 comments on commit 2875959

Please sign in to comment.