Skip to content

Commit

Permalink
[Relay] Re-run PlanDevices after LowerTE to flow new memory scope con…
Browse files Browse the repository at this point in the history
…straints. (apache#9613)

* [Relay] Re-run PlanDevices after LowerTE to flow new memory scope constraints.

This PR:
 1) Makes PlanDevices consider lowered calls when solving device domain constraints.
 2) Connects the storage scopes on PrimFunc parameters (encoded in their Buffer data
    Var type annotation PointerTypes storage_scope fields) to the memory_scope
    fields of the SEScopes which PlanDevices unifies over.
 3) Allows new device_copies to be inserted on the arguments and results of lowered
    calls so as to acount for any memory scope mismatches which are now apparent.

[device_planner.cc has main changes, rest is secondary.]

In the short term we'd like to use this machinery to flow memory scope choices made
during lowering back out into the overall Relay program. In the longer term we'd
also like to be able to use memory scopes to influence the lowering of
yet-to-be-lowered functions (or lowered functions which have yet to been scheduled,
a distinction now possible with TensorIR).

 - Memory scope constraints can flow both out of and in to PrimFuncs
   introduced by LowerTE. In TIR memory scopes are represented by
   'storage scopes' on the PointerType type annotations on TIR Buffer data
   variables.
    - It is straightforward to extract memory scopes from PrimFuncs by
      looking at the PrimFunc's buffer_map. We do this is 'phase 1' of
      PlanDevices, which collects all the device constraints implied by
    - However, pushing memory constraints in to PrimFuncs is more challenging
      due to buffer aliasing. This aspect is still experimental.

 - Allow device_copies to be inserted for both arguments and
   results of PrimFunc calls, on the assumption PlanDevices has
   already established a consistent device assignment prior to
   lowering and any new mismatch is required to match up memory scopes.
   We use the new 'free' on_device annotations to implement this.

Coming along for the ride:

 - To make unit tests of mixed Relay/TIR functions possible needed
   to be able to supply a checked_type to GlobalVar since that's currently
   the only way to give a Relay type to PrimFuncs.

 - Use GenSym to get unique var names in ANF & partial eval so easier
   to diff debug output between passes and connect program fragments
   back into the overall program. Relying on pretty-printing to
   automagically unique-ify var names is certainly cute but until we
   have better span support is very hard to work with.

 - Realized both dead_code.cc and fold_constant.cc would
   happily move values into a different lexical virtual
   device context since device_planner.cc was being
   'clever' and eliding on_devices for let-bound values
   when there's no change. Fixed so that every let-bound
   value has an on_device. Will be much better after
   apache/tvm-rfcs#45 is implemented.

 - Make build -Werror clean for clang-12 (mostly move fixups).

 - Address post-submit comments from apache#9693.

* [checkpoint] thread safe GenSym
  • Loading branch information
