Skip to content

Commit

Permalink
Add metadata section, support constant and metadata in parser & print…
Browse files Browse the repository at this point in the history
…er (apache#76)

* [CI] Set up CI; format and lint relax code to pass CI (apache#72)

* init

* fix lint

* update task_lint

* more lint

* more lint

* lint

* jenkinsfile

* jenkinsfile

* run relax only tests

* python3.7 for pytest

* point to personal ci-cpu docker

* docker pull

* test

* fix cmake config

* update

* update

* rebase

* rebase

* AutoTIR integration (apache#58)

* [WIP] Basic task extraction mechanism is implemented.

* [WIP] For gradual integration with Relay pipeline, meta_schedule/integration.py is created for relax to avoid potential conflict.

* support tir tuning and injection mode

* Add target field for Relax Extracted Task

* 1. Create relax namespace/tvm objects/... for metaschedule to preserve relay support. 2. Promote target field from Optional<Target> to Target

* Support ApplyHistoryBest

* Reflect feedback from Yuchen

* minor improvement and fix linter issue

* add ASF header

* Reorganize file structure

* fix lint errors

* remove the import-outside-toplevel

* Reflect comments

* remove redundant comment

* As per discussion w/ Yuchen, ApplyHistoryBest is introduced as a Relax transformation pass.

* remove redundant print msg

* fix lint

* reflect comments

* Yuchen's change

* relax ConstantNode in parser and printer

* Add constant data in the metasection

* rebase

* Support ir_module(metadata=json_str)

* update test case

* remove print info

* Update tests

* clang-format

* pylint

* fix ci

* Save a copy of metadata in RelaxTransformer

* Fix comments

* fix comments

Co-authored-by: Yuchen Jin <yuchenj@cs.washington.edu>
Co-authored-by: Sunghyun Park <49998730+sunggg@users.noreply.github.com>
  • Loading branch information
3 people committed Aug 14, 2022
1 parent 140ed55 commit 899e05f
Show file tree
Hide file tree
Showing 16 changed files with 601 additions and 172 deletions.
8 changes: 7 additions & 1 deletion include/tvm/ir/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,12 @@ class IRModuleNode : public Object {
return GetAttr<TObjectRef>(attr_key, Optional<TObjectRef>(default_value));
}

/*!
* \brief Get the metadata attributes.
* \returns The additional meta-data attributes
*/
DictAttrs GetAttrs() const { return attrs; }

/*!
* \brief Check whether the module has an non-zero integer attr.
*
Expand Down Expand Up @@ -357,7 +363,7 @@ class IRModule : public ObjectRef {
* \param type_definitions Type definitions in the module.
* \param import_set Set of imported files in the module.
* \param map The module source map.
* \param attrs The module attributes.
* \param attrs The module meta-data attributes.
*/
TVM_DLL explicit IRModule(Map<GlobalVar, BaseFunc> functions,
Map<GlobalTypeVar, TypeData> type_definitions = {},
Expand Down
1 change: 0 additions & 1 deletion include/tvm/relax/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,6 @@ class VarNode : public ExprNode {
bool SEqualReduce(const VarNode* other, SEqualReducer equal) const {
equal->MarkGraphNode();
return equal(vid, other->vid) && equal(type_annotation, other->type_annotation) &&
// Do we use the analysis information in equality?
equal(checked_type_, other->checked_type_) && equal(shape_, other->shape_);
}

Expand Down
1 change: 1 addition & 0 deletions include/tvm/relay/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ class ConstantNode : public ExprNode {
v->Visit("virtual_device_", &virtual_device_);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
v->Visit("shape_", &shape_);
}

bool SEqualReduce(const ConstantNode* other, SEqualReducer equal) const {
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/ir/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def assert_structural_equal(lhs, rhs, map_free_vars=False):
The left operand.
rhs : Object
The left operand.
The right operand.
map_free_vars : bool
Whether or not shall we map free vars that does
Expand Down
27 changes: 24 additions & 3 deletions python/tvm/ir/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.
"""IRModule that holds the functions and type definitions."""
from typing import Optional

import ast
from tvm._ffi.base import string_types
import tvm._ffi

Expand All @@ -39,7 +39,7 @@ class IRModule(Node):
Map of global var to BaseFunc
"""

def __init__(self, functions=None, type_definitions=None):
def __init__(self, functions=None, type_definitions=None, attrs=None):
if functions is None:
functions = {}
elif isinstance(functions, dict):
Expand All @@ -62,7 +62,17 @@ def __init__(self, functions=None, type_definitions=None):
raise TypeError("Expect type_definitions to be Dict[GlobalTypeVar, Type]")
mapped_type_defs[k] = v
type_definitions = mapped_type_defs
self.__init_handle_by_constructor__(_ffi_api.IRModule, functions, type_definitions)

attrs = None if not attrs else attrs
if attrs is not None:
attrs = ast.literal_eval(str(attrs))
attrs = tvm.ir.make_node("DictAttrs", **attrs)
self.__init_handle_by_constructor__(
_ffi_api.IRModule,
functions,
type_definitions,
attrs,
)

def __setitem__(self, var, val):
"""Add a mapping to the module.
Expand Down Expand Up @@ -308,6 +318,17 @@ def get_attr(self, attr_key):

return _ffi_api.Module_GetAttr(self, attr_key)

def get_attrs(self):
"""Get the meta_data attributes.
Returns
-------
meta_data : DictAttrs
meta_data attributes
"""

return _ffi_api.Module_GetAttrs(self)

def with_attr(self, attr_key, attr_value):
"""Copy the IRModule and add an attribute to it.
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relax/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
Call = relay.Call
If = relay.If
const = relay.const
Constant = relay.Constant


@tvm._ffi.register_object("relax.expr.ShapeExpr")
Expand Down
56 changes: 55 additions & 1 deletion python/tvm/relax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,66 @@
from tvm.tir import PrimFunc
from tvm import IRModule

# Simply extracts tir PrimFuncs from the input IRModule

def tir_partitioner(mod: IRModule) -> List[IRModule]:
"""Extracts tir PrimFuncs from the input IRModule.
Parameters
----------
mod : IRModule
The input IRModule.
Returns
-------
output : List[IRModule]
The result tir PrimFuncs.
"""
partitions = []
for gvar in mod.get_global_vars():
if isinstance(mod[gvar], PrimFunc):
tir_mod = IRModule({})
tir_mod[gvar] = mod[gvar]
partitions.append(tir_mod)
return partitions


def metadata_partitioner(rx_txt: str) -> List[str]:
"""Extract Relax program and metadata section.
Parameters
----------
rx_txt : str
The input relax text.
Returns
-------
output : List[str]
The result list of partitioned text, the first element
is the relax program, and the second is metadata section.
"""
partitions = []
left_curly = 0
meta_start = 0
meta_end = 0
for i, char in enumerate(rx_txt):
if i < 0:
raise ValueError("The program is invalid.")
if char == "{":
if meta_start == 0:
meta_start = i
left_curly += 1
elif char == "}":
left_curly -= 1
if left_curly == 0:
meta_end = i + 1
break

if meta_end == 0:
raise ValueError("The metadata section was not found.")
metadata = rx_txt[meta_start:meta_end]
rx_program = rx_txt[meta_end:-1]

partitions.append(rx_program)
partitions.append(metadata)

return partitions
41 changes: 28 additions & 13 deletions python/tvm/script/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,12 @@
different python versions. Synr also provides an error handling context that we
use for error reporting.
"""
# pylint: disable=invalid-name, inconsistent-return-statements, no-else-return, broad-except
# pylint: disable=invalid-name, inconsistent-return-statements, no-else-return, broad-except, import-outside-toplevel
import types
import json
import operator
import inspect
import functools
from typing import Any, Callable, Dict, List, Optional, Union
from synr import ast, Transformer, to_ast

Expand Down Expand Up @@ -1364,26 +1365,40 @@ def from_source(
raise TypeError("Only function definitions are supported.")


def ir_module(input_module: type) -> IRModule:
def ir_module(input_module=None, metadata=None) -> IRModule:
"""Decorate a python class as tvm IRModule.
Parameters
----------
input_module : type
The python class to be parsed.
metadata : Optional[Union[str, DictAttrs]]
The metadata attributes to be parsed.
Returns
-------
output : IRModule
mod : IRModule
The result IRModule.
"""
if inspect.isclass(input_module):
func_dict = {
name: f for name, f in input_module.__dict__.items() if isinstance(f, BaseFunc)
}
mod = IRModule(func_dict)
mod = relax.transform.ResolveGlobals()(mod)
# FIXME(@altanh): where is the source map?
return mod

raise TypeError("Only class definitions are supported.")
if metadata is not None:
from .relax.parser import RelaxTransformer as _RelaxTransformer

_RelaxTransformer.update_meta(metadata)

if input_module is None:
return functools.partial(ir_module, metadata=metadata)

def _ir_module(input_module: type) -> IRModule:
if inspect.isclass(input_module):
func_dict = {
name: f for name, f in input_module.__dict__.items() if isinstance(f, BaseFunc)
}
mod = IRModule(func_dict, attrs=metadata)
mod = relax.transform.ResolveGlobals()(mod)
# FIXME(@altanh): where is the source map?
return mod

raise TypeError("Only class definitions are supported.")

return _ir_module(input_module)
29 changes: 22 additions & 7 deletions python/tvm/script/relax/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,32 +15,47 @@
# specific language governing permissions and limitations
# under the License.
"""TVM Script Interface for Relax Functions"""
# pylint: disable=import-outside-toplevel

import inspect
from typing import Callable
import functools

from tvm.relax import Function

from .parser import from_source


def function(input_func: Callable) -> Function:
def function(input_func=None, metadata=None) -> Function:
"""Decorate a Python function as a Relax function in TVM script.
Parameters
----------
input_func : Callable
The function to be parsed.
metadata : Optional[Union[str, DictAttrs]]
The meta_data attributes to be parsed.
Returns
-------
output : Function
The parsed Relax Function.
"""
if inspect.isfunction(input_func):
result = from_source(input_func)
result.__name__ = input_func.__name__
result.__qualname__ = input_func.__qualname__
return result
if metadata is not None:
from .parser import RelaxTransformer as _RelaxTransformer

_RelaxTransformer.update_meta(metadata)

if input_func is None:
return functools.partial(function, metadata=metadata)

def _function(input_func: Callable) -> Function:
if inspect.isfunction(input_func):
result = from_source(input_func)
result.__name__ = input_func.__name__
result.__qualname__ = input_func.__qualname__
return result
raise TypeError("Only function definitions are supported.")

raise TypeError("Only function definitions are supported.")
return _function(input_func)
Loading

0 comments on commit 899e05f

Please sign in to comment.