Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REFACTOR][TIR] Introduce PrimFuncPass. #5139

Merged
merged 3 commits into from
Mar 24, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions docs/api/python/tir.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,12 @@ tvm.tir
:imported-members:
:exclude-members: PrimExpr, const
:autosummary:



tvm.tir.transform
-----------------
.. automodule:: tvm.tir.transform
:members:
:imported-members:
:autosummary:
2 changes: 1 addition & 1 deletion include/tvm/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ class RelayExprNode : public BaseExprNode {
/*!
* \return The checked_type
*/
const Type& checked_type() const;
inline const Type& checked_type() const;
/*!
* \brief Check if the inferred(checked) type of the Expr
* is backed by a TTypeNode and return it.
Expand Down
4 changes: 4 additions & 0 deletions include/tvm/ir/type_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ class TypeFunctor<R(const Type& n, Args...)> {
virtual R VisitType_(const TypeCallNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
virtual R VisitType_(const TypeDataNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
virtual R VisitType_(const PrimTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
virtual R VisitType_(const PointerTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
virtual R VisitTypeDefault_(const Object* op, Args...) {
LOG(FATAL) << "Do not have a default for " << op->GetTypeKey();
throw; // unreachable, written to stop compiler warning
Expand All @@ -115,6 +116,7 @@ class TypeFunctor<R(const Type& n, Args...)> {
TVM_TYPE_FUNCTOR_DISPATCH(TypeCallNode);
TVM_TYPE_FUNCTOR_DISPATCH(TypeDataNode);
TVM_TYPE_FUNCTOR_DISPATCH(PrimTypeNode);
TVM_TYPE_FUNCTOR_DISPATCH(PointerTypeNode);
return vtable;
}
};
Expand All @@ -138,6 +140,7 @@ class TVM_DLL TypeVisitor :
void VisitType_(const TypeCallNode* op) override;
void VisitType_(const TypeDataNode* op) override;
void VisitType_(const PrimTypeNode* op) override;
void VisitType_(const PointerTypeNode* op) override;
};

/*!
Expand All @@ -158,6 +161,7 @@ class TVM_DLL TypeMutator :
Type VisitType_(const TypeCallNode* op) override;
Type VisitType_(const TypeDataNode* op) override;
Type VisitType_(const PrimTypeNode* op) override;
Type VisitType_(const PointerTypeNode* op) override;

private:
Array<Type> MutateArray(Array<Type> arr);
Expand Down
72 changes: 72 additions & 0 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tvm/tir/transform.h
* \brief TIR specific transformation passes.
*/
#ifndef TVM_TIR_TRANSFORM_H_
#define TVM_TIR_TRANSFORM_H_

#include <tvm/ir/transform.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/function.h>

#include <string>

namespace tvm {
namespace tir {
namespace transform {

using tvm::transform::Pass;
using tvm::transform::PassNode;
using tvm::transform::PassInfo;
using tvm::transform::PassInfoNode;
using tvm::transform::PassContext;
using tvm::transform::PassContextNode;
using tvm::transform::Sequential;

/*
* \brief Create a function pass that optimizes PrimFuncs.
*
* \param pass_func The packed function that contains the optimization.
* \param opt_level The optimization level of the function pass.
* \param name The name of the function pass.
* \param required The list of the passes that the function pass is dependent on.
*
* \return The created function pass.
*/
TVM_DLL Pass CreatePrimFuncPass(const runtime::TypedPackedFunc<
PrimFunc(PrimFunc, IRModule, PassContext)>& pass_func,
int opt_level,
const std::string& name,
const tvm::Array<tvm::PrimExpr>& required);

/*!
* \brief Create PrimFuncPass to combine context calls in the host function.
*
* \return The pass.
*/
Pass CombineContextCall();

} // namespace transform
} // namespace tir
} // namespace tvm

#endif // TVM_TIR_TRANSFORM_H_
1 change: 1 addition & 0 deletions python/tvm/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,4 @@

from . import ir_builder
from . import ir_pass
from . import transform
21 changes: 21 additions & 0 deletions python/tvm/tir/transform/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Namespace of all TIR transformations"""
# pylint: disable=wildcard-import, invalid-name

from .function_pass import prim_func_pass, PrimFuncPass
from .transform import *
21 changes: 21 additions & 0 deletions python/tvm/tir/transform/_ffi_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""FFI APIs for tvm.tir.transform"""
import tvm._ffi


tvm._ffi._init_api("tir.transform", __name__)
149 changes: 149 additions & 0 deletions python/tvm/tir/transform/function_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""TIR specific function pass support."""
import inspect
import functools

import tvm._ffi
from tvm.ir.transform import Pass, PassInfo

from . import _ffi_api


@tvm._ffi.register_object("tir.PrimFuncPass")
class PrimFuncPass(Pass):
"""A pass that works on each :py:func:`tvm.tir.PrimFunc` in a module. A function
pass class should be created through py:func:`tvm.tir.transform.function_pass`.
"""


def _wrap_class_function_pass(pass_cls, pass_info):
"""Wrap a python class as function pass"""
class PyFunctionPass(PrimFuncPass):
"""Internal wrapper class to create a class instance."""
def __init__(self, *args, **kwargs):
# initialize handle in cass pass_cls creation failed.fg
self.handle = None
inst = pass_cls(*args, **kwargs)
# it is important not to capture self to
# avoid a cyclic dependency
def _pass_func(func, mod, ctx):
return inst.transform_function(func, mod, ctx)
self.__init_handle_by_constructor__(
_ffi_api.CreatePrimFuncPass, _pass_func, pass_info)
self._inst = inst

def __getattr__(self, name):
# fall back to instance attribute if there is not any
return self._inst.__getattribute__(name)

functools.update_wrapper(PyFunctionPass.__init__, pass_cls.__init__)
PyFunctionPass.__name__ = pass_cls.__name__
PyFunctionPass.__doc__ = pass_cls.__doc__
PyFunctionPass.__module__ = pass_cls.__module__
return PyFunctionPass


def prim_func_pass(pass_func=None, opt_level=None, name=None, required=None):
"""Decorate a function pass.

This function returns a callback when pass_func
is provided. Otherwise, it returns the created function pass using the
given optimization function.

Parameters
----------
pass_func : Optional[Callable[(PrimFunc, IRModule, PassContext) -> PrimFunc]]
The transformation function or class.

opt_level : int
The optimization level of this module pass.

name : Optional[str]
The name of the function pass. The name could be empty. In this case, the
name of the optimization function will be used as the pass name.

required : Optional[List[str]]
The list of passes that the function pass is dependent on.

Returns
-------
create_function_pass : Union[Callable, FunctionPass]

A decorator will be returned if pass_func is not provided,
otherwise return the decorated result.
The returned decorator has two behaviors depending on the input:
A new FunctionPass will be returned when we decorate a pass function.
A new FunctionPass class will be returned when we decorate a class type.

Examples
--------
The following code block decorates a function pass class.

.. code-block:: python

@tvm.tir.transform.prim_func_pass(opt_level=1)
class TestReplaceFunc:
def __init__(self, new_func):
self.new_func = new_func

def transform_function(self, func, mod, ctx):
# just for demo purposes
# transform func to new_func
return self.new_func

The following code creates a function pass by decorating
a user defined transform function.

.. code-block:: python

@tvm.tir.transform.prim_func_pass(opt_level=2)
def transform(func, mod, ctx):
# my transformations here.
return func

function_pass = transform
assert isinstance(function_pass, transform.FunctionPass)
assert function_pass.info.opt_level == 2

# Given a module m, the optimization could be invoked as the follwoing:
updated_mod = function_pass(m)
# Now constant folding should have been applied to every function in
# the provided module m. And the updated module will be returned.
"""

if opt_level is None:
raise ValueError("Please provide opt_level for the funtion pass.")

required = required if required else []
if not isinstance(required, (list, tuple)):
raise TypeError("Required is expected to be the type of " +
"list/tuple.")

def create_function_pass(pass_arg):
"""Internal function that creates a function pass"""
fname = name if name else pass_arg.__name__
info = PassInfo(opt_level, fname, required)
if inspect.isclass(pass_arg):
return _wrap_class_function_pass(pass_arg, info)
if not isinstance(pass_arg, (types.FunctionType, types.LambdaType)):
raise TypeError("pass_func must be a callable for Module pass")
return _ffi_api.MakeFunctionPass(pass_arg, info)

if pass_func:
return create_function_pass(pass_func)
return create_function_pass
31 changes: 31 additions & 0 deletions python/tvm/tir/transform/transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Wrapping existing transformations."""
# pylint: disable=invalid-name

from . import _ffi_api


def CombineContextCall():
"""Combine context calls in the host function.

Returns
-------
fpass : tvm.ir.transform.Pass
The result pass
"""
return _ffi_api.CombineContextCall()
4 changes: 2 additions & 2 deletions src/ir/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -158,13 +158,13 @@ void IRModuleNode::Add(const GlobalVar& var,
GetRef<relay::Function>(ptr));
}

auto type = checked_func->checked_type();
Type type = checked_func->checked_type();
CHECK(type.as<relay::IncompleteTypeNode>() == nullptr);

if (functions.find(var) != functions.end()) {
CHECK(update)
<< "Already have definition for " << var->name_hint;
auto old_type = functions[var].as<relay::FunctionNode>()->checked_type();
auto old_type = functions[var]->checked_type();
CHECK(relay::AlphaEqual(type, old_type))
<< "Module#update changes type, not possible in this mode.";
}
Expand Down
Loading