diff --git a/docker/install/ubuntu_install_python_package.sh b/docker/install/ubuntu_install_python_package.sh index fb0f596d6552..8b79455d0cd1 100755 --- a/docker/install/ubuntu_install_python_package.sh +++ b/docker/install/ubuntu_install_python_package.sh @@ -36,6 +36,6 @@ pip3 install \ pytest-xdist \ requests \ scipy \ - synr==0.5.0 \ + synr==0.6.0 \ six \ tornado diff --git a/python/gen_requirements.py b/python/gen_requirements.py index b4f3907bbc0f..bcd8ccd9b531 100755 --- a/python/gen_requirements.py +++ b/python/gen_requirements.py @@ -255,7 +255,7 @@ ("sphinx_autodoc_annotation", None), ("sphinx_gallery", None), ("sphinx_rtd_theme", None), - ("synr", "==0.5.0"), + ("synr", "==0.6.0"), ("tensorflow", None), ("tensorflow-estimator", None), ("tflite", None), diff --git a/python/tvm/script/parser.py b/python/tvm/script/parser.py index 6cb22aeb5f47..0132025024b2 100644 --- a/python/tvm/script/parser.py +++ b/python/tvm/script/parser.py @@ -32,6 +32,7 @@ from tvm._ffi.base import TVMError from tvm.ir import GlobalVar from tvm.ir.function import BaseFunc +from tvm.tir import buffer from tvm.tir.function import PrimFunc from . import _ffi_api from . import tir @@ -154,10 +155,10 @@ class TVMScriptParser(Transformer): ast.BuiltinOp.Not: tvm.tir.Not, } - def __init__(self, base_lienno, tir_namespace): + def __init__(self, base_lineno, tir_namespace): self.context = None - self.base_lineno = base_lienno + self.base_lineno = base_lineno self.current_lineno = 0 self.current_col_offset = 0 self.tir_namespace = tir_namespace @@ -249,7 +250,7 @@ def parse_arg_list(self, func, node_call): func : Function The function that provides the signature - node_call: ast.Call + node_call: Union[ast.Call, ast.TypeApply, ast.TypeCall] The AST call node that calls into the function. Returns @@ -257,12 +258,15 @@ def parse_arg_list(self, func, node_call): arg_list : list The parsed positional argument. """ - assert isinstance(node_call, ast.Call) + assert isinstance(node_call, (ast.Call, ast.TypeApply, ast.TypeCall)) # collect arguments args = [self.transform(arg) for arg in node_call.params] - kw_args = { - self.transform(k): self.transform(v) for k, v in node_call.keyword_params.items() - } + if isinstance(node_call, ast.TypeApply): + kw_args = {} # TypeApply (e.g. foo[bar]) doesn't have kwargs defined in synr + else: + kw_args = { + self.transform(k): self.transform(v) for k, v in node_call.keyword_params.items() + } # get the name and parameter list of func if isinstance(func, (Intrin, ScopeHandler, SpecialStmt)): func_name, param_list = func.signature() @@ -276,6 +280,7 @@ def parse_arg_list(self, func, node_call): reader = CallArgumentReader(func_name, args, kw_args, self, node_call) pos_only, kwargs, varargs = param_list internal_args = list() + for i, arg_name in enumerate(pos_only): internal_args.append(reader.get_pos_only_arg(i + 1, arg_name)) for i, arg_info in enumerate(kwargs): @@ -439,8 +444,22 @@ def check_decorator(decorators: List[ast.Expr]) -> bool: # add parameters of function for arg in node.params: - arg_var = tvm.te.var(arg.name, self.parse_type(arg.ty, arg)) - self.context.update_symbol(arg.name, arg_var, node) + # Note that this case is for T.match_buffer syntax sugar + if isinstance(arg.ty, (ast.TypeCall, ast.TypeApply)): + result = self.handle_match_buffer_type(arg.ty, arg.name) + if not isinstance(result, buffer.Buffer): + self.report_error( + "The result type of evaluating TypeCall and TypeApply stmt" + f" is wrong: {type(result)}. It should be a Buffer", + node.span, + ) + arg_name_with_handle = arg.name + "_handle" + arg_var = tvm.te.var(arg_name_with_handle, tvm.ir.PrimType("handle")) + self.context.func_buffer_map[arg_var] = result + self.context.update_symbol(arg.name, result, node) + else: + arg_var = tvm.te.var(arg.name, self.parse_type(arg.ty, arg)) + self.context.update_symbol(arg.name, arg_var, node) self.context.func_params.append(arg_var) if not check_decorator(node.decorators): @@ -1110,6 +1129,30 @@ def transform_TypeConstant(self, node): """ return node.value + def transform_TypeTuple(self, node): + """Tuple value visitor for types. + + Mostly used in `transform_TypeCall` and `transform_TypeApply`. + """ + return [self.transform(value) for value in node.values] + + def handle_match_buffer_type(self, node, buffer_name): + """special function to handle syntax sugar for match buffer. + + This method is for buffer declarations in the function parameters. + """ + func = self.transform(node.func_name) + assert isinstance(func, SpecialStmt) + + # parse args and kwargs for TypeCall and TypeApply + arg_list = self.parse_arg_list(func, node) + # Note that the third element in arg_list would always be the 'name' + # TODO: This index is hardcoded as a workaround. Better to make it programmatic + if arg_list[2] is None: + arg_list[2] = buffer_name + buf = func.handle(node, self.context, arg_list, node.func_name.span) + return buf + def transform_Return(self, node): self.report_error( "TVM script does not support return statements. Instead the last statement in any " diff --git a/python/tvm/script/tir/__init__.py b/python/tvm/script/tir/__init__.py index 6aa7eb33ec8b..472b3de0e43b 100644 --- a/python/tvm/script/tir/__init__.py +++ b/python/tvm/script/tir/__init__.py @@ -18,6 +18,6 @@ # Type system from .ty import int8, int16, int32, int64, float16, float32, float64 -from .ty import boolean, handle, Ptr, Tuple +from .ty import boolean, handle, Ptr, Tuple, Buffer from .prim_func import prim_func diff --git a/python/tvm/script/tir/ty.py b/python/tvm/script/tir/ty.py index 9140310d4733..2808e7a48735 100644 --- a/python/tvm/script/tir/ty.py +++ b/python/tvm/script/tir/ty.py @@ -21,6 +21,7 @@ """ # pylint: disable=invalid-name import tvm +from .special_stmt import SpecialStmt, convert_to_int class TypeGeneric: # pylint: disable=too-few-public-methods @@ -67,6 +68,75 @@ def __getitem__(self, vtypes): return ConcreteType(tvm.ir.TupleType([vtype.evaluate() for vtype in vtypes])) +class GenericBufferType(SpecialStmt): # pylint: disable=too-few-public-methods, abstract-method + """TVM script typing class for uniform Type objects""" + + def __init__(self, vtype): + def match_buffer_syntax_sugar( + shape, + dtype: str = "float32", + name: str = None, + data=None, + strides=None, + elem_offset=None, + scope="global", + align=-1, + offset_factor=0, + buffer_type="default", + span=None, + ): + if strides is None: + strides = [] + align = convert_to_int(align, "align", self.context.report_error, self.node.span) + offset_factor = convert_to_int( + offset_factor, "offset_factor", self.context.report_error, self.node.span + ) + buffer = tvm.tir.decl_buffer( + shape, + dtype, + name, + data, + strides, + elem_offset, + scope, + align, + offset_factor, + buffer_type, + span=span, + ) + return buffer + + self.type = vtype + super().__init__(match_buffer_syntax_sugar, def_symbol=True) + + def __call__( + self, + shape, + dtype="float32", + *, + name: str = None, + data=None, + strides=None, + elem_offset=None, + scope="global", + align=-1, + offset_factor=0, + buffer_type="default", + span=None, + ): + """ + This function is for Buffer(...) syntax sugar. + """ + pass # pylint: disable=unnecessary-pass + + def __getitem__(self, args): + """ + This function is for Buffer[...] syntax sugar + Note that args is the list of all arguments + """ + pass # pylint: disable=unnecessary-pass + + int8 = ConcreteType("int8") int16 = ConcreteType("int16") int32 = ConcreteType("int32") @@ -78,3 +148,6 @@ def __getitem__(self, vtypes): handle = ConcreteType("handle") Ptr = GenericPtrType() Tuple = GenericTupleType() +# we don't have 'buffer' type on the cpp side +# thus 'handle' is used here for convenience's sake +Buffer = GenericBufferType("handle") diff --git a/tests/python/unittest/test_tvmscript_syntax_sugar.py b/tests/python/unittest/test_tvmscript_syntax_sugar.py index b8d123236982..0d4c833160b4 100644 --- a/tests/python/unittest/test_tvmscript_syntax_sugar.py +++ b/tests/python/unittest/test_tvmscript_syntax_sugar.py @@ -101,5 +101,50 @@ def test_syntax_sugar_fail(): check_error(loop_syntax_sugar_fail, 3) +# match buffer - use kwargs +@T.prim_func +def elementwise_handle( + a: T.handle, + b: T.handle, +) -> None: + A = T.match_buffer(a, (128, 128, 128, 128)) + B = T.match_buffer(b, (128, 128, 128, 128)) + for i, j, k, l in T.grid(128, 128, 128, 128): + with T.block("B"): + vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l]) + B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0 + + +# match buffer - use buffer with kwargs +@T.prim_func +def elementwise_buffer_kwargs( + a: T.Buffer(shape=(128, 128, 128, 128), dtype="float32", elem_offset=None), + b: T.Buffer(shape=(128, 128, 128, 128), dtype="float32", elem_offset=None), +) -> None: + for i, j, k, l in T.grid(128, 128, 128, 128): + with T.block("B"): + vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l]) + b[vi, vj, vk, vl] = a[vi, vj, vk, vl] * 2.0 + + +# match buffer - use buffer without kwargs +@T.prim_func +def elementwise_buffer_no_kwargs( + a: T.Buffer[(128, 128, 128, 128), "float32"], + b: T.Buffer[(128, 128, 128, 128), "float32"], +) -> None: + for i, j, k, l in T.grid(128, 128, 128, 128): + with T.block("B"): + vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l]) + b[vi, vj, vk, vl] = a[vi, vj, vk, vl] * 2.0 + + +def test_match_buffer_syntax_sugar(): + # with kwargs + assert_structural_equal(elementwise_handle, elementwise_buffer_kwargs) + # without kwargs + assert_structural_equal(elementwise_handle, elementwise_buffer_no_kwargs) + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/scripts/task_ci_setup.sh b/tests/scripts/task_ci_setup.sh index dfd2a32165f1..323bc0752801 100755 --- a/tests/scripts/task_ci_setup.sh +++ b/tests/scripts/task_ci_setup.sh @@ -30,7 +30,7 @@ set -o pipefail # echo "Addtiional setup in" ${CI_IMAGE_NAME} -python3 -m pip install --user tlcpack-sphinx-addon==0.2.1 synr==0.5.0 +python3 -m pip install --user tlcpack-sphinx-addon==0.2.1 synr==0.6.0 # Rebuild standalone_crt in build/ tree. This file is not currently archived by pack_lib() in # Jenkinsfile. We expect config.cmake to be present from pack_lib().