Skip to content

Commit

Permalink
[Unity][Pass] BindParams pass, FoldConstant pass (#14016)
Browse files Browse the repository at this point in the history
This PR introduces FoldConstant/BindParam passes.
  • Loading branch information
sunggg authored Feb 17, 2023
1 parent 9af6afc commit ba47501
Show file tree
Hide file tree
Showing 7 changed files with 861 additions and 47 deletions.
133 changes: 87 additions & 46 deletions include/tvm/ir/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,68 @@ enum class CallingConv : int {
kDeviceKernelLaunch = 2,
};

/*!
* \brief Supported linkage types.
*/
enum class LinkageType : int {
/*!
* \brief Internal linkage.
*/
kInternal = 0,
/*!
* \brief External linkage.
- Function with external linkage should have a global symbol attached to it.
*/
kExternal = 1
};

/*!
* \brief Generic attribute names that can be attached to any function.
*
* \sa tvm::tir::attr, tvm::relay::attr
*/
namespace attr {
/*!
* \brief Indicates the special calling convention.
*
* Type: Integer
*
* \sa tvm::CallingConv
*/
constexpr const char* kCallingConv = "calling_conv";

/*!
* \brief Compilation target of the function.
*
* Type: Target
*
* \sa tvm::Target
*/
constexpr const char* kTarget = "target";

/*!
* \brief Global linker symbol of the function in generated code.
*
* This option forces the code generator to name the
* function with the given.
*
* For example, we could set a global_symbol of a function
* early to make sure that we can always refer to it by
* the symbol name in the generated DLL.
*
* We should not set the attribute for local functions,
* so that the compiler can freely rename them.
*
* A unique global symbol will be automatically assigned
* to each function in the module before the target code
* generation phase.
*
* Type: String
*/
constexpr const char* kGlobalSymbol = "global_symbol";

} // namespace attr

/*!
* \brief Base node of all functions.
*
Expand Down Expand Up @@ -130,6 +192,31 @@ class BaseFuncNode : public RelayExprNode {
* \endcode
*/
bool HasNonzeroAttr(const std::string& attr_key) const { return attrs.HasNonzeroAttr(attr_key); }
/*!
* \brief Get the type of the linkage.
*
* Currently, we only consider external/internal linkage.
* This can be extended in the future when necessary.
*
* \return Linkage type.
*
* \code
*
* void Example(const BaseFunc& f) {
* if (f->GetLinkageType() == tvm::LinkageType::kExternal) {
* // Do not remove a function with external linkage
* }
* }
*
* \endcode
*/

LinkageType GetLinkageType() const {
if (GetAttr<String>(attr::kGlobalSymbol))
return LinkageType::kExternal;
else
return LinkageType::kInternal;
}

static constexpr const char* _type_key = "BaseFunc";
static constexpr const uint32_t _type_child_slots = 2;
Expand All @@ -145,51 +232,5 @@ class BaseFunc : public RelayExpr {
TVM_DEFINE_OBJECT_REF_METHODS(BaseFunc, RelayExpr, BaseFuncNode);
};

/*!
* \brief Generic attribute names that can be attached to any function.
*
* \sa tvm::tir::attr, tvm::relay::attr
*/
namespace attr {
/*!
* \brief Indicates the special calling convention.
*
* Type: Integer
*
* \sa tvm::CallingConv
*/
constexpr const char* kCallingConv = "calling_conv";

/*!
* \brief Compilation target of the function.
*
* Type: Target
*
* \sa tvm::Target
*/
constexpr const char* kTarget = "target";

/*!
* \brief Global linker symbol of the function in generated code.
*
* This option forces the code generator to name the
* function with the given.
*
* For example, we could set a global_symbol of a function
* early to make sure that we can always refer to it by
* the symbol name in the generated DLL.
*
* We should not set the attribute for local functions,
* so that the compiler can freely rename them.
*
* A unique global symbol will be automatically assigned
* to each function in the module before the target code
* generation phase.
*
* Type: String
*/
constexpr const char* kGlobalSymbol = "global_symbol";

} // namespace attr
} // namespace tvm
#endif // TVM_IR_FUNCTION_H_
15 changes: 15 additions & 0 deletions include/tvm/relax/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,22 @@ TVM_DLL Pass RewriteDataflowReshape();
* \return The Pass.
*/
TVM_DLL Pass AttachGlobalSymbol();
/*!
* \brief Bind params of function of the module to constant tensors.
*
* \param func_name The name of the function to bind parameters.
* \param params The parameters to bind.
*
* \return The Pass.
*/
TVM_DLL Pass BindParams(String func_name, Map<String, runtime::NDArray> params);

/*!
* \brief Fold constant expressions.
*
* \return The Pass.
*/
TVM_DLL Pass FoldConstant();
} // namespace transform
} // namespace relax
} // namespace tvm
Expand Down
62 changes: 61 additions & 1 deletion python/tvm/relax/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
import functools
import inspect
import types
from typing import Callable, Union
from typing import Callable, Dict, Union, Optional, List
import numpy as np # type: ignore