mbs-octoml authored and ylc committed Jan 7, 2022
1 parent 7035583 commit 864a113
Show file tree
Hide file tree
Showing 24 changed files with 316 additions and 55 deletions.
2 changes: 1 addition & 1 deletion include/tvm/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ class GlobalVarNode : public RelayExprNode {
*/
class GlobalVar : public RelayExpr {
public:
TVM_DLL explicit GlobalVar(String name_hint);
TVM_DLL explicit GlobalVar(String name_hint, Type type = {});

TVM_DEFINE_OBJECT_REF_METHODS(GlobalVar, RelayExpr, GlobalVarNode);
};
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/relay/attrs/on_device.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ struct OnDeviceAttrs : public tvm::AttrsNode<OnDeviceAttrs> {
SEScope se_scope = SEScope::FullyUnconstrained();

/*!
* \brief If fales (the default), the result of the "on_device" call is not constrained to be
* \brief If false (the default), the result of the "on_device" call is not constrained to be
* \p se_scope.
*/
bool constrain_result = false;
Expand Down
9 changes: 9 additions & 0 deletions include/tvm/relay/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,15 @@ class Var : public Expr {
*/
TVM_DLL Var(Id vid, Type type_annotation, Span span = Span());

/*!
* \brief Return a globally fresh name. Helps with debugging to follow the same
* variable between passes and sub-expressions.
*
* TODO(mbs): Replace with name creation w.r.t. scopes once available as part of
* name gen overhaul.
*/
static Var GenSym(Type type_annotation = {}, Span span = {});

TVM_DEFINE_OBJECT_REF_METHODS(Var, RelayExpr, VarNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(VarNode);
};
Expand Down
8 changes: 5 additions & 3 deletions include/tvm/target/se_scope.h
Original file line number Diff line number Diff line change
Expand Up @@ -159,19 +159,21 @@ using MemoryScope = String;
*
*/
class SEScopeNode : public AttrsNode<SEScopeNode> {
public:
private:
/*!
* \brief The \p DLDeviceType (represtented as an int) of the virtual device. If \p target is
* \brief The \p DLDeviceType (represented as an int) of the virtual device. If \p target is
* known then this will be equal to \p target->kind->device_type. If \p target is null then the
* target is to be determined later.
*
* This is needed to support the legacy "on_device" and "device_copy" calls which only allow
* a \p DLDeviceTypes (as an integer) to be given.
*
* kInvalidDeviceType denotes unconstrained.
* kInvalidDeviceType denotes unconstrained. An int since the DLDeviceType enum representation
* is not fixed. Private to discourage further int vs DLDeviceType confusion.
*/
int /* actually DLDeviceType */ device_type_int;

public:
DLDeviceType device_type() const { return static_cast<DLDeviceType>(device_type_int); }

/*!
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/ir/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ class GlobalVar(RelayExpr):
The name of the variable.
"""

def __init__(self, name_hint):
self.__init_handle_by_constructor__(_ffi_api.GlobalVar, name_hint)
def __init__(self, name_hint, type_annot=None):
self.__init_handle_by_constructor__(_ffi_api.GlobalVar, name_hint, type_annot)

def __call__(self, *args):
"""Call the global variable.
Expand Down
7 changes: 5 additions & 2 deletions src/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -141,15 +141,18 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << "range(min=" << op->min << ", ext=" << op->extent << ')';
});

GlobalVar::GlobalVar(String name_hint) {
GlobalVar::GlobalVar(String name_hint, Type type) {
ObjectPtr<GlobalVarNode> n = make_object<GlobalVarNode>();
n->name_hint = std::move(name_hint);
n->checked_type_ = std::move(type);
data_ = std::move(n);
}

TVM_REGISTER_NODE_TYPE(GlobalVarNode);

TVM_REGISTER_GLOBAL("ir.GlobalVar").set_body_typed([](String name) { return GlobalVar(name); });
TVM_REGISTER_GLOBAL("ir.GlobalVar").set_body_typed([](String name, Type type) {
return GlobalVar(name, type);
});

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<GlobalVarNode>([](const ObjectRef& ref, ReprPrinter* p) {
Expand Down
4 changes: 3 additions & 1 deletion src/relay/backend/contrib/cmsisnn/extract_constants.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ class ExtractConstantsMutator : public MixedModeMutator {
private:
String gen_var_name() { return "tvm_var_extract_const_" + std::to_string(var_count_++); }

using MixedModeMutator::VisitExpr_;

Expr VisitExpr_(const FunctionNode* function) final {
Function func = GetRef<Function>(function);
function_to_constants_.Set(func, Array<Constant>{});
Expand All @@ -56,7 +58,7 @@ class ExtractConstantsMutator : public MixedModeMutator {
func = Function(FreeVars(new_body), new_body, func->ret_type, FreeTypeVars(new_body, mod_),
func->attrs);
}
return func;
return std::move(func);
}

Expr Rewrite_(const CallNode* call, const Expr& post) final {
Expand Down
2 changes: 1 addition & 1 deletion src/relay/backend/contrib/cmsisnn/generate_constants.cc
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ class GenerateConstantsMutator : public MixedModeMutator {
if (clip_call) {
ret_call = Call(clip_call->op, {ret_call}, clip_call->attrs, {});
}
return ret_call;
return std::move(ret_call);
}

Expr Rewrite_(const CallNode* call, const Expr& post) final {
Expand Down
4 changes: 3 additions & 1 deletion src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,10 @@ class CodeGenCMSISNN : public codegen::CodeGenCHost {
int clip_max;
};

using codegen::CodeGenCHost::VisitStmt_;

/*! * \brief Emits CMSIS-NN APIs for every call_extern */
void VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*)
void VisitExpr_(const CallNode* op, std::ostream& os) final {
if (!op->op.same_as(builtin::call_extern())) {
CodeGenCHost::VisitExpr_(op, os);
return;
Expand Down
3 changes: 1 addition & 2 deletions src/relay/backend/graph_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -694,8 +694,7 @@ class GraphExecutorCodegenModule : public runtime::ModuleNode {
*rv = this->output_.external_mods;
});
} else if (name == "get_devices") {
return PackedFunc(
[sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = Array<String>(); });
return PackedFunc([sptr_to_self](TVMArgs args, TVMRetValue* rv) { *rv = Array<String>(); });
} else if (name == "get_metadata") {
return PackedFunc(
[sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->output_.metadata; });
Expand Down
4 changes: 4 additions & 0 deletions src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1088,6 +1088,10 @@ IRModule VMCompiler::OptimizeModuleImpl(IRModule mod) {
// let-bound functions.
pass_seqs.push_back(DeadCodeElimination(/*inline_once=*/false));

// Now that we have PrimFuncs, flow and solve SEScope constraints again to account for
// any memory scopes which lowering has settled on.
pass_seqs.push_back(transform::PlanDevices(config_));

// Inline the functions that are lifted to the module scope. We perform this
// pass after all other optimization passes but before the memory allocation
// pass. This is because memory allocation pass will insert `invoke_tvm_op`
Expand Down
4 changes: 2 additions & 2 deletions src/relay/ir/adt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ Clause WithFields(Clause clause, Optional<Pattern> opt_lhs, Optional<Expr> opt_r
cow_clause_node->lhs = lhs;
cow_clause_node->rhs = rhs;
}
return std::move(clause);
return clause;
}

TVM_REGISTER_NODE_TYPE(ClauseNode);
Expand Down Expand Up @@ -168,7 +168,7 @@ Match WithFields(Match match, Optional<Expr> opt_data, Optional<Array<Clause>> o
cow_match_node->complete = complete;
cow_match_node->span = span;
}
return std::move(match);
return match;
}

TVM_REGISTER_NODE_TYPE(MatchNode);
Expand Down
49 changes: 28 additions & 21 deletions src/relay/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ Tuple WithFields(Tuple tuple, Optional<Array<Expr>> opt_fields,
cow_tuple_node->virtual_device_ = virtual_device;
cow_tuple_node->span = span;
}
return std::move(tuple);
return tuple;
}

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
Expand All @@ -124,6 +124,13 @@ Var::Var(Id vid, Type type_annotation, Span span) {
data_ = std::move(n);
}

/* static */ Var Var::GenSym(Type type_annotation, Span span) {
static size_t next_id = std::atomic<size_t>(0);
std::ostringstream os;
os << "x_" << next_id++;
return Var(os.str(), std::move(type_annotation), std::move(span));
}

Var WithFields(Var var, Optional<Id> opt_vid, Optional<Type> opt_type_annotation,
Optional<SEScope> opt_virtual_device, Optional<Span> opt_span) {
Id vid = opt_vid.value_or(var->vid);
Expand All @@ -141,7 +148,7 @@ Var WithFields(Var var, Optional<Id> opt_vid, Optional<Type> opt_type_annotation
cow_var_node->virtual_device_ = virtual_device;
cow_var_node->span = span;
}
return std::move(var);
return var;
}

TVM_REGISTER_NODE_TYPE(VarNode);
Expand Down Expand Up @@ -219,7 +226,7 @@ Call WithFields(Call call, Optional<Expr> opt_op, Optional<Array<Expr>> opt_args
cow_call_node->virtual_device_ = virtual_device;
cow_call_node->span = span;
}
return std::move(call);
return call;
}

TVM_REGISTER_NODE_TYPE(CallNode);
Expand Down Expand Up @@ -264,7 +271,7 @@ Let WithFields(Let let, Optional<Var> opt_var, Optional<Expr> opt_value, Optiona
cow_let_node->virtual_device_ = virtual_device;
cow_let_node->span = span;
}
return std::move(let);
return let;
}

