Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Hybrid script] Backend support #2477

Merged
merged 31 commits into from
Feb 13, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ include(cmake/modules/contrib/BLAS.cmake)
include(cmake/modules/contrib/Random.cmake)
include(cmake/modules/contrib/Sort.cmake)
include(cmake/modules/contrib/NNPack.cmake)
include(cmake/modules/contrib/HybridDump.cmake)

add_library(tvm SHARED ${COMPILER_SRCS} ${RUNTIME_SRCS})
add_library(tvm_topi SHARED ${TOPI_SRCS})
Expand Down
3 changes: 3 additions & 0 deletions cmake/modules/contrib/HybridDump.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
message(STATUS "Build with contrib.hybriddump")
file(GLOB HYBRID_CONTRIB_SRC src/contrib/hybrid/*.cc)
list(APPEND COMPILER_SRCS ${HYBRID_CONTRIB_SRC})
14 changes: 14 additions & 0 deletions docs/langref/hybrid_script.rst
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,20 @@ You can also do loop-thread bind by writing code like this:
a[tx] = b[tx]


Assert Statement
~~~~~~~~~~~~~~~~

Assert statement is supported, you can simply use it as it is in standard Python.

.. code-block:: python

assert cond, mesg

.. note::

``Assert`` is NOT a function call. Users are encouraged to use assert in the way
presented above --- condition followed by message. It fits both Python AST and HalideIR.

Keywords
~~~~~~~~
- For keywords: ``serial``, ``range``, ``unroll``, ``parallel``, ``vectorize``, ``bind``, ``const_expr``
Expand Down
25 changes: 20 additions & 5 deletions python/tvm/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,25 @@ def get_binds(args, binds=None):
return binds, arg_list


def form_body(sch):
"""According to the given schedule, form the raw body
Parameters
----------
sch : tvm.schedule.Schedule
The given scheduler to form the raw body

Returns
-------
The body formed according to the given schedule
"""
# normalize schedule first
sch = sch.normalize()
bounds = schedule.InferBound(sch)
stmt = schedule.ScheduleOps(sch, bounds)
stmt = ir_pass.InjectPrefetch(stmt)
return stmt


def lower(sch,
args,
name="default_function",
Expand Down Expand Up @@ -337,11 +356,7 @@ def lower(sch,

# Phase 0
if isinstance(sch, schedule.Schedule):
# normalize schedule first
sch = sch.normalize()
bounds = schedule.InferBound(sch)
stmt = schedule.ScheduleOps(sch, bounds)
stmt = ir_pass.InjectPrefetch(stmt)
stmt = form_body(sch)

for f in lower_phase0:
stmt = f(stmt)
Expand Down
75 changes: 72 additions & 3 deletions python/tvm/hybrid/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,77 @@
1. Users can write some preliminary versions of the computation patterns
have not been supported yet and verify it across the real execution and
python semantic emulation.
2. Developers can build HalideIR by writing Python code.
2. So far, it is a text format dedicated to HalideIR Phase 0. Refer tvm.lower
for more details. A larger ambition of this module is to support all levels of
HalideIR.
"""

from .api import script
from .parser import parse_python
# TODO(@were): Make this module more complete.
# 1. Support HalideIR dumping to Hybrid Script
# 2. Support multi-level HalideIR

from __future__ import absolute_import as _abs

from .._ffi.base import decorate
from .._ffi.function import _init_api
from ..build_module import form_body

from .module import HybridModule
from .parser import source_to_op
from .util import _pruned_source


def script(pyfunc):
"""Decorate a python function function as hybrid script.

The hybrid function support emulation mode and parsing to
the internal language IR.
were marked this conversation as resolved.
Show resolved Hide resolved

Returns
-------
hybrid_func : function
A decorated hybrid script function.
"""
def wrapped_func(func, *args, **kwargs): #pylint: disable=missing-docstring
from .util import _is_tvm_arg_types
if _is_tvm_arg_types(args):
src = _pruned_source(func)
return source_to_op(src, func.__globals__, args)

from .runtime import _enter_hybrid_runtime, _restore_runtime
intersect = _enter_hybrid_runtime(func)
value = func(*args, **kwargs)
_restore_runtime(func, intersect)
return value

return decorate(pyfunc, wrapped_func)


def build(sch, inputs, outputs, name="hybrid_func"):
"""Dump the corrent schedule to hybrid module

Parameters
----------
sch: Schedule
The schedule to be dumped

inputs: An array of Tensors or Vars
The inputs of the function body

outputs: An array of Tensors
The outputs of the function body

Returns
-------
module: HybridModule
The built results is wrapped in a HybridModule.
The usage of HybridModule is roughly the same as normal TVM-built modules.
"""

stmt = form_body(sch)
src = _Dump(stmt, inputs, outputs, name)

return HybridModule(src, name)


_init_api("tvm.hybrid")
43 changes: 0 additions & 43 deletions python/tvm/hybrid/api.py

This file was deleted.

27 changes: 27 additions & 0 deletions python/tvm/hybrid/calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .. import ir_pass
from ..stmt import For
from .util import _internal_assert
from ..intrin import call_pure_intrin

#pylint: disable=redefined-builtin

Expand Down Expand Up @@ -104,3 +105,29 @@ def len(func_id, args):
except: #pylint: disable=bare-except
_internal_assert(args[0].shape.__len__() == 1, "Only one-dimension array can get len")
return _api.convert(args[0].shape[0])


def _cast(func_id, args):
_internal_assert(args.__len__() == 1 and isinstance(args[0], _expr.Expr), \
"Only one expression can be cast")
return _make.Cast(func_id, args[0])

float16 = float32 = float64 = _cast #pylint: disable=invalid-name
int8 = int16 = int32 = int64 = _cast #pylint: disable=invalid-name
uint8 = uint16 = uint32 = uint64 = _cast #pylint: disable=invalid-name


def ceil_div(func_id, args):
_internal_assert(func_id == "ceil_div", "This function cannot be directly invoked!")
_internal_assert(args.__len__() == 2, "2 arguments expected for division!")
_internal_assert(isinstance(args[0], _expr.Expr), "Only expressions can div")
_internal_assert(isinstance(args[1], _expr.Expr), "Only expressions can div")
a, b = args[0], args[1]
return (a + b - 1) / b


def likely(func_id, args):
_internal_assert(args.__len__() == 1, \
"Only one expression can be likely")
_internal_assert(func_id == "likely", "This function cannot be directly invoked!")
return call_pure_intrin(args[0].dtype, 'likely', *args)
100 changes: 100 additions & 0 deletions python/tvm/hybrid/module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
"""Methods and data structures to support dumping HalideIR to Hybrid Script.
This allows users to do quick hack to generated HalideIR and cast it back to
TVM modules.

To enable this feature, you need to build with -DUSE_HYBRID_DUMP=ON.
"""

import ast
import imp

from ..contrib import util
from .util import _internal_assert
from .util import _is_tvm_arg_types
from .parser import source_to_op


class HybridModule(object):
"""The usage of Hybrid Module is very similar to conventional TVM module,
but conventional TVM module requires a function body which is already fully
lowered. This contradicts to the fact that Hybrid Module is originally a text
format for Phase 0 HalideIR. Thus, a totally separated module is defined."""


def __init__(self, src=None, name=None):
"""The constructor of this a hybrid module

Parameters
----------
src : str
The source code of this module

name : str
The name of this module
"""
self.src_ = self.name = self.func_ = self.root_ = None
if src is not None:
temp = util.tempdir()
dst = temp.relpath("script.py")
with open(dst, 'w') as f:
f.write("import tvm\n@tvm.hybrid.script\n%s" % src)

if name is not None:
self.name = name
self.load(dst)


def __call__(self, *args):
if _is_tvm_arg_types(args):
return source_to_op(self.root_, globals(), args)
return self.func_(*args)


def get_source(self):
return self.src_


def save(self, path):
if not path.endswith('.py'):
path = path + '.py'
with open(path, 'w') as f:
f.write(self.src_)


def load(self, path):
"""Load the module from a python file

Parameters
----------
path : str
Path to the given python file
"""
with open(path, 'r') as f:
self.src_ = f.read()

src = self.src_

class FindFunc(ast.NodeVisitor):
""" Find the function in module to be loaded module. """
#pylint: disable=invalid-name
def __init__(self):
self.name = None
self.root = None


def visit_FunctionDef(self, node):
_internal_assert(self.name is None, "For now, only one function supported!")
self.name = node.name
_internal_assert(self.root is None, "For now, only one function supported!")
self.root = node

root = ast.parse(src)
finder = FindFunc()
finder.visit(root)
_internal_assert(finder.name is not None and finder.root is not None, \
"No function found!")
if self.name is None:
self.name = finder.name
self.root_ = finder.root
py_module = imp.load_source(self.name, path)
self.func_ = getattr(py_module, self.name)
Loading