import tvm.ir
from . import _ffi_api
Expand Down Expand Up @@ -115,6 +116,65 @@ def AttachGlobalSymbol() -> tvm.ir.transform.Pass:
return _ffi_api.AttachGlobalSymbol() # type: ignore


def BindParams(
func_name: str,
params: Dict[str, Union[tvm.runtime.NDArray, np.ndarray]],
) -> tvm.ir.transform.Pass:
"""Bind params of function of the module to constant tensors.
Parameters
----------
func_name: str
The function name to be bound
params : Dict[str, Union[tvm.runtime.NDArray, np.ndarray]]
The map from param name to constant tensors.
Returns
-------
ret: tvm.ir.transform.Pass
"""
tvm_params = {}
for k, v in params.items():
if isinstance(v, np.ndarray):
v = tvm.nd.array(v)
assert isinstance(
v, tvm.runtime.NDArray
), f"param values are expected to be TVM.NDArray or numpy.ndarray, but got {type(v)}"
tvm_params[k] = v

return _ffi_api.BindParams(func_name, tvm_params) # type: ignore


def RemoveUnusedFunctions(entry_functions: Optional[List[str]] = None) -> tvm.ir.transform.Pass:
"""Remove unused relax/prim functions without external linkage in a IRModule.
Parameters
----------
entry_functions: Optional[List[str]]
The set of entry functions to start from.
Returns
-------
ret : tvm.transform.Pass
The registered pass to remove unused functions.
"""
if entry_functions is None:
entry_functions = ["main"]
return _ffi_api.RemoveUnusedFunctions(entry_functions) # type: ignore


def FoldConstant() -> tvm.ir.transform.Pass:
"""Fold constant expressions.
Returns
-------
ret: tvm.ir.transform.Pass
"""
return _ffi_api.FoldConstant() # type: ignore


def AnnotateTIROpPattern() -> tvm.ir.transform.Pass:
"""Annotate Op Pattern Kind for TIR functions
Expand Down
113 changes: 113 additions & 0 deletions src/relax/transform/bind_params.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

#include <tvm/driver/driver_api.h>
#include <tvm/ir/function.h>
#include <tvm/relax/expr.h>
#include <tvm/relax/expr_functor.h>
#include <tvm/relax/transform.h>
#include <tvm/relax/type.h>
#include <tvm/tir/op.h>

#include <utility>

namespace tvm {
namespace relax {

/*!
* \brief Bind params to function by using name
* \param func Relax function
* \param params params dict
* \return Function
*/
inline Function BindParamsByName(Function func, const Map<String, runtime::NDArray>& params) {
std::unordered_map<std::string, Var> name_dict;
std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> repeat_var;
for (auto arg : func->params) {
const auto& name = arg->name_hint();
if (name_dict.count(name)) {
repeat_var.insert(name_dict[name]);
} else {
name_dict[name] = arg;
}
}

std::unordered_map<Var, Expr, ObjectPtrHash, ObjectPtrEqual> bind_dict;
for (auto& kv : params) {
if (name_dict.count(kv.first) == 0) {
continue;
}
auto arg = name_dict.at(kv.first);
if (repeat_var.count(arg)) {
LOG(FATAL) << "ValueError: Multiple args in the function have name " << kv.first;
}
bind_dict[arg] = Constant(kv.second);
}
Expr bound_expr = Bind(func, bind_dict);
Function ret = Downcast<Function>(bound_expr);
ICHECK(ret.defined()) << "The returning type is expected to be a Relax Function."
<< "\n";
return ret;
}

/*!
* \brief Bind params to a specific function in a module
* \param m The module
* \param func_name The name of the specific function
* \param param The param dict
* \return The module after binding params.
*/
IRModule BindParam(IRModule m, String func_name, Map<String, runtime::NDArray> param) {
IRModuleNode* new_module = m.CopyOnWrite();
Map<GlobalVar, BaseFunc> functions = m->functions;
for (const auto& func_pr : functions) {
if (const auto* relax_f = func_pr.second.as<FunctionNode>()) {
if (relax_f->GetLinkageType() == LinkageType::kExternal) {
// Use global_symbol if it's external linkage
Optional<String> gsymbol = relax_f->GetAttr<String>(tvm::attr::kGlobalSymbol);
if (gsymbol.defined() && gsymbol.value() == func_name) {
Function f_after_bind = BindParamsByName(GetRef<Function>(relax_f), param);
new_module->Update(func_pr.first, f_after_bind);
}
} else {
// Use global var's name_hint if it's internal linkage
if (func_pr.first->name_hint == func_name) {
Function f_after_bind = BindParamsByName(GetRef<Function>(relax_f), param);
new_module->Update(func_pr.first, f_after_bind);
}
}
}
}
return GetRef<IRModule>(new_module);
}

namespace transform {

Pass BindParams(String func_name, Map<String, runtime::NDArray> params) {
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
[=](IRModule mod, PassContext pc) { return BindParam(std::move(mod), func_name, params); };
return CreateModulePass(pass_func, 0, "BindParams", {});
}

TVM_REGISTER_GLOBAL("relax.transform.BindParams").set_body_typed(BindParams);

} // namespace transform

} // namespace relax
} // namespace tvm
Loading

0 comments on commit ba47501

Please sign in to comment.