Skip to content

Commit

Permalink
Add some assert message for primx.py
Browse files Browse the repository at this point in the history
  • Loading branch information
0x45f committed May 14, 2022
1 parent b87ce07 commit 288949f
Showing 1 changed file with 30 additions and 11 deletions.
41 changes: 30 additions & 11 deletions python/paddle/incubate/autograd/primx.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def topo_path(xs, ys, block=None):

# Initialize reached vars
for x in xs:
assert x is None or x.block == block
assert x is None or x.block == block, f'x is not None and x.block != block'
reached_vars[id(x)] = x

# Reaching test, returning whether an op is reached from the given input
Expand Down Expand Up @@ -122,7 +122,10 @@ def add_rec(self, key_vars, value_vars):
f'value_vars must be Variable, but got {type(value_vars)}')
self.tab[id(key_vars)] = id(value_vars)
else:
assert len(key_vars) == len(value_vars)
assert len(key_vars) == len(value_vars), (
f'len(key_vars) shoule be equal to len(value_vars), '
f'but len(key_vars)={len(key_vars)} and len(value_vars)={len(value_vars)}.'
)
for key_var, value_var in zip(key_vars, value_vars):
self.add_rec(key_var, value_var)

Expand Down Expand Up @@ -224,7 +227,7 @@ def dot2bar_rec(self, dots):

if isinstance(dots, paddle.fluid.framework.Variable):
bar = self.dot2bar.lookup(dots)
assert bar is not None
assert bar is not None, 'bar is None.'
return bar

bars = [self.dot2bar_rec(dot) for dot in dots]
Expand All @@ -251,11 +254,17 @@ def linearize(self, xs, ys, xs_dot=None):
xs_dot = [fill_const(1.0, shape=x.shape, dtype=x.dtype) for x in xs]
self.add_vars(xs_dot)
else:
assert len(xs) == len(xs_dot)
assert len(xs) == len(xs_dot), (
f'len(xs) should be equal to len(xs_dot), '
f'but len(xs)={len(xs)} and len(xs_dot)={len(xs_dot)}')

for x, dot in zip(xs, xs_dot):
assert x.dtype == dot.dtype
assert x.shape == dot.shape
assert x.dtype == dot.dtype, (
f'x.dtype should be equal to dot.dtype, '
f'but x.dtype={x.dtype} and dot.dtype={dot.dtype}')
assert x.shape == dot.shape, (
f'x.shape should be equal to dot.shape, '
f'but x.shape={x.shape} and dot.shape={dot.shape}')
self.var2dot.add(x, dot)

path, unused_xs, _ = topo_path(xs, ys, self.block)
Expand Down Expand Up @@ -308,10 +317,18 @@ def transpose(self, ys_dot, xs_dot, ys_bar=None, retain_fwd=False):
ys_bar.append(fill_const(1.0, shape=y.shape, dtype=y.dtype))
self.add_vars(ys_bar)
else:
assert len(ys_dot) == len(ys_bar)
assert len(ys_dot) == len(ys_bar), (
f'len(ys_dot) should be equal to len(ys_bar), '
f'but len(ys_dot)={len(ys_dot)} and len(ys_bar)={len(ys_bar)}')
for y_dot, y_bar in zip(ys_dot, ys_bar):
assert y_dot.shape == y_bar.shape
assert y_dot.dtype == y_bar.dtype
assert y_dot.shape == y_bar.shape, (
f'y_dot.shape should be equal to y_bar.shape, '
f'but y_dot.shape={y_dot.shape} and y_bar.shape={y_bar.shape}'
)
assert y_dot.dtype == y_bar.dtype, (
f'y_dot.dtype should be equal to y_bar.dtype, '
f'but y_dot.dtype={y_dot.dtype} and y_bar.dtype={y_bar.dtype}'
)

for dot, bar in zip(ys_dot, ys_bar):
self.dot2bar.add(dot, bar)
Expand Down Expand Up @@ -344,7 +361,9 @@ def transpose(self, ys_dot, xs_dot, ys_bar=None, retain_fwd=False):

ins_bar = flatten(ins_bar_rec)
ins = flatten(op_position_inputs(op))
assert len(ins) == len(ins_bar)
assert len(ins) == len(ins_bar), (
f'len(ins) should be equal to len(ins_bar), '
f'but len(ins)={len(ins)} and len(ins_bar)={len(ins_bar)}')

for dot, bar in zip(ins, ins_bar):
if bar is not None:
Expand Down Expand Up @@ -414,7 +433,7 @@ def _gradients(ys, xs, ys_bar=None):
for var in xs_dot:
if var is not None:
op_index = block.ops.index(var.op)
assert op_index >= 0
assert op_index >= 0, f'op_index should be greater than or equal to 0, but op_index={op_index}.'
op_indexes.append(op_index)

ad.erase_ops(sorted(op_indexes))
Expand Down

0 comments on commit 288949f

Please sign in to comment.