Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[IR][TRANSFORM] Enable CopyOnWrite for TIR passes. #5309

Merged
merged 1 commit into from
Apr 12, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions include/tvm/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ class PrimExpr : public BaseExpr {
private:
// Internal function for conversion.
friend struct runtime::PackedFuncValueConverter<PrimExpr>;
TVM_DLL static PrimExpr FromObject_(ObjectPtr<Object> ptr);
TVM_DLL static PrimExpr FromObject_(ObjectRef ref);
};

/*!
Expand Down Expand Up @@ -464,9 +464,8 @@ struct PackedFuncValueConverter<PrimExpr> {
if (val.type_code() == kDLFloat) {
return PrimExpr(static_cast<float>(val.operator double()));
}
TVM_CHECK_TYPE_CODE(val.type_code(), kTVMObjectHandle);
Object* ptr = val.ptr<Object>();
return PrimExpr::FromObject_(GetObjectPtr<Object>(ptr));

return PrimExpr::FromObject_(val.AsObjectRef<ObjectRef>());
}
};
} // namespace runtime
Expand Down
23 changes: 16 additions & 7 deletions include/tvm/ir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
#include <tvm/ir/error.h>
#include <tvm/ir/module.h>
#include <string>
#include <utility>

namespace tvm {
namespace transform {
Expand Down Expand Up @@ -251,8 +252,8 @@ class PassNode : public Object {
*
* \return The transformed module.
*/
IRModule operator()(const IRModule& mod) const {
return this->operator()(mod, PassContext::Current());
IRModule operator()(IRModule mod) const {
return this->operator()(std::move(mod), PassContext::Current());
}

/*!
Expand All @@ -263,7 +264,7 @@ class PassNode : public Object {
*
* \return The transformed module.
*/
virtual IRModule operator()(const IRModule& mod,
virtual IRModule operator()(IRModule mod,
const PassContext& pass_ctx) const = 0;

void VisitAttrs(AttrVisitor* v) {}
Expand All @@ -277,14 +278,22 @@ class Pass : public ObjectRef {
/*!
* \brief Transform mod using the default PassContext in the current scope.
*
* \code
*
* // If you do no longer need the input module
* // it is recommended to use std::move to move your input module.
* mod = pass(std::move(mod));
*
* \endcode
*
* \param mod The module that an optimization pass runs on.
*
* \return The transformed module.
*/
IRModule operator()(const IRModule& mod) const {
IRModule operator()(IRModule mod) const {
const PassNode* node = operator->();
CHECK(node != nullptr);
return node->operator()(mod);
return node->operator()(std::move(mod));
}
/*!
* \brief Transform mod using a functor under a given pass context.
Expand All @@ -294,11 +303,11 @@ class Pass : public ObjectRef {
*
* \return The transformed module.
*/
IRModule operator()(const IRModule& mod,
IRModule operator()(IRModule mod,
const PassContext& pass_ctx) const {
const PassNode* node = operator->();
CHECK(node != nullptr);
return node->operator()(mod, pass_ctx);
return node->operator()(std::move(mod), pass_ctx);
}

TVM_DEFINE_OBJECT_REF_METHODS(Pass, ObjectRef, PassNode);
Expand Down
1 change: 1 addition & 0 deletions include/tvm/runtime/data_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,7 @@ inline const char* TypeCode2Str(int type_code) {
case kTVMModuleHandle: return "ModuleHandle";
case kTVMNDArrayHandle: return "NDArrayContainer";
case kTVMObjectHandle: return "Object";
case kTVMObjectRValueRefArg: return "ObjectRValueRefArg";
default: LOG(FATAL) << "unknown type_code="
<< static_cast<int>(type_code); return "";
}
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/runtime/object.h
Original file line number Diff line number Diff line change
Expand Up @@ -885,7 +885,7 @@ inline ObjectPtr<BaseType> GetObjectPtr(ObjType* ptr) {

template <typename SubRef, typename BaseRef>
inline SubRef Downcast(BaseRef ref) {
CHECK(ref->template IsInstance<typename SubRef::ContainerType>())
CHECK(!ref.defined() || ref->template IsInstance<typename SubRef::ContainerType>())
<< "Downcast from " << ref->GetTypeKey() << " to "
<< SubRef::ContainerType::_type_key << " failed.";
return SubRef(std::move(ref.data_));
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/runtime/packed_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -1357,7 +1357,7 @@ inline void TVMArgsSetter::SetObject(size_t i, T&& value) const {
ptr->IsInstance<Module::ContainerType>())) {
values_[i].v_handle = ptr;
type_codes_[i] = kTVMModuleHandle;
} else if (std::is_rvalue_reference<T>::value) {
} else if (std::is_rvalue_reference<decltype(value)>::value) {
values_[i].v_handle = const_cast<Object**>(&(value.data_.data_));
type_codes_[i] = kTVMObjectRValueRefArg;
} else {
Expand Down
5 changes: 3 additions & 2 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -197,11 +197,12 @@ TVM_DLL Pass CombineContextCall();
/*!
* \brief Narrow down PrimExpr datatype in stmt to target_bits.
*
* \note Run this pass after StorageFlatten.
* \param target_bits The target bits
*
* \note Run this pass after storage flatten.
* \return The pass.
*/
TVM_DLL Pass NarrowDataType();
TVM_DLL Pass NarrowDataType(int target_bits);

} // namespace transform
} // namespace tir
Expand Down
1 change: 1 addition & 0 deletions python/tvm/error.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def __init__(self, msg):
register_error("ValueError", ValueError)
register_error("TypeError", TypeError)
register_error("AttributeError", AttributeError)
register_error("KeyError", KeyError)


@register_error
Expand Down
13 changes: 9 additions & 4 deletions python/tvm/tir/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def Apply(ftransform):
# pylint: disable=unused-argument
def _transform(func, mod, ctx):
return ftransform(func)
return _fpass.prim_func_pass(_transform, opt_level=0)
return _fpass.prim_func_pass(_transform, opt_level=0, name="Apply")


def Filter(fcond):
Expand All @@ -57,7 +57,7 @@ def Filter(fcond):
# pylint: disable=unused-argument
def _transform(func, mod, ctx):
return func if fcond(func) else None
return _fpass.prim_func_pass(_transform, opt_level=0)
return _fpass.prim_func_pass(_transform, opt_level=0, name="Filter")


def LowerCustomDatatypes():
Expand Down Expand Up @@ -221,9 +221,14 @@ def CombineContextCall():
return _ffi_api.CombineContextCall()


def NarrowDataType():
def NarrowDataType(target_bits):
"""Narrow down PrimExpr datatype in stmt to target_bits.

Parameters
----------
target_bits : int
The target bit configuration.

Returns
-------
fpass : tvm.ir.transform.Pass
Expand All @@ -233,4 +238,4 @@ def NarrowDataType():
----
Run this pass after StorageFlatten.
"""
return _ffi_api.NarrowDataType()
return _ffi_api.NarrowDataType(target_bits)
20 changes: 10 additions & 10 deletions src/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,21 +43,21 @@ PrimExpr::PrimExpr(float value)
PrimExpr::PrimExpr(runtime::String value)
: PrimExpr(tir::StringImmNode::make(value)) {}

PrimExpr PrimExpr::FromObject_(ObjectPtr<Object> ptr) {
PrimExpr PrimExpr::FromObject_(ObjectRef ref) {
using runtime::ObjectTypeChecker;
if (ptr->IsInstance<tir::IterVarNode>()) {
return tir::IterVar(ptr)->var;
if (auto* ptr = ref.as<tir::IterVarNode>()) {
return GetRef<tir::IterVar>(ptr)->var;
}
if (ptr->IsInstance<te::TensorNode>()) {
return te::Tensor(ptr)();
if (auto* ptr = ref.as<te::TensorNode>()) {
return GetRef<te::Tensor>(ptr)();
}
if (ptr->IsInstance<runtime::StringObj>()) {
return tir::StringImmNode::make(runtime::String(ptr));
if (auto* ptr = ref.as<runtime::StringObj>()) {
return tir::StringImmNode::make(GetRef<runtime::String>(ptr));
}
CHECK(ObjectTypeChecker<PrimExpr>::Check(ptr.get()))
CHECK(ObjectTypeChecker<PrimExpr>::Check(ref.get()))
<< "Expect type " << ObjectTypeChecker<PrimExpr>::TypeName()
<< " but get " << ptr->GetTypeKey();
return PrimExpr(ptr);
<< " but get " << ref->GetTypeKey();
return Downcast<PrimExpr>(ref);
}


Expand Down
16 changes: 14 additions & 2 deletions src/ir/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,20 @@ bool IRModuleNode::ContainGlobalTypeVar(const std::string& name) const {

GlobalVar IRModuleNode::GetGlobalVar(const std::string& name) const {
auto it = global_var_map_.find(name);
CHECK(it != global_var_map_.end())
<< "Cannot find global var " << name << " in the Module";
if (it == global_var_map_.end()) {
std::ostringstream msg;
msg << "ValueError: Cannot find global var \"" << name << "\" in the Module\n"
<< "candidates are: [";
int counter = 0;
for (auto kv : global_var_map_) {
if (counter++ != 0) {
msg << ", ";
}
msg << "\"" << kv.first << "\"";
}
msg << "]";
LOG(FATAL) << msg.str();
}
return (*it).second;
}

Expand Down
29 changes: 13 additions & 16 deletions src/ir/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ class ModulePassNode : public PassNode {
*
* \return Return the updated module.
*/
IRModule operator()(const IRModule& mod, const PassContext& pass_ctx) const final;
IRModule operator()(IRModule mod, const PassContext& pass_ctx) const final;

/*!
* \brief Get the pass information/meta data.
Expand Down Expand Up @@ -205,7 +205,7 @@ class SequentialNode : public PassNode {
*
* \return Return the updated module.
*/
IRModule operator()(const IRModule& mod, const PassContext& pass_ctx) const final;
IRModule operator()(IRModule mod, const PassContext& pass_ctx) const final;

static constexpr const char* _type_key = "transform.Sequential";
TVM_DECLARE_FINAL_OBJECT_INFO(SequentialNode, PassNode);
Expand All @@ -231,19 +231,20 @@ ModulePass::ModulePass(
}

// Module -> Module optimizations.
IRModule ModulePassNode::operator()(const IRModule& mod,
IRModule ModulePassNode::operator()(IRModule mod,
const PassContext& pass_ctx) const {
const PassInfo& pass_info = Info();
DLOG(INFO) << "Executing module pass : "
<< pass_info->name
<< " with opt level: "
<< pass_info->opt_level;

CHECK(mod.defined());
pass_ctx.Trace(mod, pass_info, true);
IRModule updated_mod = pass_func(mod, pass_ctx);
CHECK(updated_mod.defined());
pass_ctx.Trace(updated_mod, pass_info, false);
return updated_mod;
mod = pass_func(std::move(mod), pass_ctx);
CHECK(mod.defined());
pass_ctx.Trace(mod, pass_info, false);
return mod;
}

Sequential::Sequential(tvm::Array<Pass> passes, PassInfo pass_info) {
Expand Down Expand Up @@ -314,18 +315,17 @@ Pass GetPass(const std::string& pass_name) {
// TODO(zhiics): we currenlty only sequentially execute each pass in
// a Sequential without the consideration of their orders. The phase
// ordering problem needs to be handled in the future.
IRModule SequentialNode::operator()(const IRModule& module,
IRModule SequentialNode::operator()(IRModule mod,
const PassContext& pass_ctx) const {
IRModule mod = module;
for (const Pass& pass : passes) {
CHECK(pass.defined()) << "Found undefined pass for optimization.";
const PassInfo& pass_info = pass->Info();
if (!PassEnabled(pass_info)) continue;
// resolve dependencies
for (const auto& it : pass_info->required) {
mod = GetPass(it)(mod, pass_ctx);
mod = GetPass(it)(std::move(mod), pass_ctx);
}
mod = pass(mod, pass_ctx);
mod = pass(std::move(mod), pass_ctx);
}
return mod;
}
Expand Down Expand Up @@ -375,11 +375,8 @@ TVM_REGISTER_GLOBAL("transform.MakeModulePass")
});

TVM_REGISTER_GLOBAL("transform.RunPass")
.set_body([](TVMArgs args, TVMRetValue* ret) {
Pass pass = args[0];
IRModule mod = args[1];
ObjectRef ref = args[1];
*ret = pass(mod);
.set_body_typed([](Pass pass, IRModule mod) {
return pass(std::move(mod));
});

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
Expand Down
9 changes: 8 additions & 1 deletion src/node/container.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
#include <tvm/runtime/container.h>
#include <tvm/node/container.h>
#include <tvm/tir/expr.h>
#include <cstring>
#include "../support/str_escape.h"

namespace tvm {

Expand Down Expand Up @@ -63,6 +63,13 @@ TVM_REGISTER_REFLECTION_VTABLE(runtime::StringObj, StringObjTrait)
static_cast<const runtime::StringObj*>(n)).operator std::string();
});

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<runtime::StringObj>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const runtime::StringObj*>(node.get());
p->stream << '"' << support::StrEscape(op->data, op->size) << '"';
});


struct ADTObjTrait {
static constexpr const std::nullptr_t VisitAttrs = nullptr;

Expand Down
5 changes: 3 additions & 2 deletions src/relay/ir/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class FunctionPassNode : public PassNode {
*
* \return Return the updated module.
*/
IRModule operator()(const IRModule& mod, const PassContext& pass_ctx) const final;
IRModule operator()(IRModule mod, const PassContext& pass_ctx) const final;

/*!
* \brief Get the pass information/meta data.
Expand Down Expand Up @@ -113,7 +113,7 @@ FunctionPass::FunctionPass(
}

// Perform Module -> Module optimizations at the Function level.
IRModule FunctionPassNode::operator()(const IRModule& mod,
IRModule FunctionPassNode::operator()(IRModule mod,
const PassContext& pass_ctx) const {
const PassInfo& pass_info = Info();
CHECK(mod.defined());
Expand All @@ -122,6 +122,7 @@ IRModule FunctionPassNode::operator()(const IRModule& mod,
<< " with opt level: "
<< pass_info->opt_level;
pass_ctx.Trace(mod, pass_info, true);

// Execute the pass function and return a new module.
IRModule updated_mod = IRModule(mod->functions, mod->type_definitions, mod->Imports());
std::vector<std::pair<GlobalVar, Function> > updates;
Expand Down
Loading