TVM_REGISTER_NODE_TYPE(LetNode);
Expand Down Expand Up @@ -308,7 +315,7 @@ If WithFields(If if_expr, Optional<Expr> opt_cond, Optional<Expr> opt_true_branc
cow_if_node->virtual_device_ = virtual_device;
cow_if_node->span = span;
}
return std::move(if_expr);
return if_expr;
}

TVM_REGISTER_NODE_TYPE(IfNode);
Expand Down Expand Up @@ -350,7 +357,7 @@ TupleGetItem WithFields(TupleGetItem tuple_get_item, Optional<Expr> opt_tuple,
cow_tuple_get_item_node->span = span;
cow_tuple_get_item_node->virtual_device_ = virtual_device;
}
return std::move(tuple_get_item);
return tuple_get_item;
}

TVM_REGISTER_NODE_TYPE(TupleGetItemNode);
Expand Down Expand Up @@ -385,7 +392,7 @@ RefCreate WithFields(RefCreate ref_create, Optional<Expr> opt_value,
cow_ref_create_node->virtual_device_ = virtual_device;
cow_ref_create_node->span = span;
}
return std::move(ref_create);
return ref_create;
}

TVM_REGISTER_NODE_TYPE(RefCreateNode);
Expand Down Expand Up @@ -420,7 +427,7 @@ RefRead WithFields(RefRead ref_read, Optional<Expr> opt_ref, Optional<SEScope> o
cow_ref_read_node->virtual_device_ = virtual_device;
cow_ref_read_node->span = span;
}
return std::move(ref_read);
return ref_read;
}

