Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

support property process in TensorVariable #170

Merged
merged 11 commits into from
Jun 25, 2023
33 changes: 21 additions & 12 deletions sot/infer_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,28 @@ def value_fn(self, *args, **kwargs):


class MetaInfo:
def __init__(self, shape, dtype, stop_gradient):
def __init__(
self, shape, dtype, stop_gradient, name, persistable, type, place
):
self.name = name
self.persistable = persistable
self.type = type
self.place = place
self.shape = shape
self.dtype = dtype
self.stop_gradient = stop_gradient

@staticmethod
def from_tensor(tensor):
return MetaInfo(tensor.shape, tensor.dtype, tensor.stop_gradient)
return MetaInfo(
list(tensor.shape),
tensor.dtype,
tensor.stop_gradient,
tensor.name,
tensor.persistable,
tensor.type,
tensor.place,
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

目前这样会影响 Tensor guard

@2742195759 这些信息也应该放到 meta 里嘛?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

见AST动转静的CacheKey:
image
这里如果 MetaInfo 添加了这些,那么 MetaInfo 进行Guard判定时对齐动转静。


def is_dynamic_shape(self):
"""
Expand All @@ -40,6 +54,9 @@ def to_input_spec(self):
self.shape, dtype=self.dtype, stop_gradient=self.stop_gradient
)

def guard_str(self):
return f"({self.shape}, {self.dtype}, {self.stop_gradient})"

def __repr__(self):
return meta_str(self.shape, self.dtype, self.stop_gradient)

Expand Down Expand Up @@ -128,11 +145,7 @@ def variable_to_meta_info(args):
return map_if(
args,
pred=lambda x: isinstance(x, paddle.static.Variable),
true_fn=lambda x: MetaInfo(
list(x.shape),
x.dtype,
x.stop_gradient,
),
true_fn=lambda x: MetaInfo.from_tensor(x),
false_fn=lambda x: x,
)

Expand All @@ -153,11 +166,7 @@ def infer_meta_for_layer(layer, *args, **kwargs):
args, kwargs = convert_to_input_spec(args), convert_to_input_spec(kwargs)
concrete_program = layer.forward.get_concrete_program(*args, **kwargs)[0]
out = concrete_program.outputs[0]
out = MetaInfo(
list(out.shape),
out.dtype,
out.stop_gradient,
)
out = MetaInfo.from_tensor(out)
layer.forward.rollback()
return out

Expand Down
33 changes: 33 additions & 0 deletions sot/opcode_translator/executor/variable_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from functools import partial
from typing import TYPE_CHECKING

import paddle

from ...utils import BreakGraphError, NotImplementException
from ...utils.magic_methods import (
BINARY_OPS,
Expand Down Expand Up @@ -97,6 +99,37 @@
lambda var: var.bool(),
)

# TensorVariable
Dispatcher.register(
paddle.is_tensor,
("TensorVariable",),
{},
lambda var: var.is_tensor(),
)
Dispatcher.register(
paddle.is_complex,
("TensorVariable",),
{},
lambda var: var.is_complex(),
)
Dispatcher.register(
paddle.is_integer,
("TensorVariable",),
{},
lambda var: var.is_integer(),
)
Dispatcher.register(
paddle.is_floating_point,
("TensorVariable",),
{},
lambda var: var.is_floating_point(),
)
Dispatcher.register(
paddle.rank,
("TensorVariable",),
{},
lambda var: var.ndim,
)

# VariableBase
Dispatcher.register(
Expand Down
96 changes: 78 additions & 18 deletions sot/opcode_translator/executor/variables/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,21 +33,34 @@
if TYPE_CHECKING:
from ..function_graph import FunctionGraph

DTYPE_ABBRS = {

FP_DTYPE_ABBRS = {
paddle.bfloat16: 'bfloat16',
paddle.float64: 'float64',
paddle.float32: 'float32',
paddle.float16: 'float16',
}

CP_DTYPE_ABBRS = {
paddle.complex64: 'complex64',
paddle.complex128: 'complex128',
}

INT_DTYPE_ABBRS = {
paddle.int8: 'int8',
paddle.int16: 'int16',
paddle.int32: 'int32',
paddle.int64: 'int64',
paddle.bool: 'bool',
paddle.uint8: 'uint8',
}

DTYPE_ABBRS = {
**FP_DTYPE_ABBRS,
**CP_DTYPE_ABBRS,
**INT_DTYPE_ABBRS,
paddle.bool: 'bool',
}


class ConstantVariable(VariableBase):
def __init__(
Expand Down Expand Up @@ -95,6 +108,14 @@ def wrap_literal(value: Any) -> ConstantVariable:
return ConstantVariable(value, ConstTracker(value))


implemented_property = set()
SigureMo marked this conversation as resolved.
Show resolved Hide resolved


def tensor_property(func):
implemented_property.add(func.__name__)
SigureMo marked this conversation as resolved.
Show resolved Hide resolved
return property(func)


class TensorVariable(VariableBase):
var_name_generator = NameGenerator("var_")

Expand Down Expand Up @@ -151,7 +172,7 @@ def make_stringify_guard(self) -> StringifyExpression:
),
)
return StringifyExpression(
f"str(MetaInfo.from_tensor({frame_value_tracer.expr})) == '{self.meta}'",
f"MetaInfo.from_tensor({frame_value_tracer.expr}).guard_str() == '{self.meta.guard_str()}'",
union_free_vars(
{"MetaInfo": MetaInfo},
frame_value_tracer.free_vars,
Expand Down Expand Up @@ -186,7 +207,7 @@ def __setitem__(self, key, value):
value,
)

@property
@tensor_property
def T(self):
perm = list(range(len(self.meta.shape) - 1, -1, -1))
perm_var = VariableFactory.from_value(
Expand All @@ -195,40 +216,79 @@ def T(self):
out = self.graph.call_paddle_api(paddle.transpose, self, perm_var)
return out

@property
@tensor_property
def ndim(self):
return ConstantVariable.wrap_literal(len(self.meta.shape))

@property
def shape(self):
@tensor_property
def size(self):
# TODO: maybe break graph.
if self.meta.is_dynamic_shape():
raise BreakGraphError(
f"Getting size for a dynamic shape tensor causes graph break. shape = {self.meta.shape}"
)
elements = reduce(operator.mul, self.meta.shape, 1)
return ConstantVariable.wrap_literal(elements)

@tensor_property
def shape(self):
if self.meta.is_dynamic_shape():
raise BreakGraphError(
f"Getting shape for a dynamic shape tensor causes graph break. shape = {self.meta.shape}"
)
self.graph.add_global_guarded_variable(self)
return VariableFactory.from_value(
self.meta.shape, self.graph, tracker=ConstTracker(self.meta.shape)
)

@property
def size(self):
# TODO: maybe break graph.
if self.meta.is_dynamic_shape():
raise BreakGraphError(
f"Getting size for a dynamic shape tensor causes graph break. shape = {self.meta.shape}"
)
elements = reduce(operator.mul, self.meta.shape, 1)
return ConstantVariable.wrap_literal(elements)
def is_tensor(self):
return ConstantVariable.wrap_literal(True)

def is_complex(self):
dtype = self.meta.dtype
is_cp_dtype = dtype in CP_DTYPE_ABBRS
return ConstantVariable.wrap_literal(is_cp_dtype)

def is_integer(self):
dtype = self.meta.dtype
is_int_dtype = dtype in INT_DTYPE_ABBRS
return ConstantVariable.wrap_literal(is_int_dtype)

def is_floating_point(self):
dtype = self.meta.dtype
is_fp_dtype = dtype in FP_DTYPE_ABBRS
return ConstantVariable.wrap_literal(is_fp_dtype)

def getattr(self, name: str):
if name in ["shape", "dtype", "stop_gradient"]:
method_name_to_builtin_fn = {
"dim": paddle.rank,
"ndimension": paddle.rank,
"is_tensor": paddle.is_tensor,
"is_complex": paddle.is_complex,
"is_integer": paddle.is_integer,
"is_floating_point": paddle.is_floating_point,
}
if name in ["dtype", "type", "name", "persistable", "stop_gradient"]:
if name == "name" and self.meta.name.startswith(
"infer_meta_variable_tmp"
):
raise BreakGraphError(f"{self.meta.name} is a middle tensor.")
return VariableFactory.from_value(
getattr(self.meta, name),
self.graph,
tracker=GetAttrTracker(self, name),
)
elif name in ["T", "ndim", "size"]:
elif name in implemented_property:
SigureMo marked this conversation as resolved.
Show resolved Hide resolved
return getattr(self, name)
elif name in method_name_to_builtin_fn:
# TODO: backward, gradient
from .callable import BuiltinVariable

builtin_fn = method_name_to_builtin_fn[name]

return BuiltinVariable(
builtin_fn, self.graph, DanglingTracker()
).bind(self, name)
elif name in paddle_tensor_methods:
from .callable import TensorFunctionVariable

Expand Down
24 changes: 22 additions & 2 deletions tests/test_18_tensor_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,21 @@ def tensor_method_passed_by_user(a: paddle.Tensor, func: paddle.Tensor):


def tensor_method_property(a: paddle.Tensor, b: paddle.Tensor):
return a @ b.T + len(a.shape) + b.size + a.ndim
return (
a.name,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

突然想到一个问题,这里 a.name 应该没啥问题,但 (a + b).name 是通过 infer meta 计算的,这里应该不对的

a.name 在中间节点应该 break graph 的

按照这个思路可以再看看其他几个是否有类似的问题

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

当获取中间变量的name,如果不打断是这种分别生成的

- infer_meta_variable_tmp_0
+ eager_tmp_2

如果打断是这种序号会持续累加的

AssertionError: 'eager_tmp_2' != 'eager_tmp_3'
- eager_tmp_2
?           ^
+ eager_tmp_3
?           ^

所以可能这里的测试case是不能加进去的

另外这里我是不是只需要通过是否以infer_meta_variable_tmp开头来判断是否是中间变量?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

另外这里我是不是只需要通过是否以infer_meta_variable_tmp开头来判断是否是中间变量?

self.value == None 即是中间变量,但不要直接判断,将其封装成一个函数,比如 is_leaf(不等于的情况),或者其他名字

str(a.place),
a.persistable,
a.dtype,
a.type,
a.is_tensor(),
a.clear_gradient(),
a @ b.T + len(a.shape) + b.size + a.ndim + a.dim() + a.rank(),
)


def middle_tensor_name(a: paddle.Tensor, b: paddle.Tensor):
c = a + b
return c.name


class TestTensorMethod(TestCaseBase):
Expand All @@ -47,9 +61,15 @@ def test_tensor_method_passed_by_user(self):
self.assert_results(tensor_method_passed_by_user, x, y.add)

def test_tensor_method_property(self):
x = paddle.rand([42, 24], dtype='float64')
y = paddle.rand([42, 24], dtype='float32')
self.assert_results(tensor_method_property, x, y)

@unittest.skip("TODO: dynamic tensor name is different")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个中间变量的问题可以在下一个 PR 处理~

def test_middle_tensor_name(self):
x = paddle.rand([42, 24])
y = paddle.rand([42, 24])
self.assert_results(tensor_method_property, x, y)
self.assert_results(middle_tensor_name, x, y)


if __name__ == "__main__":
Expand Down