Skip to content

Commit

Permalink
Add some assert message
Browse files Browse the repository at this point in the history
  • Loading branch information
0x45f committed May 15, 2022
1 parent c764b95 commit 28e812e
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 16 deletions.
42 changes: 27 additions & 15 deletions python/paddle/incubate/autograd/primrules.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,7 +552,9 @@ def scatter_add_jvp(op, x_dot, y_dot):
@REGISTER_TRANSPOSE('add_p')
def add_transpose(op, check_dot, z_bar):
x, y = op_position_inputs(op)
assert check_dot(x) or check_dot(y)
assert check_dot(x) or check_dot(y), (
f'(check_dot(x) or check_dot(y)) must be True, '
f'but check_dot(x)={check_dot(x)} and check_dot(y)={check_dot(y)}.')
x_bar = z_bar if check_dot(x) else None
y_bar = z_bar if check_dot(y) else None
return x_bar, y_bar
Expand All @@ -561,7 +563,9 @@ def add_transpose(op, check_dot, z_bar):
@REGISTER_TRANSPOSE('sub_p')
def sub_transpose(op, check_dot, z_bar):
x, y = op_position_inputs(op)
assert check_dot(x) or check_dot(y)
assert check_dot(x) or check_dot(y), (
f'(check_dot(x) or check_dot(y)) must be True, '
f'but check_dot(x)={check_dot(x)} and check_dot(y)={check_dot(y)}.')
x_bar = z_bar if check_dot(x) else None
y_bar = neg(z_bar) if check_dot(y) else None
return x_bar, y_bar
Expand All @@ -570,7 +574,9 @@ def sub_transpose(op, check_dot, z_bar):
@REGISTER_TRANSPOSE('mul_p')
def mul_transpose(op, check_dot, z_bar):
x, y = op_position_inputs(op)
assert check_dot(x) ^ check_dot(y)
assert check_dot(x) ^ check_dot(y), (
f'(check_dot(x) ^ check_dot(y)) must be True, '
f'but check_dot(x)={check_dot(x)} and check_dot(y)={check_dot(y)}.')
if check_dot(x):
return mul(z_bar, y), None
else:
Expand All @@ -580,22 +586,22 @@ def mul_transpose(op, check_dot, z_bar):
@REGISTER_TRANSPOSE('div_p')
def div_transpose(op, check_dot, z_bar):
x, y = op_position_inputs(op)
assert not check_dot(y)
assert not check_dot(y), 'check_dot(y) must be False'
x_bar = div(z_bar, y) if check_dot(x) else None
return x_bar, None


@REGISTER_TRANSPOSE('reshape_p')
def reshape_transpose(op, check_dot, y_bar):
x, = op_position_inputs(op)
assert check_dot(x)
assert check_dot(x), 'check_dot(x) must be True'
return reshape(y_bar, shape=x.shape)


@REGISTER_TRANSPOSE('broadcast_p')
def broadcast_transpose(op, check_dot, y_bar):
x, = op_position_inputs(op)
assert check_dot(x)
assert check_dot(x), 'check_dot(x) must be True'
bat = len(y_bar.shape) - len(x.shape)
axis = list(range(bat))
keepdim = [(bat + i) for i, s in enumerate(x.shape) if s == 1]
Expand All @@ -608,7 +614,7 @@ def broadcast_transpose(op, check_dot, y_bar):
@REGISTER_TRANSPOSE('transpose_p')
def transpose_transpose(op, check_dot, y_bar):
x, = op_position_inputs(op)
assert check_dot(x)
assert check_dot(x), 'check_dot(x) must be True'
axis = op.attr('axis')
reordered = sorted((k, i) for i, k in enumerate(axis))
axis = [i for k, i in reordered]
Expand All @@ -618,15 +624,15 @@ def transpose_transpose(op, check_dot, y_bar):
@REGISTER_TRANSPOSE('split_p')
def split_transpose(op, check_dot, ys_bar):
x, = op_position_inputs(op)
assert check_dot(x)
assert check_dot(x), 'check_dot(x) must be True'
return concat(ys_bar, axis=op.attr('axis'))


