Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR][TVMScript] specialize #8354

Merged
merged 4 commits into from
Jul 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/tvm/tir/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ TVM_DLL Array<Var> UndefinedVars(const PrimExpr& expr);
TVM_DLL CallEffectKind SideEffect(const PrimExpr& expr);

/*!
* \brief Whether e expression used any var in variable set..
* \brief Whether e expression used any var in variable set.
* \param expr The expression to be checked.
* \param vset_contains The check function to see if var is in the vset.
* \return Whether e uses vset.
Expand Down
1 change: 1 addition & 0 deletions include/tvm/tir/buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ class Buffer : public ObjectRef {
TVM_DLL Stmt vstore(Array<PrimExpr> begin, PrimExpr value) const;

TVM_DEFINE_OBJECT_REF_METHODS(Buffer, ObjectRef, BufferNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferNode);
};

/*!
Expand Down
38 changes: 38 additions & 0 deletions include/tvm/tir/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,44 @@ class LinkedParam : public ObjectRef {
TVM_DEFINE_OBJECT_REF_COW_METHOD(LinkedParamNode);
};

/*!
* \brief Specialize parameters of PrimFunc.
* \param func The PrimFunc to be specialized.
* \param param_map The mapping from function params to the instance.
* \return The new function with parameter specialized.
* \note We can define a Meta TIR function with symbolic shape:
*
* \code
* @tvm.script.tir
* def mem_copy(a: ty.handle, b: ty.handle, m: ty.int32, n: ty.int32) -> None:
* A = tir.match_buffer(a, (m, n), "float32")
* B = tir.match_buffer(b, (m, n), "float32")
*
* with tir.block([m, n], "") as [vi, vj]:
* B[vi, vj] = A[vi, vj]
* \endcode
*
* Then we can make it specialized with given shapes or buffers.
*
* \code
* a, _, m, n = mem_copy.params
* func = mem_copy.specialize({a: tir.decl_buffer((16, 16))})
* # or
* func = mem_copy.specialize({n: 16, m: 16})
* \endcode
*
* \code {.language-id}
* @tvm.script.tir
* def mem_copy_16_16(a: ty.handle, b: ty.handle) -> None:
* A = tir.match_buffer(a, (16, 16), "float32")
* B = tir.match_buffer(b, (16, 16), "float32")
*
* with tir.block([16, 16], "") as [vi, vj]:
* B[vi, vj] = A[vi, vj]
* \endcode
*/
PrimFunc Specialize(PrimFunc func, const Map<Var, ObjectRef>& param_map);

/*!
* \brief PrimFunc specific attribute names.
*
Expand Down
55 changes: 54 additions & 1 deletion python/tvm/tir/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@
# under the License.
"""Function data types."""

from typing import Mapping, Union

import tvm._ffi
import tvm.runtime
from tvm.runtime import Object
from tvm.ir import BaseFunc
from .buffer import Buffer
from .expr import Var
from .expr import Var, PrimExpr
from . import _ffi_api


Expand Down Expand Up @@ -85,3 +87,54 @@ def with_body(self, new_body, span=None):
The created new function.
"""
return PrimFunc(self.params, new_body, self.ret_type, self.buffer_map, self.attrs, span)

def specialize(self, param_map: Mapping[Var, Union[PrimExpr, Buffer]]):
"""Specialize parameters of PrimFunc

Hzfengsy marked this conversation as resolved.
Show resolved Hide resolved
Parameters
----------

param_map : Mapping[Var, Union[PrimExpr, Buffer]]
The mapping from function params to the instance

Examples
--------
We can define a Meta TIR function with symbolic shape:

.. code-block:: python

@tvm.script.tir
def mem_copy(a: ty.handle, b: ty.handle, m: ty.int32, n: ty.int32) -> None:
A = tir.match_buffer(a, (m, n), "float32")
B = tir.match_buffer(b, (m, n), "float32")

with tir.block([m, n], "") as [vi, vj]:
B[vi, vj] = A[vi, vj]

Then we can make it specialized with given shapes or buffers.

.. code-block:: python

a, _, m, n = mem_copy.params
func = mem_copy.specialize({a: tir.decl_buffer((16, 16))})
# or
func = mem_copy.specialize({n: 16, m: 16})

The specialized function:

.. code-block:: python

@tvm.script.tir
def mem_copy_16_16(a: ty.handle, b: ty.handle) -> None:
A = tir.match_buffer(a, (16, 16), "float32")
B = tir.match_buffer(b, (16, 16), "float32")

with tir.block([16, 16], "") as [vi, vj]:
B[vi, vj] = A[vi, vj]

Returns
-------
func : PrimFunc
The new function with parameter specialized
"""
return _ffi_api.Specialize(self, param_map)
Loading