Skip to content

Commit

Permalink
Allow parsing python function instead of string (apache#45)
Browse files Browse the repository at this point in the history
  • Loading branch information
junrushao authored and Hzfengsy committed Jul 27, 2022
1 parent 77b71e0 commit cdd5f99
Show file tree
Hide file tree
Showing 8 changed files with 56 additions and 31 deletions.
12 changes: 10 additions & 2 deletions python/tvm/script/builder/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,14 @@
unroll,
vectorized,
)
from .prim_func_frame import arg, func_attr, func_ret, prim_func, match_buffer, preflattened_buffer
from .var import Buffer
from .op import *
from .prim_func_frame import (
arg,
func_attr,
func_name,
func_ret,
match_buffer,
preflattened_buffer,
prim_func,
)
from .var import Buffer
23 changes: 15 additions & 8 deletions python/tvm/script/builder/tir/prim_func_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,13 @@
# specific language governing permissions and limitations
# under the License.
"""TVM Script TIR Prim Func Frame"""
from typing import Union, Dict, Any
from typing import Any, Callable, Dict, Optional, Union

from tvm._ffi import register_object as _register_object
from tvm.ir import Type
from tvm.tir.buffer import Buffer
from tvm.tir.expr import Var
from tvm.ir import Type

from ..builder import Builder
from . import _ffi_api
from .base import TIRFrame

Expand All @@ -32,15 +31,23 @@ class PrimFuncFrame(TIRFrame):
...


def prim_func(name) -> PrimFuncFrame:
return _ffi_api.PrimFuncFrame(name) # pylint: disable=no-member # type: ignore
def prim_func(f: Optional[Callable] = None) -> PrimFuncFrame:
if f is not None:
from tvm.script.parse import parse # pylint: disable=import-outside-toplevel

return parse(f)
return _ffi_api.PrimFuncFrame() # pylint: disable=no-member # type: ignore


setattr(prim_func, "dispatch_token", "tir")


def arg(name, obj) -> Union[Var, Buffer]:
return _ffi_api.Arg(name, obj) # pylint: disable=no-member # type: ignore


setattr(prim_func, "dispatch_token", "tir")
def func_name(name) -> str:
return _ffi_api.FuncName(name) # pylint: disable=no-member # type: ignore


