Skip to content

Commit

Permalink
[TVMScript] Add syntax sugar for T.handle and T.match_buffer (#9492)
Browse files Browse the repository at this point in the history
  • Loading branch information
Yuanjing Shi authored Dec 8, 2021
1 parent b54beed commit e8889ae
Show file tree
Hide file tree
Showing 7 changed files with 174 additions and 13 deletions.
2 changes: 1 addition & 1 deletion docker/install/ubuntu_install_python_package.sh
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,6 @@ pip3 install \
pytest-xdist \
requests \
scipy \
synr==0.5.0 \
synr==0.6.0 \
six \
tornado
2 changes: 1 addition & 1 deletion python/gen_requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@
("sphinx_autodoc_annotation", None),
("sphinx_gallery", None),
("sphinx_rtd_theme", None),
("synr", "==0.5.0"),
("synr", "==0.6.0"),
("tensorflow", None),
("tensorflow-estimator", None),
("tflite", None),
Expand Down
61 changes: 52 additions & 9 deletions python/tvm/script/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from tvm._ffi.base import TVMError
from tvm.ir import GlobalVar
from tvm.ir.function import BaseFunc
from tvm.tir import buffer
from tvm.tir.function import PrimFunc
from . import _ffi_api
from . import tir
Expand Down Expand Up @@ -154,10 +155,10 @@ class TVMScriptParser(Transformer):
ast.BuiltinOp.Not: tvm.tir.Not,
}

def __init__(self, base_lienno, tir_namespace):
def __init__(self, base_lineno, tir_namespace):
self.context = None

self.base_lineno = base_lienno
self.base_lineno = base_lineno
self.current_lineno = 0
self.current_col_offset = 0
self.tir_namespace = tir_namespace
Expand Down Expand Up @@ -249,20 +250,23 @@ def parse_arg_list(self, func, node_call):
func : Function
The function that provides the signature
node_call: ast.Call
node_call: Union[ast.Call, ast.TypeApply, ast.TypeCall]
The AST call node that calls into the function.
Returns
-------
arg_list : list
The parsed positional argument.
"""
assert isinstance(node_call, ast.Call)
assert isinstance(node_call, (ast.Call, ast.TypeApply, ast.TypeCall))
# collect arguments
args = [self.transform(arg) for arg in node_call.params]
kw_args = {
self.transform(k): self.transform(v) for k, v in node_call.keyword_params.items()
}
if isinstance(node_call, ast.TypeApply):
kw_args = {} # TypeApply (e.g. foo[bar]) doesn't have kwargs defined in synr
else:
kw_args = {
self.transform(k): self.transform(v) for k, v in node_call.keyword_params.items()
}
# get the name and parameter list of func
if isinstance(func, (Intrin, ScopeHandler, SpecialStmt)):
func_name, param_list = func.signature()
Expand All @@ -276,6 +280,7 @@ def parse_arg_list(self, func, node_call):
reader = CallArgumentReader(func_name, args, kw_args, self, node_call)
pos_only, kwargs, varargs = param_list
internal_args = list()

for i, arg_name in enumerate(pos_only):
internal_args.append(reader.get_pos_only_arg(i + 1, arg_name))
for i, arg_info in enumerate(kwargs):
Expand Down Expand Up @@ -439,8 +444,22 @@ def check_decorator(decorators: List[ast.Expr]) -> bool:

# add parameters of function
for arg in node.params:
arg_var = tvm.te.var(arg.name, self.parse_type(arg.ty, arg))
self.context.update_symbol(arg.name, arg_var, node)
# Note that this case is for T.match_buffer syntax sugar
if isinstance(arg.ty, (ast.TypeCall, ast.TypeApply)):
result = self.handle_match_buffer_type(arg.ty, arg.name)
if not isinstance(result, buffer.Buffer):
self.report_error(
"The result type of evaluating TypeCall and TypeApply stmt"
f" is wrong: {type(result)}. It should be a Buffer",
node.span,
)
arg_name_with_handle = arg.name + "_handle"
arg_var = tvm.te.var(arg_name_with_handle, tvm.ir.PrimType("handle"))
self.context.func_buffer_map[arg_var] = result
self.context.update_symbol(arg.name, result, node)
else:
arg_var = tvm.te.var(arg.name, self.parse_type(arg.ty, arg))
self.context.update_symbol(arg.name, arg_var, node)
self.context.func_params.append(arg_var)

if not check_decorator(node.decorators):
Expand Down Expand Up @@ -1110,6 +1129,30 @@ def transform_TypeConstant(self, node):
"""
return node.value

def transform_TypeTuple(self, node):
"""Tuple value visitor for types.
Mostly used in `transform_TypeCall` and `transform_TypeApply`.
"""
return [self.transform(value) for value in node.values]

def handle_match_buffer_type(self, node, buffer_name):
"""special function to handle syntax sugar for match buffer.
This method is for buffer declarations in the function parameters.
"""
func = self.transform(node.func_name)
assert isinstance(func, SpecialStmt)

# parse args and kwargs for TypeCall and TypeApply
arg_list = self.parse_arg_list(func, node)
# Note that the third element in arg_list would always be the 'name'
# TODO: This index is hardcoded as a workaround. Better to make it programmatic
if arg_list[2] is None:
arg_list[2] = buffer_name
buf = func.handle(node, self.context, arg_list, node.func_name.span)
return buf

