-
Notifications
You must be signed in to change notification settings - Fork 25
support property process in TensorVariable #170
Conversation
Thanks for your contribution! |
@@ -105,7 +105,7 @@ def register( | |||
): | |||
if fn not in cls.handlers: | |||
cls.handlers[fn] = [] | |||
cls.handlers[fn].append((Pattern(*types, **kwtypes), handler)) | |||
cls.handlers[fn].insert(0, (Pattern(*types, **kwtypes), handler)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为什么要插入到最前面呢?我们现在是需要这个顺序的,优先搜索先 append 的
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这块当时是想着子类可能需要覆盖掉父类的Dispatcher,类似下面的这种实现
Dispatcher.register(
getattr,
("VariableBase", "str"),
{},
lambda var, name: var.getattr(name),
)
@Dispatcher.register_decorator()
def getattr(var: TensorVariable, name: str):
if name in ["dtype", "type", "persistable", "name", "stop_gradient"]:
return VariableFactory.from_value(
getattr(self.meta, name),
self.graph,
tracker=GetAttrTracker(self, name),
)
elif name in implemented_property:
return getattr(self, name)
elif name in implemented_method:
# TODO: backward, gradient
from .callable import MethodVariable
attr = getattr(self, name)
return MethodVariable.wrap_method(
value=attr,
instance=self,
graph=self.graph,
tracker=GetAttrTracker(self, name),
method_name=name,
)
elif name in paddle_tensor_methods:
from .callable import TensorFunctionVariable
fn_var = TensorFunctionVariable(
name, graph=self.graph, tracker=DanglingTracker()
)
return fn_var.bind(self, name)
else:
raise InnerError(f"Unknown Tensor attribute: {name}")
但是后来发现这种方法可读性还不如之前的写法,然后我就又给改回来了,但是这个部分我想的是我的一般理解下,应该也是后面的会把前面的覆盖掉,所以这部分修改我就给保留了下来。
之前没注意到现在是需要 append 的顺序的,我现在改回去
@@ -25,7 +25,7 @@ | |||
] | |||
|
|||
|
|||
ConstTypes = (int, float, str, bool, type(None)) | |||
ConstTypes = (int, float, str, bool, type(None), paddle.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dtype 也是应该放到 ConstantVariable 的嘛?现在应该是 ObjectVariable?这样是有什么问题嘛?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我刚刚在最新的代码里试了一下貌似不会有问题了,这个我记得是之前报了个什么错误,我就给加进去了,看起来应该是当时我的代码有问题
tensor.persistable, | ||
tensor.type, | ||
tensor.place, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
目前这样会影响 Tensor guard
@2742195759 这些信息也应该放到 meta 里嘛?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tensor_method | ||
def is_tensor(self): | ||
if self.value is None: | ||
return False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为什么 self.value is None
时直接 return False
了呢?为 None 表示该 Tensor 是一个中间结果,并不是不是 Tensor
这里部分 is_xxx
的方法应该可以通过 metadata 来直接判断,如果不能判断,则需要打断子图
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个 is_tensor 我看了一下paddle的实现,貌似只是判断了一下 是不是Tensor,那这里TensorVariable 是不是就恒为True啦
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个 is_tensor 我看了一下paddle的实现,貌似只是判断了一下 是不是Tensor,那这里TensorVariable 是不是就恒为True啦
嗯,应该是没有问题的
from .callable import MethodVariable | ||
|
||
attr = getattr(self, name) | ||
return MethodVariable.wrap_method( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
emmm,这里生成的 FunctionVariable 是 UserDefinedFunctionVariable?是直接 inline call 了?可以利用 BuiltinVariable 的 dispatch 机制转发到这些方法上,可参考 dict.keys 等
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
or dtype == paddle.uint8 | ||
or dtype == paddle.int16 | ||
or dtype == paddle.int32 | ||
or dtype == paddle.int64 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
or dtype == paddle.int64 | |
is_int_dtype = dtype in [paddle.int8, paddle.uint8, ....] |
可以用 in list 来判断
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的,这块刚刚修改了,现在直接复用了这块的新代码,
FP_DTYPE_ABBRS = {
...
}
CP_DTYPE_ABBRS = {
...
}
INT_DTYPE_ABBRS = {
...
}
DTYPE_ABBRS = {
**FP_DTYPE_ABBRS,
**CP_DTYPE_ABBRS,
**INT_DTYPE_ABBRS,
paddle.bool: 'bool',
}
dtype in FP_DTYPE_ABBRS
@2742195759 麻烦帮忙再看下还有什么问题嘛 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我这边目前没什么问题,不过这个 PR 需要 @2742195759 的确认
.gitignore
Outdated
|
||
# Build | ||
build/ | ||
*.egg-info |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
奇怪,这些上游已经 merge 的 diff 为什么会出现在这个 PR 里呢?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我rebase了一下,现在应该是没问题了
@@ -26,7 +26,16 @@ def tensor_method_passed_by_user(a: paddle.Tensor, func: paddle.Tensor): | |||
|
|||
|
|||
def tensor_method_property(a: paddle.Tensor, b: paddle.Tensor): | |||
return a @ b.T + len(a.shape) + b.size + a.ndim | |||
return ( | |||
a.name, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
突然想到一个问题,这里 a.name 应该没啥问题,但 (a + b).name
是通过 infer meta 计算的,这里应该不对的
a.name
在中间节点应该 break graph 的
按照这个思路可以再看看其他几个是否有类似的问题
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
当获取中间变量的name,如果不打断是这种分别生成的
- infer_meta_variable_tmp_0
+ eager_tmp_2
如果打断是这种序号会持续累加的
AssertionError: 'eager_tmp_2' != 'eager_tmp_3'
- eager_tmp_2
? ^
+ eager_tmp_3
? ^
所以可能这里的测试case是不能加进去的
另外这里我是不是只需要通过是否以infer_meta_variable_tmp
开头来判断是否是中间变量?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
另外这里我是不是只需要通过是否以infer_meta_variable_tmp开头来判断是否是中间变量?
self.value == None
即是中间变量,但不要直接判断,将其封装成一个函数,比如 is_leaf
(不等于的情况),或者其他名字
现在的情况是guard命中的情况,原来AST动转静中的 Guard 只有 shape、dtype和stop gradient 作为CacheKey,这里我复用动转静,所以也应该对齐。MetaInfo中可以加入额外的参数,但是需要在 Guard 对metainfo的判定中将额外的参数删除掉,只对比上述三个。 |
现在这个eq的魔术函数还是之前的比较这三个参数,是不是就可以了 def __eq__(self, meta):
return (
self.shape == meta.shape
and self.dtype == meta.dtype
and self.stop_gradient == meta.stop_gradient
) |
可是我们的 Guard 目前是字符串比较的,这样是不是有问题?可以在 Guard 里显式修改下三个相等并使用 and 串联起来 |
有道理,这块我看到具体的实现是 str(MetaInfo.from_tensor(frame.f_locals['func'].__self__)) == '(shape: [42], dtype: paddle.float32, stop_gradient: True)' 为了方便维护的话,看起来需要添加一个__str__,我去试试这个方法 |
好像不太对,我发现已经有一个 def meta_str(shape, dtype, stop_gradient):
return f"(shape: {shape}, dtype: {dtype}, stop_gradient: {stop_gradient})" 的实现了,这个实现相当于把str这个过程给确定下来了 |
这里 str(MetaInfo.from_tensor(frame.f_locals['func'].self)) 的结果也是只包括那三个参数的 |
这里不要依赖这种行为了,一旦将来改了 |
好的,已修改 |
y = paddle.rand([42, 24], dtype='float32') | ||
self.assert_results(tensor_method_property, x, y) | ||
|
||
@unittest.skip("TODO: dynamic tensor name is different") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个中间变量的问题可以在下一个 PR 处理~
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
本pr按照以下规则筛选了一下Tensor的所有attr
已经实现全部的 property ['ndim', 'size', 'T']
已经实现部分 attr ['stop_gradient', 'shape', 'dtype']
返回类型不为Tensor且Tensor与Variable均支持的方法分为 method 和 unbound function。
其中 method 包括 [('numpy', <class 'numpy.ndarray'>), ('clear_gradient', <class 'NoneType'>), ('element_size', <class 'int'>)] ,现阶段的处理方案已经不会出现问题。
unbound function 包括 [('backward', <class 'NoneType'>), ('gradient', <class 'NoneType'>), ('dim', <class 'int'>), ('ndimension', <class 'int'>), ('is_tensor', <class 'bool'>), ('is_complex', <class 'bool'>), ('is_integer', <class 'bool'>), ('is_floating_point', <class 'bool'>)] ,除 backward 和 gradient 已经全部实现
筛选部分的代码见
closes: #151