-
Notifications
You must be signed in to change notification settings - Fork 25
support property process in TensorVariable #170
Changes from 10 commits
3b7b51e
f490585
2dfaa13
e746cfd
3286022
676d1b6
0d1e599
836a8f0
e85d739
df3a024
132649d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,7 +26,21 @@ def tensor_method_passed_by_user(a: paddle.Tensor, func: paddle.Tensor): | |
|
||
|
||
def tensor_method_property(a: paddle.Tensor, b: paddle.Tensor): | ||
return a @ b.T + len(a.shape) + b.size + a.ndim | ||
return ( | ||
a.name, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 突然想到一个问题,这里 a.name 应该没啥问题,但
按照这个思路可以再看看其他几个是否有类似的问题 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 当获取中间变量的name,如果不打断是这种分别生成的 - infer_meta_variable_tmp_0
+ eager_tmp_2 如果打断是这种序号会持续累加的 AssertionError: 'eager_tmp_2' != 'eager_tmp_3'
- eager_tmp_2
? ^
+ eager_tmp_3
? ^ 所以可能这里的测试case是不能加进去的 另外这里我是不是只需要通过是否以 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
str(a.place), | ||
a.persistable, | ||
a.dtype, | ||
a.type, | ||
a.is_tensor(), | ||
a.clear_gradient(), | ||
a @ b.T + len(a.shape) + b.size + a.ndim + a.dim() + a.rank(), | ||
) | ||
|
||
|
||
def middle_tensor_name(a: paddle.Tensor, b: paddle.Tensor): | ||
c = a + b | ||
return c.name | ||
|
||
|
||
class TestTensorMethod(TestCaseBase): | ||
|
@@ -47,9 +61,15 @@ def test_tensor_method_passed_by_user(self): | |
self.assert_results(tensor_method_passed_by_user, x, y.add) | ||
|
||
def test_tensor_method_property(self): | ||
x = paddle.rand([42, 24], dtype='float64') | ||
y = paddle.rand([42, 24], dtype='float32') | ||
self.assert_results(tensor_method_property, x, y) | ||
|
||
@unittest.skip("TODO: dynamic tensor name is different") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个中间变量的问题可以在下一个 PR 处理~ |
||
def test_middle_tensor_name(self): | ||
x = paddle.rand([42, 24]) | ||
y = paddle.rand([42, 24]) | ||
self.assert_results(tensor_method_property, x, y) | ||
self.assert_results(middle_tensor_name, x, y) | ||
|
||
|
||
if __name__ == "__main__": | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
目前这样会影响 Tensor guard
@2742195759 这些信息也应该放到 meta 里嘛?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
见AST动转静的CacheKey:
这里如果 MetaInfo 添加了这些,那么 MetaInfo 进行Guard判定时对齐动转静。