def transform_Return(self, node):
self.report_error(
"TVM script does not support return statements. Instead the last statement in any "
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/script/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,6 @@

# Type system
from .ty import int8, int16, int32, int64, float16, float32, float64
from .ty import boolean, handle, Ptr, Tuple
from .ty import boolean, handle, Ptr, Tuple, Buffer

from .prim_func import prim_func
73 changes: 73 additions & 0 deletions python/tvm/script/tir/ty.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
"""
# pylint: disable=invalid-name
import tvm
from .special_stmt import SpecialStmt, convert_to_int


class TypeGeneric: # pylint: disable=too-few-public-methods
Expand Down Expand Up @@ -67,6 +68,75 @@ def __getitem__(self, vtypes):
return ConcreteType(tvm.ir.TupleType([vtype.evaluate() for vtype in vtypes]))


class GenericBufferType(SpecialStmt): # pylint: disable=too-few-public-methods, abstract-method
"""TVM script typing class for uniform Type objects"""

def __init__(self, vtype):
def match_buffer_syntax_sugar(
shape,
dtype: str = "float32",
name: str = None,
data=None,
strides=None,
elem_offset=None,
scope="global",
align=-1,
offset_factor=0,
buffer_type="default",
span=None,
):
if strides is None:
strides = []
align = convert_to_int(align, "align", self.context.report_error, self.node.span)
offset_factor = convert_to_int(
offset_factor, "offset_factor", self.context.report_error, self.node.span
)
buffer = tvm.tir.decl_buffer(
shape,
dtype,
name,
data,
strides,
elem_offset,
scope,
align,
offset_factor,
buffer_type,
span=span,
)
return buffer

self.type = vtype
super().__init__(match_buffer_syntax_sugar, def_symbol=True)

def __call__(
self,
shape,
dtype="float32",
*,
name: str = None,
data=None,
strides=None,
elem_offset=None,
scope="global",
align=-1,
offset_factor=0,
buffer_type="default",
span=None,
):
"""
This function is for Buffer(...) syntax sugar.
"""
pass # pylint: disable=unnecessary-pass

def __getitem__(self, args):
"""
This function is for Buffer[...] syntax sugar
Note that args is the list of all arguments
"""
pass # pylint: disable=unnecessary-pass


int8 = ConcreteType("int8")
int16 = ConcreteType("int16")
int32 = ConcreteType("int32")
Expand All @@ -78,3 +148,6 @@ def __getitem__(self, vtypes):
handle = ConcreteType("handle")
Ptr = GenericPtrType()
Tuple = GenericTupleType()
# we don't have 'buffer' type on the cpp side
# thus 'handle' is used here for convenience's sake
Buffer = GenericBufferType("handle")
45 changes: 45 additions & 0 deletions tests/python/unittest/test_tvmscript_syntax_sugar.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,5 +101,50 @@ def test_syntax_sugar_fail():
check_error(loop_syntax_sugar_fail, 3)


# match buffer - use kwargs
@T.prim_func
def elementwise_handle(
a: T.handle,
b: T.handle,
) -> None:
A = T.match_buffer(a, (128, 128, 128, 128))
B = T.match_buffer(b, (128, 128, 128, 128))
for i, j, k, l in T.grid(128, 128, 128, 128):
with T.block("B"):
vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l])
B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0


# match buffer - use buffer with kwargs
@T.prim_func
def elementwise_buffer_kwargs(
a: T.Buffer(shape=(128, 128, 128, 128), dtype="float32", elem_offset=None),
b: T.Buffer(shape=(128, 128, 128, 128), dtype="float32", elem_offset=None),
) -> None:
for i, j, k, l in T.grid(128, 128, 128, 128):
with T.block("B"):
vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l])
b[vi, vj, vk, vl] = a[vi, vj, vk, vl] * 2.0


# match buffer - use buffer without kwargs
@T.prim_func
def elementwise_buffer_no_kwargs(
a: T.Buffer[(128, 128, 128, 128), "float32"],
b: T.Buffer[(128, 128, 128, 128), "float32"],
) -> None:
for i, j, k, l in T.grid(128, 128, 128, 128):
with T.block("B"):
vi, vj, vk, vl = T.axis.remap("SSSS", [i, j, k, l])
b[vi, vj, vk, vl] = a[vi, vj, vk, vl] * 2.0


def test_match_buffer_syntax_sugar():
# with kwargs
assert_structural_equal(elementwise_handle, elementwise_buffer_kwargs)
# without kwargs
assert_structural_equal(elementwise_handle, elementwise_buffer_no_kwargs)


if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))
2 changes: 1 addition & 1 deletion tests/scripts/task_ci_setup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ set -o pipefail
#
echo "Addtiional setup in" ${CI_IMAGE_NAME}

python3 -m pip install --user tlcpack-sphinx-addon==0.2.1 synr==0.5.0
python3 -m pip install --user tlcpack-sphinx-addon==0.2.1 synr==0.6.0

# Rebuild standalone_crt in build/ tree. This file is not currently archived by pack_lib() in
# Jenkinsfile. We expect config.cmake to be present from pack_lib().
Expand Down

0 comments on commit e8889ae

Please sign in to comment.