Skip to content

Commit

Permalink
[OpAtttr]Add attribute var interface for Operator class (#45525)
Browse files Browse the repository at this point in the history
* [OpAtttr]Add attribute var interface for Operator class

* fix unittest

* fix unittest
  • Loading branch information
Aurelius84 authored Aug 30, 2022
1 parent fe321f9 commit e221a60
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 8 deletions.
52 changes: 44 additions & 8 deletions python/paddle/fluid/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -3138,7 +3138,7 @@ def attr_type(self, name):
Returns:
core.AttrType: the attribute type.
"""
return self.desc.attr_type(name)
return self.desc.attr_type(name, True)

def _set_attr(self, name, val):
"""
Expand Down Expand Up @@ -3290,6 +3290,41 @@ def _blocks_attr_ids(self, name):

return self.desc._blocks_attr_ids(name)

def _var_attr(self, name):
"""
Get the Variable attribute by name.
Args:
name(str): the attribute name.
Returns:
Variable: the Variable attribute.
"""
attr_type = self.desc.attr_type(name, True)
assert attr_type == core.AttrType.VAR, "Required type attr({}) is Variable, but received {}".format(
name, attr_type)
attr_var_name = self.desc.attr(name, True).name()
return self.block._var_recursive(attr_var_name)

def _vars_attr(self, name):
"""
Get the Variables attribute by name.
Args:
name(str): the attribute name.
Returns:
Variables: the Variables attribute.
"""
attr_type = self.desc.attr_type(name, True)
assert attr_type == core.AttrType.VARS, "Required type attr({}) is list[Variable], but received {}".format(
name, attr_type)
attr_vars = [
self.block._var_recursive(var.name())
for var in self.desc.attr(name, True)
]
return attr_vars

def all_attrs(self):
"""
Get the attribute dict.
Expand All @@ -3300,16 +3335,17 @@ def all_attrs(self):
attr_names = self.attr_names
attr_map = {}
for n in attr_names:
attr_type = self.desc.attr_type(n)
attr_type = self.desc.attr_type(n, True)
if attr_type == core.AttrType.BLOCK:
attr_map[n] = self._block_attr(n)
continue

if attr_type == core.AttrType.BLOCKS:
elif attr_type == core.AttrType.BLOCKS:
attr_map[n] = self._blocks_attr(n)
continue

attr_map[n] = self.attr(n)
elif attr_type == core.AttrType.VAR:
attr_map[n] = self._var_attr(n)
elif attr_type == core.AttrType.VARS:
attr_map[n] = self._vars_attr(n)
else:
attr_map[n] = self.attr(n)

return attr_map

Expand Down
4 changes: 4 additions & 0 deletions python/paddle/fluid/tests/unittests/test_attribute_var.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,10 @@ def test_static(self):
infer_out = self.infer_prog()
self.assertEqual(infer_out.shape, (10, 10))

self.assertEqual(
main_prog.block(0).ops[4].all_attrs()['dropout_prob'].name,
p.name)


class TestTileTensorList(UnittestBase):

Expand Down
6 changes: 6 additions & 0 deletions python/paddle/fluid/tests/unittests/test_reverse_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,12 @@ def call_func(self, x):
# axes is a List[Variable]
axes = [paddle.assign([0]), paddle.assign([2])]
out = paddle.fluid.layers.reverse(x, axes)

# check attrs
axis_attrs = paddle.static.default_main_program().block(
0).ops[-1].all_attrs()["axis"]
self.assertTrue(axis_attrs[0].name, axes[0].name)
self.assertTrue(axis_attrs[1].name, axes[1].name)
return out


Expand Down

0 comments on commit e221a60

Please sign in to comment.