Skip to content

Commit

Permalink
[Relay] PlanDevices can run after LowerTE
Browse files Browse the repository at this point in the history
 - Allow device_copy into and out of PrimFuncs
 - Realized both dead_code.cc and fold_constant.cc would
   happily move values into a different lexical virtual
   device context since device_planner.cc was being
   'clever' and eliding on_devices for let-bound values
   when there's no change. Fixed so that every let-bound
   value has an on_device. Will be much better after
   apache/tvm-rfcs#45 is implemented.
  • Loading branch information
mbs-octoml committed Dec 4, 2021
1 parent 7f683da commit 715692f
Show file tree
Hide file tree
Showing 41 changed files with 826 additions and 447 deletions.
2 changes: 1 addition & 1 deletion include/tvm/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ class GlobalVarNode : public RelayExprNode {
*/
class GlobalVar : public RelayExpr {
public:
TVM_DLL explicit GlobalVar(String name_hint);
explicit TVM_DLL GlobalVar(String name_hint, Type type = {});

TVM_DEFINE_OBJECT_REF_METHODS(GlobalVar, RelayExpr, GlobalVarNode);
};
Expand Down
4 changes: 3 additions & 1 deletion include/tvm/relay/attrs/call.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ struct CallLoweredAttrs : public tvm::AttrsNode<CallLoweredAttrs> {
Map<String, ObjectRef> metadata;

TVM_DECLARE_ATTRS(CallLoweredAttrs, "relay.attrs.CallLoweredAttrs") {
TVM_ATTR_FIELD(metadata).describe("Metadata attached to the lowered function call.");
TVM_ATTR_FIELD(metadata)
.describe("Metadata attached to the lowered function call.")
.set_default(Map<String, ObjectRef>());
}
};

Expand Down
62 changes: 33 additions & 29 deletions include/tvm/relay/attrs/on_device.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

/*!
* \file tvm/relay/attrs/on_device.h
* \brief Attribute for the on device annotation.
* \brief Attribute for the "on_device" annotation (ie operator).
*/
#ifndef TVM_RELAY_ATTRS_ON_DEVICE_H_
#define TVM_RELAY_ATTRS_ON_DEVICE_H_
Expand All @@ -33,9 +33,9 @@ namespace tvm {
namespace relay {

/*!
* \brief Attributes for the "on_device" special operator.
* \brief Attributes for the "on_device" annotation (ie operator).
*
* The Relay call (aka 'annotation'):
* The Relay call:
* \code
* on_device(sub_expr, se_scope=S)
* \endcode
Expand All @@ -54,44 +54,48 @@ namespace relay {
* multiply(device_copy(add(%x, %y), src_se_scope=GPU, dst_se_scope=CPU), %z)
* \endcode
*
* The Relay call
* \code
* on_device(sub_expr, se_scope=S, is_fixed=True)
* \endcode
* is similar to the above, however the annotation itself must appear in an expression on the
* same \p SEScope \p S. The compiler will check the \p SEScopes are consistent, and will not
* insert any "device_copy" call. This form of annotation shouldn't be necessary in user programs.
* However it is needed by the \p PlanDevices pass to fully specify the results of device planning
* so that the pass is idempotent.
*
* E.g.: The following program is equivalent to the above:
* \code
* let %a = on_device(add(%x, %y), se_scope=GPU, is_fixed=True)
* multiply(device_copy(%a, src_se_scope=GPU, dst_se_scope=CPU), %z)
* \endcode
* The "on_device" annotation with \p is_fixed=True indicates unambiguously that \p %a is stored
* on the GPU.
* The \p constraint_body (default true) and \p constraint_result (default false) fields can be
* used by passes for finer-grained control over how the \p SEScope constraint should be applied.
*/
struct OnDeviceAttrs : public tvm::AttrsNode<OnDeviceAttrs> {
/*!
* \brief (Virtual) \p SEScope on which the result of the argument expression should be stored.
* \brief The \p SEScope to constraint to apply to the body, result, or both body and result
* of the "on_device" call.
*/
SEScope se_scope = SEScope::FullyUnconstrained();

/*!
* \brief If fales (the default), the result of the "on_device" call is not constrained to be
* \p se_scope.
*/
bool constrain_result = false;

/*!
* \brief If true (the default), the body of the "on_device" call is constrained to be \p
* se_scope.
*/
bool constrain_body = true;

/*!
* \brief Returns true if both the body and result are constrained.
*/
bool is_fixed() const { return constrain_result && constrain_body; }

/*!
* \brief If true, the result \p SEScope must also be \p se_scope, and device planning should
* not insert any "device_copy" calls to respect this annotation.
*
* This is used by the device planning pass itself when annotating the planned program.
* \brief Returns true only the body is constrained (the 'normal' case).
*/
bool is_fixed = false;
bool is_normal() const { return !constrain_result && constrain_body; }

TVM_DECLARE_ATTRS(OnDeviceAttrs, "relay.attrs.OnDeviceAttrs") {
TVM_ATTR_FIELD(se_scope)
.describe("The (virtual) device and scope holding the expression result.")
.describe("The (virtual) device to constrain to.")
.set_default(SEScope::FullyUnconstrained());
TVM_ATTR_FIELD(is_fixed)
.describe("If true, do not insert a \"device_copy\" call to respect this annotation.")
TVM_ATTR_FIELD(constrain_result)
.describe("Whether the constraint applies to the overall expression")
.set_default(false);
TVM_ATTR_FIELD(constrain_body)
.describe("Whether the constraint applies to the body sub-expression.")
.set_default(true);
}
};

Expand Down
8 changes: 8 additions & 0 deletions include/tvm/relay/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,14 @@ class Var : public Expr {
*/
TVM_DLL Var(Id vid, Type type_annotation, Span span = Span());

/*!
* \brief Return a globally fresh name. Helps with debugging to follow the same
* variable between passes and sub-expressions.
*
* TODO(mbs): Replace with name creation w.r.t. scopes.
*/
static Var GenSym(Type type_annotation = {}, Span span = {});

TVM_DEFINE_OBJECT_REF_METHODS(Var, RelayExpr, VarNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(VarNode);
};
Expand Down
7 changes: 6 additions & 1 deletion include/tvm/target/se_scope.h
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ class SEScopeNode : public AttrsNode<SEScopeNode> {
*
* kInvalidDeviceType denotes unconstrained.
*/
int device_type_int;
int /* actually DLDeviceType */ device_type_int;

DLDeviceType device_type() const { return static_cast<DLDeviceType>(device_type_int); }

Expand Down Expand Up @@ -303,6 +303,11 @@ class SEScope : public ObjectRef {
return SEScope(device_type, /*virtual_device_id=*/0, std::move(target));
}

/*! \brief Returns the \p SEScope for \p memory_scope alone. */
static SEScope ForMemoryScope(MemoryScope memory_scope) {
return SEScope(kInvalidDeviceType, -1, {}, std::move(memory_scope));
}

/*! \brief Returns the \p SEScope for \p device, \p target and \p memory_scope. */
TVM_DLL static SEScope ForDeviceTargetAndMemoryScope(const Device& device, Target target,
MemoryScope memory_scope) {
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/tir/var.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class VarNode : public PrimExprNode {
*/
String name_hint;
/*!
* \brief type annotaion of the variable.
* \brief type annotation of the variable.
*
* It is an optional field that provides a refined type of the variable than dtype.
*
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/ir/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ class GlobalVar(RelayExpr):
The name of the variable.
"""

def __init__(self, name_hint):
self.__init_handle_by_constructor__(_ffi_api.GlobalVar, name_hint)
def __init__(self, name_hint, type_annot=None):
self.__init_handle_by_constructor__(_ffi_api.GlobalVar, name_hint, type_annot)

def __call__(self, *args):
"""Call the global variable.
Expand Down
24 changes: 15 additions & 9 deletions python/tvm/relay/op/annotation/annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,29 +31,35 @@ def _make_se_scope(device):
raise ValueError("expecting a Device or device name, but received a %s" % (type(device)))


def on_device(data, device, is_fixed=False):
"""Annotates an expression with the device type on which its result should be stored.
def on_device(body, device, constrain_result=False, constrain_body=True):
"""Annotates a body expression with device constraints. The constraint influences
how the body is compiled, where the body is evaluated, and where the result of
evaluation is stored.
Note that the defaults for the constrain_body and constrain_result parameters should
almost never need to be overridden by the user. These parameters are exposed here
to help unit tests exercise the PlanDevices pass machinery.
Parameters
----------
data : tvm.relay.Expr
body : tvm.relay.Expr
The expression to be annotated.
device : Union[:py:class:`Device`, str]
The device to annotate with.
is_fixed : bool
If false (the default), a device_copy
If true, the annotation does not imply a device_copy may be inserted to
reconcile the device of the data argument with the device for the context of the
annotated expression.
constrain_result : bool
If false (the default), the result of the on_device is not constrained to be on device.
constrain_body : bool
If true (the default), the body of the on_device is constrained to be on device.
Returns
-------
result : tvm.relay.Expr
The annotated expression.
"""
return _make.OnDevice(data, _make_se_scope(device), is_fixed)
return _make.OnDevice(body, _make_se_scope(device), constrain_result, constrain_body)


def function_on_device(function, param_devices, result_device):
Expand Down
7 changes: 5 additions & 2 deletions src/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -141,15 +141,18 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << "range(min=" << op->min << ", ext=" << op->extent << ')';
});

GlobalVar::GlobalVar(String name_hint) {
GlobalVar::GlobalVar(String name_hint, Type type) {
ObjectPtr<GlobalVarNode> n = make_object<GlobalVarNode>();
n->name_hint = std::move(name_hint);
n->checked_type_ = std::move(type);
data_ = std::move(n);
}

TVM_REGISTER_NODE_TYPE(GlobalVarNode);

TVM_REGISTER_GLOBAL("ir.GlobalVar").set_body_typed([](String name) { return GlobalVar(name); });
TVM_REGISTER_GLOBAL("ir.GlobalVar").set_body_typed([](String name, Type type) {
return GlobalVar(name, type);
});

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<GlobalVarNode>([](const ObjectRef& ref, ReprPrinter* p) {
Expand Down
11 changes: 5 additions & 6 deletions src/printer/relay_text_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -502,12 +502,6 @@ Doc RelayTextPrinter::VisitExpr_(const FunctionNode* op) {
Doc RelayTextPrinter::VisitExpr_(const GlobalVarNode* op) {
Doc doc;
doc << "@" << op->name_hint;
#if TVM_LOG_DEBUG
if (op->checked_type_.defined()) {
doc << " /* type=" << PrintType(op->checked_type_, /*meta=*/false) << " */";
}
doc << " /* id=" << reinterpret_cast<uint64_t>(op) << " */";
#endif
return doc;
}

Expand All @@ -521,6 +515,11 @@ Doc RelayTextPrinter::VisitExpr_(const CallNode* op) {
for (const Expr& arg : op->args) {
args.push_back(Print(arg));
}
#if TVM_LOG_DEBUG
for (const Type& type_arg : op->type_args) {
args.push_back(Print(type_arg));
}
#endif
for (const Doc& d : PrintCallAttrs(op->attrs, op->op)) {
args.push_back(d);
}
Expand Down
57 changes: 40 additions & 17 deletions src/printer/text_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

#include <tvm/tir/function.h>

#include <algorithm>
#include <string>

namespace tvm {
Expand All @@ -36,49 +37,71 @@ static const char* kSemVer = "0.0.5";
Doc TextPrinter::PrintMod(const IRModule& mod) {
Doc doc;
int counter = 0;

// We'll print in alphabetical order to make a/b diffs easier to work with.

// type definitions
std::vector<GlobalTypeVar> tyvars;
for (const auto& kv : mod->type_definitions) {
tyvars.emplace_back(kv.first);
}
std::sort(tyvars.begin(), tyvars.end(),
[](const GlobalTypeVar& left, const GlobalTypeVar& right) {
return left->name_hint < right->name_hint;
});
for (const auto& tyvar : tyvars) {
if (counter++ != 0) {
doc << Doc::NewLine();
}
doc << relay_text_printer_.Print(kv.second);
doc << relay_text_printer_.Print(mod->type_definitions[tyvar]);
doc << Doc::NewLine();
}

// functions
std::vector<GlobalVar> vars;
for (const auto& kv : mod->functions) {
if (kv.second.as<relay::FunctionNode>()) {
vars.emplace_back(kv.first);
}
std::sort(vars.begin(), vars.end(), [](const GlobalVar& left, const GlobalVar& right) {
return left->name_hint < right->name_hint;
});
for (const auto& var : vars) {
const BaseFunc& base_func = mod->functions[var];
if (base_func.as<relay::FunctionNode>()) {
relay_text_printer_.dg_ =
relay::DependencyGraph::Create(&relay_text_printer_.arena_, kv.second);
relay::DependencyGraph::Create(&relay_text_printer_.arena_, base_func);
}
if (counter++ != 0) {
doc << Doc::NewLine();
}
if (kv.second.as<relay::FunctionNode>()) {
if (base_func.as<relay::FunctionNode>()) {
std::ostringstream os;
os << "def @" << kv.first->name_hint;
#if TVM_LOG_DEBUG
os << " /* id=" << reinterpret_cast<uint64_t>(kv.first.get()) << " */";
#endif
doc << relay_text_printer_.PrintFunc(Doc::Text(os.str()), kv.second);
} else if (kv.second.as<tir::PrimFuncNode>()) {
doc << "@" << kv.first->name_hint;
#if TVM_LOG_DEBUG
doc << " /* id=" << reinterpret_cast<uint64_t>(kv.first.get()) << " */";
#endif
doc << " = " << tir_text_printer_.PrintPrimFunc(Downcast<tir::PrimFunc>(kv.second));
os << "def @" << var->name_hint;
doc << relay_text_printer_.PrintFunc(Doc::Text(os.str()), base_func);
} else if (base_func.as<tir::PrimFuncNode>()) {
doc << "@" << var->name_hint;
doc << " = " << tir_text_printer_.PrintPrimFunc(Downcast<tir::PrimFunc>(base_func));
}
doc << Doc::NewLine();
}

#if TVM_LOG_DEBUG
// attributes
// TODO(mbs): Make this official, including support from parser.
if (mod->attrs.defined() && !mod->attrs->dict.empty()) {
doc << "attributes {" << Doc::NewLine();
std::vector<String> keys;
for (const auto& kv : mod->attrs->dict) {
doc << " '" << kv.first << "' = " << PrettyPrint(kv.second) << Doc::NewLine();
keys.emplace_back(kv.first);
}
std::sort(keys.begin(), keys.end());
doc << "attributes {" << Doc::NewLine();
for (const auto& key : keys) {
doc << " '" << key << "' = " << PrettyPrint(mod->attrs->dict[key]) << Doc::NewLine();
}
doc << "}" << Doc::NewLine();
}
#endif

return doc;
}

Expand Down
4 changes: 3 additions & 1 deletion src/relay/backend/contrib/cmsisnn/extract_constants.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ class ExtractConstantsMutator : public MixedModeMutator {
private:
String gen_var_name() { return "tvm_var_extract_const_" + std::to_string(var_count_++); }

using MixedModeMutator::VisitExpr_;

Expr VisitExpr_(const FunctionNode* function) final {
Function func = GetRef<Function>(function);
function_to_constants_.Set(func, Array<Constant>{});
Expand All @@ -56,7 +58,7 @@ class ExtractConstantsMutator : public MixedModeMutator {
func = Function(FreeVars(new_body), new_body, func->ret_type, FreeTypeVars(new_body, mod_),
func->attrs);
}
return func;
return std::move(func);
}

Expr Rewrite_(const CallNode* call, const Expr& post) final {
Expand Down
2 changes: 1 addition & 1 deletion src/relay/backend/contrib/cmsisnn/generate_constants.cc
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ class GenerateConstantsMutator : public MixedModeMutator {
if (clip_call) {
ret_call = Call(clip_call->op, {ret_call}, clip_call->attrs, {});
}
return ret_call;
return std::move(ret_call);
}

Expr Rewrite_(const CallNode* call, const Expr& post) final {
Expand Down
4 changes: 3 additions & 1 deletion src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,10 @@ class CodeGenCMSISNN : public codegen::CodeGenCHost {
int depth_multiplier;
};

using codegen::CodeGenCHost::VisitStmt_;

/*! * \brief Emit the CMSIS-NN context buffer */
void VisitStmt_(const AllocateNode* op) {
void VisitStmt_(const AllocateNode* op) final {
context_buffer_name_ = op->buffer_var->name_hint;
context_buffer_size_ = op->constant_allocation_size();
CodeGenC::VisitStmt_(op);
Expand Down
Loading

0 comments on commit 715692f

Please sign in to comment.