@REGISTER_TRANSPOSE('concat_p')
def concat_transpose(op, check_dot, y_bar):
xs, = op_position_inputs(op)
for x in xs:
assert check_dot(x)
assert check_dot(x), 'check_dot(x) must be True'
axis = op.attr('axis')
sections = [x.shape[axis] for x in xs]
return split(y_bar, num_or_sections=sections, axis=axis)
Expand All @@ -635,7 +641,7 @@ def concat_transpose(op, check_dot, y_bar):
@REGISTER_TRANSPOSE('reduce_p')
def reduce_transpose(op, check_dot, y_bar):
x, = op_position_inputs(op)
assert check_dot(x)
assert check_dot(x), 'check_dot(x) must be True'
axes = op.attr('axis')
shape = tuple(1 if i in axes else size for i, size in enumerate(x.shape))
t = reshape(y_bar, shape=shape)
Expand All @@ -645,7 +651,9 @@ def reduce_transpose(op, check_dot, y_bar):
@REGISTER_TRANSPOSE('matmul_p')
def matmul_transpose(op, check_dot, z_bar):
x, y = op_position_inputs(op)
assert check_dot(x) ^ check_dot(y)
assert check_dot(x) ^ check_dot(y), (
f'(check_dot(x) ^ check_dot(y)) must be True, '
f'but check_dot(x)={check_dot(x)} and check_dot(y)={check_dot(y)}.')
# TODO: replace it. this is hacky
axis = [1, 0] if len(x.shape) == 2 else [0, 2, 1]
if check_dot(x):
Expand All @@ -657,7 +665,7 @@ def matmul_transpose(op, check_dot, z_bar):
@REGISTER_TRANSPOSE('slice_select_p')
def slice_select_transpose(op, check_dot, y_bar):
x, = op_position_inputs(op)
assert check_dot(x)
assert check_dot(x), 'check_dot(x) must be True'
zeros = fill_const(value=0.0, shape=x.shape, dtype=x.dtype)
axis = op.attr('axis')
starts = op.attr('starts')
Expand All @@ -670,7 +678,9 @@ def slice_select_transpose(op, check_dot, y_bar):
@REGISTER_TRANSPOSE('slice_assign_p')
def slice_assign_transpose(op, check_dot, z_bar):
x, y = op_position_inputs(op)
assert check_dot(x) and check_dot(y)
assert check_dot(x) and check_dot(y), (
f'(check_dot(x) and check_dot(y)) must be True, '
f'but check_dot(x)={check_dot(x)} and check_dot(y)={check_dot(y)}.')
zeros = fill_const(value=0.0, shape=y.shape, dtype=y.dtype)
axis = op.attr('axis')
starts = op.attr('starts')
Expand All @@ -686,7 +696,7 @@ def slice_assign_transpose(op, check_dot, z_bar):
@REGISTER_TRANSPOSE('gather_p')
def gather_transpose(op, check_dot, y_bar):
x, indextensor = op_position_inputs(op)
assert check_dot(x)
assert check_dot(x), 'check_dot(x) must be True'
axis = op.attr('axis')
zeros = fill_const(0.0, x.shape, x.dtype)
x_bar = scatter_add(zeros, y_bar, indextensor, axis=axis)
Expand All @@ -697,7 +707,9 @@ def gather_transpose(op, check_dot, y_bar):
@REGISTER_TRANSPOSE('scatter_add_p')
def scatter_add_transpose(op, check_dot, z_bar):
x, y, indextensor = op_position_inputs(op)
assert check_dot(x) and check_dot(y)
assert check_dot(x) and check_dot(y), (
f'(check_dot(x) and check_dot(y)) must be True, '
f'but check_dot(x)={check_dot(x)} and check_dot(y)={check_dot(y)}.')
axis = op.attr('axis')
zeros = fill_const(value=0.0, shape=y.shape, dtype=y.dtype)
x_bar = scatter_add(z_bar, zeros, indextensor, axis=axis)
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/incubate/autograd/primx.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def dot2bar_rec(self, dots):

if isinstance(dots, paddle.fluid.framework.Variable):
bar = self.dot2bar.lookup(dots)
assert bar is not None, 'bar is None.'
assert bar is not None, 'bar must be not None'
return bar

bars = [self.dot2bar_rec(dot) for dot in dots]
Expand Down

0 comments on commit 28e812e

Please sign in to comment.