Skip to content

Commit

Permalink
pass simple lower pass
Browse files Browse the repository at this point in the history
  • Loading branch information
levi131 committed Apr 15, 2022
1 parent 4d5a92b commit 3ef4017
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 6 deletions.
6 changes: 3 additions & 3 deletions python/paddle/autograd/primrules.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,14 +98,14 @@ def trans(shape):
def elementwise_add_orig2prim(op, x, y):
if x.shape != y.shape:
y = broadcast(y, shape=x.shape)
if op.attr('Scale_x'):
if op.attr('Scale_x') - 1.0 > 1e-5:
tmp = fill_const(shape=x.shape, dtype=x.dtype, value=op.attr('Scale_x'))
x = mul(x, tmp)
if op.attr('Scale_y'):
if op.attr('Scale_y') - 1.0 > 1e-5:
tmp = fill_const(shape=y.shape, dtype=y.dtype, value=op.attr('Scale_y'))
y = mul(y, tmp)
z = add(x, y)
if op.attr('Scale_out'):
if op.attr('Scale_out') - 1.0 > 1e-5:
tmp = fill_const(
shape=z.shape, dtype=z.dtype, value=op.attr('Scale_out'))
z = mul(z, tmp)
Expand Down
10 changes: 8 additions & 2 deletions python/paddle/autograd/primx.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from paddle.fluid.framework import default_main_program, default_startup_program
from paddle.fluid import unique_name, core
from .primops import fill_const, add
from .primrules import get_input_vars, get_output_vars, _orig2prim, _prim2orig, _jvp, _transpose
from .primreg import op_position_inputs, op_position_output, lookup_orig2prim, lookup_prim2orig
from .primrules import get_input_vars, get_output_vars, _orig2prim, _prim2orig, _jvp, _transpose
from collections import OrderedDict
Expand Down Expand Up @@ -266,6 +265,13 @@ def prim2orig(block=None):
_lower(block, reverse=True)


def to_tensors(xs):
if isinstance(xs, list or tuple):
return xs
else:
return [xs]


def _lower(block, reverse):
lower_fn = _prim2orig if reverse else _orig2prim
lookup_fn = lookup_prim2orig if reverse else lookup_orig2prim
Expand Down Expand Up @@ -293,7 +299,7 @@ def _lower(block, reverse):
print("op_type: ", op.type)
print("input_args: ", input_args)
for orig_out, new_out in zip(
get_output_vars(op), lower_fn(op, *input_args)):
get_output_vars(op), to_tensors(lower_fn(op, *input_args))):
assert not (orig_out is None) ^ (
new_out is None), "orig_out and new_out should match."
vars_to_remove.append(orig_out.name)
Expand Down
7 changes: 6 additions & 1 deletion python/paddle/fluid/tests/unittests/test_primops.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def setUp(self):
# for op in topo_path(vs, grads):
# print(op)

def test_orig2prim(self):
def test_lower(self):
main = paddle.static.Program()
startup = paddle.static.Program()
with paddle.static.program_guard(main, startup):
Expand All @@ -191,6 +191,11 @@ def test_orig2prim(self):
print(f'-------test_orig2prim: prim-------')
print(x.block)

prim2orig(x.block)

print(f'-------test_orig2prim: orig-------')
print(x.block)


if __name__ == '__main__':
unittest.main()

0 comments on commit 3ef4017

Please sign in to comment.