TVM_REGISTER_NODE_TYPE(RefReadNode);
Expand Down Expand Up @@ -457,7 +464,7 @@ RefWrite WithFields(RefWrite ref_write, Optional<Expr> opt_ref, Optional<Expr> o
cow_ref_write_node->virtual_device_ = virtual_device;
cow_ref_write_node->span = span;
}
return std::move(ref_write);
return ref_write;
}

TVM_REGISTER_NODE_TYPE(RefWriteNode);
Expand Down Expand Up @@ -510,29 +517,29 @@ inline void Dismantle(const Expr& expr) {
stack.top().second = true;

// special handling
if (const CallNode* op = node.as<CallNode>()) {
if (const auto* call_node = node.as<CallNode>()) {
// do not process args if used elsewhere
if (op->args.use_count() < 2) {
for (auto it = op->args.rbegin(); it != op->args.rend(); ++it) {
if (call_node->args.use_count() < 2) {
for (auto it = call_node->args.rbegin(); it != call_node->args.rend(); ++it) {
fpush_to_stack(*it);
}
}
} else if (const TupleNode* op = node.as<TupleNode>()) {
} else if (const auto* tuple_node = node.as<TupleNode>()) {
// do not process fields if used elsewhere
if (op->fields.use_count() < 2) {
for (auto it = op->fields.rbegin(); it != op->fields.rend(); ++it) {
if (tuple_node->fields.use_count() < 2) {
for (auto it = tuple_node->fields.rbegin(); it != tuple_node->fields.rend(); ++it) {
fpush_to_stack(*it);
}
}
} else if (const TupleGetItemNode* op = node.as<TupleGetItemNode>()) {
} else if (const auto* tuple_get_item_node = node.as<TupleGetItemNode>()) {
// do not process tuple if used elsewhere
if (op->tuple.use_count() < 2) {
fpush_to_stack(op->tuple);
if (tuple_get_item_node->tuple.use_count() < 2) {
fpush_to_stack(tuple_get_item_node->tuple);
}
} else if (const LetNode* op = node.as<LetNode>()) {
} else if (const auto* let_node = node.as<LetNode>()) {
// do not process let if used elsewhere
if (op->body.use_count() < 2) {
fpush_to_stack(op->body);
if (let_node->body.use_count() < 2) {
fpush_to_stack(let_node->body);
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/relay/ir/function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ Function WithFields(Function function, Optional<Array<Var>> opt_params, Optional
cow_function_node->virtual_device_ = virtual_device;
cow_function_node->span = span;
}
return std::move(function);
return function;
}

FuncType FunctionNode::func_type_annotation() const {
Expand Down
4 changes: 3 additions & 1 deletion src/relay/op/memory/on_device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,9 @@ Expr MaybeOnDevice(Expr body, SEScope se_scope, bool constrain_result, bool cons
ICHECK(inner == outer)
<< "Cannot constrain intermediate result of nested on_device calls to different SEScopes";
}
// We can now ignore the intermediate constraints, if any.
// We can now ignore the middle constraint.
// If the outer on_device has any constraint then use se_scope given for it.
// Otherwise we can use the existing inner se_scope.
return OnDevice(props.body, (constrain_inner || constrain_outer) ? outer : inner,
constrain_outer, constrain_inner);
} else {
Expand Down
9 changes: 6 additions & 3 deletions src/relay/op/memory/on_device.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,15 +66,18 @@ struct OnDeviceProps {
};

/*!
* \brief As for OnDevice, but taking all fields other than \p body from \p props.
* \brief Wraps \p body in an "on_device" CallNode, taking all fields other than \p body from \p
* props.
*/
inline Call OnDeviceWithProps(Expr body, const OnDeviceProps& props) {
return OnDevice(std::move(body), props.se_scope, props.constrain_result, props.constrain_body);
}

/*!
* \brief As for OnDevice, but don't constrain the body or result to any particular virtual device.
* This allows a "device_copy" when required.
* \brief Wraps \p body in an "on_device" CallNode, but don't constrain the body or result to
* any particular virtual device. This allows a "device_copy" to be inserted by PlanDevices
* where required, while at the same time not introducing unnecessary freedom in the device
* choices.
*/
inline Call OnDeviceCopyOk(Expr body) {
return OnDevice(std::move(body), SEScope::FullyUnconstrained(),
Expand Down
12 changes: 9 additions & 3 deletions src/relay/transforms/device_domains.cc
Original file line number Diff line number Diff line change
Expand Up @@ -199,9 +199,15 @@ DeviceDomainPtr DeviceDomains::DomainForCallee(const Call& call) {
DeviceCopyProps device_copy_props = GetDeviceCopyProps(call.get());
CallLoweredProps call_lowered_props = GetCallLoweredProps(call.get());

// TODO(mbs): Support call_lowered to PrimFuncs.
ICHECK(!call_lowered_props.lowered_func.defined());
if (on_device_props.body.defined()) {
if (call_lowered_props.lowered_func.defined()) {
// Presumably we've already seen the call to the "primitive" Function from which this lowered
// function was derived in an earlier PlanDevices pass. Thus we've already established that
// all the argument and result devices domains must be equal, ignoring memory scopes.
// So at this point we'll let all the arguments and result be free so that memory scopes can
// differ.
// TODO(mbs): As per header comments, need to revisit when can setup sub-SEScope constraints.
return DomainFor(call_lowered_props.lowered_func);
} else if (on_device_props.body.defined()) {
// By default:
// on_device(expr, se_scope=<t>)
// on_device : fn(<t>):?x?
Expand Down
Loading

0 comments on commit 864a113

Please sign in to comment.