def func_attr(attrs: Dict[str, Any]) -> None:
Expand All @@ -65,7 +72,7 @@ def match_buffer(
axis_separators=None,
span=None,
) -> Buffer:
return _ffi_api.MatchBuffer(
return _ffi_api.MatchBuffer( # pylint: disable=no-member # type: ignore
param,
shape,
dtype,
Expand Down Expand Up @@ -95,7 +102,7 @@ def preflattened_buffer(
axis_separators=None,
span=None,
) -> None:
_ffi_api.PreflattenedBuffer(
_ffi_api.PreflattenedBuffer( # pylint: disable=no-member # type: ignore
postflattened,
shape,
dtype,
Expand Down
17 changes: 8 additions & 9 deletions python/tvm/script/parse/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,14 @@ def __init__(self, program: Union[str, doc.AST]):
else:
self.source_name = inspect.getsourcefile(program) # type: ignore
lines, self.start_line = inspect.getsourcelines(program) # type: ignore

if lines:
self.start_column = len(lines[0]) - len(lines[0].lstrip())
else:
self.start_column = 0
if self.start_column and lines:
self.source = "\n".join([l[self.start_column :].rstrip() for l in lines])
else:
self.source = ""
self.source = "".join(lines)
try:
# It will cause a problem when running in Jupyter Notebook.
# `mod` will be <module '__main__'>, which is a built-in module
Expand All @@ -69,16 +68,16 @@ def as_ast(self) -> doc.AST:
return doc.parse(self.source)


def parse(
program: Union[doc.AST, Any, str],
extra_vars: Optional[Dict[str, Any]] = None,
):
def parse(program: Union[doc.AST, Any, str]):
# TODO: `extra_vars` is a hack
from tvm.script.builder import tir as T

extra_vars = {"T": T}
program_ast = SourceCode(program).as_ast()
parser = Parser()
with Builder() as builder:
with parser.var_table.with_frame():
if extra_vars:
for k, v in extra_vars.items():
parser.var_table.add(k, v)
for k, v in extra_vars.items():
parser.var_table.add(k, v)
parser.visit(program_ast)
return builder.get()
3 changes: 2 additions & 1 deletion python/tvm/script/parse/tir/tir.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ def visit_with(self: Parser, node: doc.With) -> None:
def visit_function_def(self: Parser, node: doc.FunctionDef) -> None:
with self.var_table.with_frame():
self.var_table.add("range", T.serial)
with T.prim_func(node.name):
with T.prim_func():
T.func_name(node.name)
with self.with_dispatch_token("tir"):
# TODO: define the GlobalVar, handle the return value
self.visit(node.args)
Expand Down
10 changes: 8 additions & 2 deletions src/script/builder/tir/prim_func_frame.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@ void PrimFuncFrameNode::ExitWithScope() {
}
}

PrimFuncFrame PrimFunc_(String name) {
PrimFuncFrame PrimFunc_() {
ObjectPtr<PrimFuncFrameNode> n = make_object<PrimFuncFrameNode>();
n->name = name;
n->name = "";
n->args.clear();
n->ret_type = TupleType::Empty();
n->buffer_map.clear();
Expand All @@ -78,6 +78,11 @@ tvm::tir::Buffer Arg(String name, tvm::tir::Buffer buffer) {
return buffer;
}

void FuncName(String name) {
PrimFuncFrame frame = Builder::Current()->FindFrame<PrimFuncFrame>().value();
frame->name = name;
}

void FuncAttrs(Map<String, ObjectRef> attrs) {
using namespace tvm::tir;
PrimFuncFrame frame = Builder::Current()->FindFrame<PrimFuncFrame>().value();
Expand Down Expand Up @@ -165,6 +170,7 @@ TVM_REGISTER_GLOBAL("script.builder.tir.Arg")
LOG(FATAL) << "ValueError: Unexpected type for TIR Arg.";
throw;
});
TVM_REGISTER_GLOBAL("script.builder.tir.FuncName").set_body_typed(FuncName);
TVM_REGISTER_GLOBAL("script.builder.tir.FuncAttrs").set_body_typed(FuncAttrs);
TVM_REGISTER_GLOBAL("script.builder.tir.FuncRet").set_body_typed(FuncRet);
TVM_REGISTER_GLOBAL("script.builder.tir.MatchBuffer").set_body_typed(MatchBuffer);
Expand Down
3 changes: 2 additions & 1 deletion src/script/builder/tir/prim_func_frame.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,10 @@ class PrimFuncFrame : public TIRFrame {
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(PrimFuncFrame, TIRFrame, PrimFuncFrameNode);
};

PrimFuncFrame PrimFunc_(String name);
PrimFuncFrame PrimFunc_();
tvm::tir::Var Arg(String name, tvm::tir::Var var);
tvm::tir::Buffer Arg(String name, tvm::tir::Buffer buffer);
void FuncName(String name);
void FuncAttrs(Map<String, ObjectRef> attrs);
tvm::Type FuncRet(tvm::Type ret_type);

Expand Down
5 changes: 3 additions & 2 deletions tests/python/tvmscript/test_builder_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,16 @@

def test_builder_basic():
with Builder() as b:
with T.prim_func(name="main"):
with T.prim_func():
T.func_name("main")
T.func_attr({"global_symbol": "main"})
T.func_ret(tvm.ir.PrimType("int8"))
arg_a = T.arg("a", T.handle())
arg_b = T.arg("b", T.handle())
buffer_c = T.Buffer((128,), "float32")
buffer_d = T.Buffer((128,), "float32")
arg_c = T.arg("c", buffer_c)
arg_d = T.arg("d", buffer_d)
T.func_ret(tvm.ir.PrimType("int8"))
A = def_("A", T.match_buffer(arg_a, (128, 128, 128)))
B = def_("B", T.match_buffer(arg_b, (128, 128, 128)))
T.preflattened_buffer(buffer_c, (128,), data=buffer_c.data)
Expand Down
14 changes: 8 additions & 6 deletions tests/python/tvmscript/test_parse_basic.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,25 @@
from tvm.script.builder import tir as T
from tvm.script.parse import parse

elementwise = """

# pylint: disable=unused-argument,unused-variable,invalid-name
@T.prim_func
def elementwise(
A: T.Buffer(shape=(128, 128, 128), dtype="float32"),
B: T.Buffer(shape=(128, 128, 128), dtype="float32"),
A: T.Buffer(shape=(128, 128, 128), dtype="float32"), # type: ignore
B: T.Buffer(shape=(128, 128, 128), dtype="float32"), # type: ignore
) -> None:
for i, j, *vvv, k in T.grid(128, 128, 128, 128, 128, 128, 128):
with T.block("inner_block"):
# vi, vj, vk = T.axis.remap("SSR", [i, j, k])
vi = T.axis.S(128, i + 1)
vj = T.axis.S(128, j + 20)
vk = T.axis.R(128, k - i)
"""


# pylint: enable=unused-argument,unused-variable,invalid-name


def main():
result = parse(elementwise, extra_vars={"T": T})
result = elementwise
print(result.script())


Expand Down

0 comments on commit cdd5f99

Please sign in to comment.