diff --git a/sot/infer_meta.py b/sot/infer_meta.py index 5fe1182bf..d7f18d944 100644 --- a/sot/infer_meta.py +++ b/sot/infer_meta.py @@ -19,14 +19,28 @@ def value_fn(self, *args, **kwargs): class MetaInfo: - def __init__(self, shape, dtype, stop_gradient): + def __init__( + self, shape, dtype, stop_gradient, name, persistable, type, place + ): + self.name = name + self.persistable = persistable + self.type = type + self.place = place self.shape = shape self.dtype = dtype self.stop_gradient = stop_gradient @staticmethod def from_tensor(tensor): - return MetaInfo(tensor.shape, tensor.dtype, tensor.stop_gradient) + return MetaInfo( + list(tensor.shape), + tensor.dtype, + tensor.stop_gradient, + tensor.name, + tensor.persistable, + tensor.type, + tensor.place, + ) def is_dynamic_shape(self): """ @@ -40,6 +54,9 @@ def to_input_spec(self): self.shape, dtype=self.dtype, stop_gradient=self.stop_gradient ) + def guard_str(self): + return f"({self.shape}, {self.dtype}, {self.stop_gradient})" + def __repr__(self): return meta_str(self.shape, self.dtype, self.stop_gradient) @@ -128,11 +145,7 @@ def variable_to_meta_info(args): return map_if( args, pred=lambda x: isinstance(x, paddle.static.Variable), - true_fn=lambda x: MetaInfo( - list(x.shape), - x.dtype, - x.stop_gradient, - ), + true_fn=lambda x: MetaInfo.from_tensor(x), false_fn=lambda x: x, ) @@ -153,11 +166,7 @@ def infer_meta_for_layer(layer, *args, **kwargs): args, kwargs = convert_to_input_spec(args), convert_to_input_spec(kwargs) concrete_program = layer.forward.get_concrete_program(*args, **kwargs)[0] out = concrete_program.outputs[0] - out = MetaInfo( - list(out.shape), - out.dtype, - out.stop_gradient, - ) + out = MetaInfo.from_tensor(out) layer.forward.rollback() return out diff --git a/sot/opcode_translator/executor/variable_dispatch.py b/sot/opcode_translator/executor/variable_dispatch.py index 1efcee4e8..e65a1c18d 100644 --- a/sot/opcode_translator/executor/variable_dispatch.py +++ b/sot/opcode_translator/executor/variable_dispatch.py @@ -4,6 +4,8 @@ from functools import partial from typing import TYPE_CHECKING +import paddle + from ...utils import BreakGraphError, NotImplementException from ...utils.magic_methods import ( BINARY_OPS, @@ -97,6 +99,37 @@ lambda var: var.bool(), ) +# TensorVariable +Dispatcher.register( + paddle.is_tensor, + ("TensorVariable",), + {}, + lambda var: var.is_tensor(), +) +Dispatcher.register( + paddle.is_complex, + ("TensorVariable",), + {}, + lambda var: var.is_complex(), +) +Dispatcher.register( + paddle.is_integer, + ("TensorVariable",), + {}, + lambda var: var.is_integer(), +) +Dispatcher.register( + paddle.is_floating_point, + ("TensorVariable",), + {}, + lambda var: var.is_floating_point(), +) +Dispatcher.register( + paddle.rank, + ("TensorVariable",), + {}, + lambda var: var.ndim, +) # VariableBase Dispatcher.register( diff --git a/sot/opcode_translator/executor/variables/basic.py b/sot/opcode_translator/executor/variables/basic.py index 574e78dd2..a38603af4 100644 --- a/sot/opcode_translator/executor/variables/basic.py +++ b/sot/opcode_translator/executor/variables/basic.py @@ -33,21 +33,34 @@ if TYPE_CHECKING: from ..function_graph import FunctionGraph -DTYPE_ABBRS = { + +FP_DTYPE_ABBRS = { paddle.bfloat16: 'bfloat16', paddle.float64: 'float64', paddle.float32: 'float32', paddle.float16: 'float16', +} + +CP_DTYPE_ABBRS = { paddle.complex64: 'complex64', paddle.complex128: 'complex128', +} + +INT_DTYPE_ABBRS = { paddle.int8: 'int8', paddle.int16: 'int16', paddle.int32: 'int32', paddle.int64: 'int64', - paddle.bool: 'bool', paddle.uint8: 'uint8', } +DTYPE_ABBRS = { + **FP_DTYPE_ABBRS, + **CP_DTYPE_ABBRS, + **INT_DTYPE_ABBRS, + paddle.bool: 'bool', +} + class ConstantVariable(VariableBase): def __init__( @@ -95,6 +108,14 @@ def wrap_literal(value: Any) -> ConstantVariable: return ConstantVariable(value, ConstTracker(value)) +IMPLEMENTED_TENSOR_PROPERTIES = set() + + +def tensor_property(func): + IMPLEMENTED_TENSOR_PROPERTIES.add(func.__name__) + return property(func) + + class TensorVariable(VariableBase): var_name_generator = NameGenerator("var_") @@ -151,7 +172,7 @@ def make_stringify_guard(self) -> StringifyExpression: ), ) return StringifyExpression( - f"str(MetaInfo.from_tensor({frame_value_tracer.expr})) == '{self.meta}'", + f"MetaInfo.from_tensor({frame_value_tracer.expr}).guard_str() == '{self.meta.guard_str()}'", union_free_vars( {"MetaInfo": MetaInfo}, frame_value_tracer.free_vars, @@ -186,7 +207,7 @@ def __setitem__(self, key, value): value, ) - @property + @tensor_property def T(self): perm = list(range(len(self.meta.shape) - 1, -1, -1)) perm_var = VariableFactory.from_value( @@ -195,40 +216,79 @@ def T(self): out = self.graph.call_paddle_api(paddle.transpose, self, perm_var) return out - @property + @tensor_property def ndim(self): return ConstantVariable.wrap_literal(len(self.meta.shape)) - @property - def shape(self): + @tensor_property + def size(self): + # TODO: maybe break graph. if self.meta.is_dynamic_shape(): raise BreakGraphError( f"Getting size for a dynamic shape tensor causes graph break. shape = {self.meta.shape}" ) + elements = reduce(operator.mul, self.meta.shape, 1) + return ConstantVariable.wrap_literal(elements) + + @tensor_property + def shape(self): + if self.meta.is_dynamic_shape(): + raise BreakGraphError( + f"Getting shape for a dynamic shape tensor causes graph break. shape = {self.meta.shape}" + ) self.graph.add_global_guarded_variable(self) return VariableFactory.from_value( self.meta.shape, self.graph, tracker=ConstTracker(self.meta.shape) ) - @property - def size(self): - # TODO: maybe break graph. - if self.meta.is_dynamic_shape(): - raise BreakGraphError( - f"Getting size for a dynamic shape tensor causes graph break. shape = {self.meta.shape}" - ) - elements = reduce(operator.mul, self.meta.shape, 1) - return ConstantVariable.wrap_literal(elements) + def is_tensor(self): + return ConstantVariable.wrap_literal(True) + + def is_complex(self): + dtype = self.meta.dtype + is_cp_dtype = dtype in CP_DTYPE_ABBRS + return ConstantVariable.wrap_literal(is_cp_dtype) + + def is_integer(self): + dtype = self.meta.dtype + is_int_dtype = dtype in INT_DTYPE_ABBRS + return ConstantVariable.wrap_literal(is_int_dtype) + + def is_floating_point(self): + dtype = self.meta.dtype + is_fp_dtype = dtype in FP_DTYPE_ABBRS + return ConstantVariable.wrap_literal(is_fp_dtype) def getattr(self, name: str): - if name in ["shape", "dtype", "stop_gradient"]: + method_name_to_builtin_fn = { + "dim": paddle.rank, + "ndimension": paddle.rank, + "is_tensor": paddle.is_tensor, + "is_complex": paddle.is_complex, + "is_integer": paddle.is_integer, + "is_floating_point": paddle.is_floating_point, + } + if name in ["dtype", "type", "name", "persistable", "stop_gradient"]: + if name == "name" and self.meta.name.startswith( + "infer_meta_variable_tmp" + ): + raise BreakGraphError(f"{self.meta.name} is a middle tensor.") return VariableFactory.from_value( getattr(self.meta, name), self.graph, tracker=GetAttrTracker(self, name), ) - elif name in ["T", "ndim", "size"]: + elif name in IMPLEMENTED_TENSOR_PROPERTIES: return getattr(self, name) + elif name in method_name_to_builtin_fn: + # TODO: backward, gradient + from .callable import BuiltinVariable + + builtin_fn = method_name_to_builtin_fn[name] + + return BuiltinVariable( + builtin_fn, self.graph, DanglingTracker() + ).bind(self, name) elif name in paddle_tensor_methods: from .callable import TensorFunctionVariable diff --git a/tests/test_18_tensor_method.py b/tests/test_18_tensor_method.py index 6cd5b1472..3f56c0e90 100644 --- a/tests/test_18_tensor_method.py +++ b/tests/test_18_tensor_method.py @@ -26,7 +26,21 @@ def tensor_method_passed_by_user(a: paddle.Tensor, func: paddle.Tensor): def tensor_method_property(a: paddle.Tensor, b: paddle.Tensor): - return a @ b.T + len(a.shape) + b.size + a.ndim + return ( + a.name, + str(a.place), + a.persistable, + a.dtype, + a.type, + a.is_tensor(), + a.clear_gradient(), + a @ b.T + len(a.shape) + b.size + a.ndim + a.dim() + a.rank(), + ) + + +def middle_tensor_name(a: paddle.Tensor, b: paddle.Tensor): + c = a + b + return c.name class TestTensorMethod(TestCaseBase): @@ -47,9 +61,15 @@ def test_tensor_method_passed_by_user(self): self.assert_results(tensor_method_passed_by_user, x, y.add) def test_tensor_method_property(self): + x = paddle.rand([42, 24], dtype='float64') + y = paddle.rand([42, 24], dtype='float32') + self.assert_results(tensor_method_property, x, y) + + @unittest.skip("TODO: dynamic tensor name is different") + def test_middle_tensor_name(self): x = paddle.rand([42, 24]) y = paddle.rand([42, 24]) - self.assert_results(tensor_method_property, x, y) + self.assert_results(middle_tensor_name, x, y) if __name__ == "__main__":