diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index d33606676944..a6e5c8de73a7 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -245,7 +245,7 @@ class GlobalVarNode : public RelayExprNode { */ class GlobalVar : public RelayExpr { public: - TVM_DLL explicit GlobalVar(String name_hint); + TVM_DLL explicit GlobalVar(String name_hint, Type type = {}); TVM_DEFINE_OBJECT_REF_METHODS(GlobalVar, RelayExpr, GlobalVarNode); }; diff --git a/include/tvm/relay/attrs/on_device.h b/include/tvm/relay/attrs/on_device.h index b1f1e6a6dc45..0931865fa88e 100644 --- a/include/tvm/relay/attrs/on_device.h +++ b/include/tvm/relay/attrs/on_device.h @@ -65,7 +65,7 @@ struct OnDeviceAttrs : public tvm::AttrsNode { SEScope se_scope = SEScope::FullyUnconstrained(); /*! - * \brief If fales (the default), the result of the "on_device" call is not constrained to be + * \brief If false (the default), the result of the "on_device" call is not constrained to be * \p se_scope. */ bool constrain_result = false; diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index 03200d3a3dfb..8bec72490ab1 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -234,6 +234,15 @@ class Var : public Expr { */ TVM_DLL Var(Id vid, Type type_annotation, Span span = Span()); + /*! + * \brief Return a globally fresh name. Helps with debugging to follow the same + * variable between passes and sub-expressions. + * + * TODO(mbs): Replace with name creation w.r.t. scopes once available as part of + * name gen overhaul. + */ + static Var GenSym(Type type_annotation = {}, Span span = {}); + TVM_DEFINE_OBJECT_REF_METHODS(Var, RelayExpr, VarNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(VarNode); }; diff --git a/include/tvm/target/se_scope.h b/include/tvm/target/se_scope.h index ec5da3a80cae..314bf054d7ea 100644 --- a/include/tvm/target/se_scope.h +++ b/include/tvm/target/se_scope.h @@ -159,19 +159,21 @@ using MemoryScope = String; * */ class SEScopeNode : public AttrsNode { - public: + private: /*! - * \brief The \p DLDeviceType (represtented as an int) of the virtual device. If \p target is + * \brief The \p DLDeviceType (represented as an int) of the virtual device. If \p target is * known then this will be equal to \p target->kind->device_type. If \p target is null then the * target is to be determined later. * * This is needed to support the legacy "on_device" and "device_copy" calls which only allow * a \p DLDeviceTypes (as an integer) to be given. * - * kInvalidDeviceType denotes unconstrained. + * kInvalidDeviceType denotes unconstrained. An int since the DLDeviceType enum representation + * is not fixed. Private to discourage further int vs DLDeviceType confusion. */ int /* actually DLDeviceType */ device_type_int; + public: DLDeviceType device_type() const { return static_cast(device_type_int); } /*! diff --git a/python/tvm/ir/expr.py b/python/tvm/ir/expr.py index 139e2b8d97fa..43cba1a83530 100644 --- a/python/tvm/ir/expr.py +++ b/python/tvm/ir/expr.py @@ -64,8 +64,8 @@ class GlobalVar(RelayExpr): The name of the variable. """ - def __init__(self, name_hint): - self.__init_handle_by_constructor__(_ffi_api.GlobalVar, name_hint) + def __init__(self, name_hint, type_annot=None): + self.__init_handle_by_constructor__(_ffi_api.GlobalVar, name_hint, type_annot) def __call__(self, *args): """Call the global variable. diff --git a/src/ir/expr.cc b/src/ir/expr.cc index caddf0efcc77..399873492f04 100644 --- a/src/ir/expr.cc +++ b/src/ir/expr.cc @@ -141,15 +141,18 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "range(min=" << op->min << ", ext=" << op->extent << ')'; }); -GlobalVar::GlobalVar(String name_hint) { +GlobalVar::GlobalVar(String name_hint, Type type) { ObjectPtr n = make_object(); n->name_hint = std::move(name_hint); + n->checked_type_ = std::move(type); data_ = std::move(n); } TVM_REGISTER_NODE_TYPE(GlobalVarNode); -TVM_REGISTER_GLOBAL("ir.GlobalVar").set_body_typed([](String name) { return GlobalVar(name); }); +TVM_REGISTER_GLOBAL("ir.GlobalVar").set_body_typed([](String name, Type type) { + return GlobalVar(name, type); +}); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { diff --git a/src/relay/backend/contrib/cmsisnn/extract_constants.cc b/src/relay/backend/contrib/cmsisnn/extract_constants.cc index 5ed23ad1ad6a..ca003d80c1d9 100644 --- a/src/relay/backend/contrib/cmsisnn/extract_constants.cc +++ b/src/relay/backend/contrib/cmsisnn/extract_constants.cc @@ -46,6 +46,8 @@ class ExtractConstantsMutator : public MixedModeMutator { private: String gen_var_name() { return "tvm_var_extract_const_" + std::to_string(var_count_++); } + using MixedModeMutator::VisitExpr_; + Expr VisitExpr_(const FunctionNode* function) final { Function func = GetRef(function); function_to_constants_.Set(func, Array{}); @@ -56,7 +58,7 @@ class ExtractConstantsMutator : public MixedModeMutator { func = Function(FreeVars(new_body), new_body, func->ret_type, FreeTypeVars(new_body, mod_), func->attrs); } - return func; + return std::move(func); } Expr Rewrite_(const CallNode* call, const Expr& post) final { diff --git a/src/relay/backend/contrib/cmsisnn/generate_constants.cc b/src/relay/backend/contrib/cmsisnn/generate_constants.cc index 056784b6675d..472f93a0a1f0 100644 --- a/src/relay/backend/contrib/cmsisnn/generate_constants.cc +++ b/src/relay/backend/contrib/cmsisnn/generate_constants.cc @@ -179,7 +179,7 @@ class GenerateConstantsMutator : public MixedModeMutator { if (clip_call) { ret_call = Call(clip_call->op, {ret_call}, clip_call->attrs, {}); } - return ret_call; + return std::move(ret_call); } Expr Rewrite_(const CallNode* call, const Expr& post) final { diff --git a/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc b/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc index 2a7d0ae21769..e0e5aa962239 100644 --- a/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc +++ b/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc @@ -101,8 +101,10 @@ class CodeGenCMSISNN : public codegen::CodeGenCHost { int clip_max; }; + using codegen::CodeGenCHost::VisitStmt_; + /*! * \brief Emits CMSIS-NN APIs for every call_extern */ - void VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) + void VisitExpr_(const CallNode* op, std::ostream& os) final { if (!op->op.same_as(builtin::call_extern())) { CodeGenCHost::VisitExpr_(op, os); return; diff --git a/src/relay/backend/graph_executor_codegen.cc b/src/relay/backend/graph_executor_codegen.cc index 3d889cdf6561..16b1ddb3c82f 100644 --- a/src/relay/backend/graph_executor_codegen.cc +++ b/src/relay/backend/graph_executor_codegen.cc @@ -694,8 +694,7 @@ class GraphExecutorCodegenModule : public runtime::ModuleNode { *rv = this->output_.external_mods; }); } else if (name == "get_devices") { - return PackedFunc( - [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = Array(); }); + return PackedFunc([sptr_to_self](TVMArgs args, TVMRetValue* rv) { *rv = Array(); }); } else if (name == "get_metadata") { return PackedFunc( [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->output_.metadata; }); diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 93b2bcb8d7ef..23aee452ba09 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -1088,6 +1088,10 @@ IRModule VMCompiler::OptimizeModuleImpl(IRModule mod) { // let-bound functions. pass_seqs.push_back(DeadCodeElimination(/*inline_once=*/false)); + // Now that we have PrimFuncs, flow and solve SEScope constraints again to account for + // any memory scopes which lowering has settled on. + pass_seqs.push_back(transform::PlanDevices(config_)); + // Inline the functions that are lifted to the module scope. We perform this // pass after all other optimization passes but before the memory allocation // pass. This is because memory allocation pass will insert `invoke_tvm_op` diff --git a/src/relay/ir/adt.cc b/src/relay/ir/adt.cc index c2b8fd641d03..0389547a78f9 100644 --- a/src/relay/ir/adt.cc +++ b/src/relay/ir/adt.cc @@ -115,7 +115,7 @@ Clause WithFields(Clause clause, Optional opt_lhs, Optional opt_r cow_clause_node->lhs = lhs; cow_clause_node->rhs = rhs; } - return std::move(clause); + return clause; } TVM_REGISTER_NODE_TYPE(ClauseNode); @@ -168,7 +168,7 @@ Match WithFields(Match match, Optional opt_data, Optional> o cow_match_node->complete = complete; cow_match_node->span = span; } - return std::move(match); + return match; } TVM_REGISTER_NODE_TYPE(MatchNode); diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index 6b4b2f16ce1e..18e83f998e24 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -107,7 +107,7 @@ Tuple WithFields(Tuple tuple, Optional> opt_fields, cow_tuple_node->virtual_device_ = virtual_device; cow_tuple_node->span = span; } - return std::move(tuple); + return tuple; } TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) @@ -124,6 +124,13 @@ Var::Var(Id vid, Type type_annotation, Span span) { data_ = std::move(n); } +/* static */ Var Var::GenSym(Type type_annotation, Span span) { + static size_t next_id = std::atomic(0); + std::ostringstream os; + os << "x_" << next_id++; + return Var(os.str(), std::move(type_annotation), std::move(span)); +} + Var WithFields(Var var, Optional opt_vid, Optional opt_type_annotation, Optional opt_virtual_device, Optional opt_span) { Id vid = opt_vid.value_or(var->vid); @@ -141,7 +148,7 @@ Var WithFields(Var var, Optional opt_vid, Optional opt_type_annotation cow_var_node->virtual_device_ = virtual_device; cow_var_node->span = span; } - return std::move(var); + return var; } TVM_REGISTER_NODE_TYPE(VarNode); @@ -219,7 +226,7 @@ Call WithFields(Call call, Optional opt_op, Optional> opt_args cow_call_node->virtual_device_ = virtual_device; cow_call_node->span = span; } - return std::move(call); + return call; } TVM_REGISTER_NODE_TYPE(CallNode); @@ -264,7 +271,7 @@ Let WithFields(Let let, Optional opt_var, Optional opt_value, Optiona cow_let_node->virtual_device_ = virtual_device; cow_let_node->span = span; } - return std::move(let); + return let; } TVM_REGISTER_NODE_TYPE(LetNode); @@ -308,7 +315,7 @@ If WithFields(If if_expr, Optional opt_cond, Optional opt_true_branc cow_if_node->virtual_device_ = virtual_device; cow_if_node->span = span; } - return std::move(if_expr); + return if_expr; } TVM_REGISTER_NODE_TYPE(IfNode); @@ -350,7 +357,7 @@ TupleGetItem WithFields(TupleGetItem tuple_get_item, Optional opt_tuple, cow_tuple_get_item_node->span = span; cow_tuple_get_item_node->virtual_device_ = virtual_device; } - return std::move(tuple_get_item); + return tuple_get_item; } TVM_REGISTER_NODE_TYPE(TupleGetItemNode); @@ -385,7 +392,7 @@ RefCreate WithFields(RefCreate ref_create, Optional opt_value, cow_ref_create_node->virtual_device_ = virtual_device; cow_ref_create_node->span = span; } - return std::move(ref_create); + return ref_create; } TVM_REGISTER_NODE_TYPE(RefCreateNode); @@ -420,7 +427,7 @@ RefRead WithFields(RefRead ref_read, Optional opt_ref, Optional o cow_ref_read_node->virtual_device_ = virtual_device; cow_ref_read_node->span = span; } - return std::move(ref_read); + return ref_read; } TVM_REGISTER_NODE_TYPE(RefReadNode); @@ -457,7 +464,7 @@ RefWrite WithFields(RefWrite ref_write, Optional opt_ref, Optional o cow_ref_write_node->virtual_device_ = virtual_device; cow_ref_write_node->span = span; } - return std::move(ref_write); + return ref_write; } TVM_REGISTER_NODE_TYPE(RefWriteNode); @@ -510,29 +517,29 @@ inline void Dismantle(const Expr& expr) { stack.top().second = true; // special handling - if (const CallNode* op = node.as()) { + if (const auto* call_node = node.as()) { // do not process args if used elsewhere - if (op->args.use_count() < 2) { - for (auto it = op->args.rbegin(); it != op->args.rend(); ++it) { + if (call_node->args.use_count() < 2) { + for (auto it = call_node->args.rbegin(); it != call_node->args.rend(); ++it) { fpush_to_stack(*it); } } - } else if (const TupleNode* op = node.as()) { + } else if (const auto* tuple_node = node.as()) { // do not process fields if used elsewhere - if (op->fields.use_count() < 2) { - for (auto it = op->fields.rbegin(); it != op->fields.rend(); ++it) { + if (tuple_node->fields.use_count() < 2) { + for (auto it = tuple_node->fields.rbegin(); it != tuple_node->fields.rend(); ++it) { fpush_to_stack(*it); } } - } else if (const TupleGetItemNode* op = node.as()) { + } else if (const auto* tuple_get_item_node = node.as()) { // do not process tuple if used elsewhere - if (op->tuple.use_count() < 2) { - fpush_to_stack(op->tuple); + if (tuple_get_item_node->tuple.use_count() < 2) { + fpush_to_stack(tuple_get_item_node->tuple); } - } else if (const LetNode* op = node.as()) { + } else if (const auto* let_node = node.as()) { // do not process let if used elsewhere - if (op->body.use_count() < 2) { - fpush_to_stack(op->body); + if (let_node->body.use_count() < 2) { + fpush_to_stack(let_node->body); } } } diff --git a/src/relay/ir/function.cc b/src/relay/ir/function.cc index f2cb02194009..4c5b867e49da 100644 --- a/src/relay/ir/function.cc +++ b/src/relay/ir/function.cc @@ -91,7 +91,7 @@ Function WithFields(Function function, Optional> opt_params, Optional cow_function_node->virtual_device_ = virtual_device; cow_function_node->span = span; } - return std::move(function); + return function; } FuncType FunctionNode::func_type_annotation() const { diff --git a/src/relay/op/memory/on_device.cc b/src/relay/op/memory/on_device.cc index ae5ef33da6d0..0fd86d3de67c 100644 --- a/src/relay/op/memory/on_device.cc +++ b/src/relay/op/memory/on_device.cc @@ -99,7 +99,9 @@ Expr MaybeOnDevice(Expr body, SEScope se_scope, bool constrain_result, bool cons ICHECK(inner == outer) << "Cannot constrain intermediate result of nested on_device calls to different SEScopes"; } - // We can now ignore the intermediate constraints, if any. + // We can now ignore the middle constraint. + // If the outer on_device has any constraint then use se_scope given for it. + // Otherwise we can use the existing inner se_scope. return OnDevice(props.body, (constrain_inner || constrain_outer) ? outer : inner, constrain_outer, constrain_inner); } else { diff --git a/src/relay/op/memory/on_device.h b/src/relay/op/memory/on_device.h index bac6695ac35b..2ebaf034c760 100644 --- a/src/relay/op/memory/on_device.h +++ b/src/relay/op/memory/on_device.h @@ -66,15 +66,18 @@ struct OnDeviceProps { }; /*! - * \brief As for OnDevice, but taking all fields other than \p body from \p props. + * \brief Wraps \p body in an "on_device" CallNode, taking all fields other than \p body from \p + * props. */ inline Call OnDeviceWithProps(Expr body, const OnDeviceProps& props) { return OnDevice(std::move(body), props.se_scope, props.constrain_result, props.constrain_body); } /*! - * \brief As for OnDevice, but don't constrain the body or result to any particular virtual device. - * This allows a "device_copy" when required. + * \brief Wraps \p body in an "on_device" CallNode, but don't constrain the body or result to + * any particular virtual device. This allows a "device_copy" to be inserted by PlanDevices + * where required, while at the same time not introducing unnecessary freedom in the device + * choices. */ inline Call OnDeviceCopyOk(Expr body) { return OnDevice(std::move(body), SEScope::FullyUnconstrained(), diff --git a/src/relay/transforms/device_domains.cc b/src/relay/transforms/device_domains.cc index 76697d8437f4..fd46a6dc0563 100644 --- a/src/relay/transforms/device_domains.cc +++ b/src/relay/transforms/device_domains.cc @@ -199,9 +199,15 @@ DeviceDomainPtr DeviceDomains::DomainForCallee(const Call& call) { DeviceCopyProps device_copy_props = GetDeviceCopyProps(call.get()); CallLoweredProps call_lowered_props = GetCallLoweredProps(call.get()); - // TODO(mbs): Support call_lowered to PrimFuncs. - ICHECK(!call_lowered_props.lowered_func.defined()); - if (on_device_props.body.defined()) { + if (call_lowered_props.lowered_func.defined()) { + // Presumably we've already seen the call to the "primitive" Function from which this lowered + // function was derived in an earlier PlanDevices pass. Thus we've already established that + // all the argument and result devices domains must be equal, ignoring memory scopes. + // So at this point we'll let all the arguments and result be free so that memory scopes can + // differ. + // TODO(mbs): As per header comments, need to revisit when can setup sub-SEScope constraints. + return DomainFor(call_lowered_props.lowered_func); + } else if (on_device_props.body.defined()) { // By default: // on_device(expr, se_scope=) // on_device : fn():?x? diff --git a/src/relay/transforms/device_planner.cc b/src/relay/transforms/device_planner.cc index a85233de17e5..bad8363f4783 100644 --- a/src/relay/transforms/device_planner.cc +++ b/src/relay/transforms/device_planner.cc @@ -57,6 +57,8 @@ * idempotent. * - Some special operators require their arguments or results to be on the 'host' (typcially * a CPU) \p SEScope, see below. + * - Any \p PrimFuncs in the \p IRModule (if \p LowerTEPass has already run) may constrain their + * argument buffers to have a specific memory scope, which is part of \p SEScope. * - Annotations left over from a previous run of this pass, such as 'param_se_scopes' and * 'result_se_scope' function attributes we introduce below. This is so the pass is idempotent * and can be re-run to flow additional memory scope constraints. @@ -71,10 +73,14 @@ * - We wish to treat \code on_device(expr, device_type=d).0 \endcode as if it were written * \code on_device(expr.0, device_type_d) \endcode. I.e. we prefer to copy the projection from * the tuple rather than project from a copy of the tuple. We'll do this by rewriting. + * - We are prepared to insert device_copies on the arguments and result of calls to PrimFuncs, + * on the assumption a) we already ran PlanDevices before lowering so we are not allowing + * any new cross-device copies, but b) after lowering we may have new memory scope constraits + * to deal with. * * Phase 1 * ------- - * We flow constraints from the "on_device" and "device_copy" calls, + * We flow constraints from the "on_device" and "device_copy" calls, PrimFunc buffer memory scopes, * and some special ops, to all other Relay sub-expressions. * * For a primitive such as \code add(e1, e2) \endcode all arguments and results must be on the @@ -102,6 +108,10 @@ * different from each other. Every call to the function must use the same choice of parameter * and result devices -- there is no 'device polymorphism' for Relay functions. * + * Currently \p PrimFuncs and external functions do not carry over their parameter and result + * devices from their original Relay Function representations. However we know all calls to those + * functions are device-consistent, thus no information is lost. + * * Phase 2 * ------- * After flowing constraints we apply some defaulting heuristics (using a global default \p SEScope) @@ -138,6 +148,7 @@ * around a var or global var. These uses of "on_device" imply both the argument and result are * on the same device. We signal this by setting the 'is_fixed' OnDeviceAttrs field to true, * which helps make this pass idempotent. + * - The buffer maps for called PrimFuncs are updated to capture memory scopes. * * Helper visitors (in device_aware_visitors.h) can be used by downstream transforms to recover * the device for any expression for their own use, e.g. during memory planning. All downstream @@ -268,6 +279,7 @@ #include +#include "../../tir/analysis/device_constraint_utils.h" #include "../op/annotation/annotation.h" #include "../op/memory/device_copy.h" #include "../op/memory/on_device.h" @@ -301,6 +313,17 @@ namespace { * on_device(e).0 * ==> on_device(e.0) * \endcode + * + * - Be prepared to copy arguments and results on primitive call boundaries in case memory + * scopes don't line up. We'll use the 'fully unconstrained' version of on_device so that + * we can allow for a device_copy without knowing the specific device for the arguments. + * \code + * call_lowered(@prim, (a, b)) + * ==> copy_ok(call_lowered(@prim, (copy_ok(a), copy_ok(b)))) + * where + * copy_ok(x) = on_device(x, se_scope=SEScope::FullyUnconstrained, + * constrain_body=False, constrain_result=False) + * \endcode */ class RewriteOnDevices : public ExprMutator { public: @@ -358,6 +381,26 @@ class RewriteOnDevices : public ExprMutator { return WithFields(GetRef(function_node), function_node->params, std::move(body)); } + Expr VisitExpr_(const CallNode* call_node) final { + CallLoweredProps props = GetCallLoweredProps(call_node); + if (props.lowered_func.defined()) { + BaseFunc base_func = mod_->Lookup(props.lowered_func); + if (base_func.as()) { + VLOG(2) << "allowing device_copy on PrimFunc arguments and result"; + Array new_args; + new_args.reserve(props.arguments.size()); + for (const auto& arg : props.arguments) { + Expr new_arg = VisitExpr(arg); + new_args.push_back(OnDeviceCopyOk(std::move(new_arg))); + } + Call new_call = CallLowered(std::move(props.lowered_func), std::move(new_args), props.attrs, + call_node->span); + return OnDeviceCopyOk(std::move(new_call)); + } + } + return ExprMutator::VisitExpr_(call_node); + } + /*! \brief Module we are rewriting, so we can lookup global definitions. */ IRModule mod_; }; @@ -398,6 +441,10 @@ class DeviceAnalyzer : public ExprVisitor { VLOG(2) << "collecting constraints from Relay Function '" << kv.first->name_hint << "'"; domains_->UnifyExprExact(kv.first, kv.second); VisitExpr(GetRef(function_node)); + } else if (const auto* prim_func_node = kv.second.as()) { + VLOG(2) << "collecting constraints from TIR PrimFunc '" << kv.first->name_hint << "'"; + domains_->UnifyExprExact( + kv.first, DomainForPrimFunc(kv.first, GetRef(prim_func_node))); } else { VLOG(2) << "skipping '" << kv.first->name_hint << "'"; } @@ -406,6 +453,40 @@ class DeviceAnalyzer : public ExprVisitor { } private: + /*! + * \brief Return the domain representing \p prim_func which, before lowering, had + * the Relay \p type. + */ + DeviceDomainPtr DomainForPrimFunc(const GlobalVar& global_var, const tir::PrimFunc& prim_func) { + // CAUTION: The prim_func->checked_type() is currently w.r.t. the flattened and DPS form + // of the prim func, however here we wish to remain within the Relay view of all functions. + // Thus we'll use the global var who's checked_type is in Relay form. + auto func_domain = domains_->DomainFor(global_var); // higher-order + + // TODO(mbs): We don't visit the body of the function -- there's currently nothing to be done. + const auto* func_type_node = global_var->checked_type().as(); + ICHECK(func_type_node); + ICHECK_EQ(func_domain->function_arity(), func_type_node->arg_types.size()); + + Array se_scopes = + tir::GetPrimFuncArgAndResultConstraints(prim_func, GetRef(func_type_node)); + + // Build the implied domain (in terms of the function's Relay type) implied by any memory scope + // constrains in the function's buffers, for both arguments and results. + std::vector args_and_result_domains; + args_and_result_domains.reserve(se_scopes.size()); + for (size_t i = 0; i < func_type_node->arg_types.size(); ++i) { + const SEScope& param_se_scope = se_scopes[i]; + VLOG(2) << "param_se_scope[" << i << "] = " << param_se_scope; + args_and_result_domains.push_back(domains_->MakeFirstOrderDomain(param_se_scope)); + } + const SEScope& ret_se_scope = se_scopes.back(); + VLOG(2) << "ret_se_scope = " << ret_se_scope; + args_and_result_domains.push_back(domains_->MakeFirstOrderDomain(ret_se_scope)); + + return domains_->MakeHigherOrderDomain(std::move(args_and_result_domains)); + } + void VisitExpr_(const CallNode* call_node) final { auto call = GetRef(call_node); @@ -849,6 +930,15 @@ class DeviceCapturer : public ExprMutator { if (const auto* function_node = AsOptimizableFunctionNode(kv.second)) { VLOG(2) << "capturing devices for Relay Function '" << kv.first->name_hint << "'"; result->Add(kv.first, Downcast(Mutate(GetRef(function_node)))); + } else if (const auto* prim_func_node = kv.second.as()) { + VLOG(2) << "capturing devices for TIR PrimFunc '" << kv.first->name_hint << "'"; + auto prim_func = GetRef(prim_func_node); + tir::PrimFunc new_prim_func = UpdatePrimFunc(kv.first, prim_func); + VLOG(2) << "Rewritten prim func:" << std::endl + << PrettyPrint(prim_func) << std::endl + << "to:" << std::endl + << PrettyPrint(new_prim_func); + result->Add(kv.first, std::move(new_prim_func)); } else { VLOG(2) << "skipping '" << kv.first->name_hint << "'"; result->Add(kv.first, kv.second); @@ -858,6 +948,34 @@ class DeviceCapturer : public ExprMutator { } private: + /*! + * \brief Returns \p prim_func updated to capture any memory scope's implied by its device + * domain. + */ + tir::PrimFunc UpdatePrimFunc(const GlobalVar& global_var, const tir::PrimFunc& prim_func) { + // CAUTION: Same caution as for DeviceAnalyzer::DomainForPrimFunc. + auto func_domain = domains_->DomainFor(global_var); + ICHECK(func_domain->is_higher_order()); + + const auto* func_type_node = global_var->checked_type().as(); + ICHECK(func_type_node); + ICHECK_EQ(func_domain->function_arity(), func_type_node->arg_types.size()); + + std::vector arg_and_result_se_scopes; + arg_and_result_se_scopes.reserve(func_type_node->arg_types.size() + 1); + for (size_t i = 0; i < func_type_node->arg_types.size(); ++i) { + SEScope param_se_scope = domains_->ResultSEScope(func_domain->function_param(i)); + VLOG(2) << "param_se_scope[" << i << "] = " << param_se_scope; + arg_and_result_se_scopes.push_back(param_se_scope); + } + SEScope ret_se_scope = domains_->ResultSEScope(func_domain->function_result()); + VLOG(2) << "ret_se_scope = " << ret_se_scope; + arg_and_result_se_scopes.push_back(ret_se_scope); + + return tir::ApplyPrimFuncArgAndResultConstraints(prim_func, GetRef(func_type_node), + arg_and_result_se_scopes); + } + // Nothing interesting for VarNode, ConstantNode, GlobalVarNode, OpNode and ConstructorNode Expr VisitExpr_(const TupleNode* tuple_node) final { @@ -932,8 +1050,7 @@ class DeviceCapturer : public ExprMutator { // match. return VisitExpr(device_copy_props.body); } else { - return VisitChild(/*lexical_se_scope=*/ - dst_se_scope, + return VisitChild(/*lexical_se_scope=*/dst_se_scope, /*expected_se_scope=*/dst_se_scope, /*child_se_scope=*/src_se_scope, device_copy_props.body); } diff --git a/src/relay/transforms/let_list.h b/src/relay/transforms/let_list.h index 56875f6c16a1..f449d6c3b011 100644 --- a/src/relay/transforms/let_list.h +++ b/src/relay/transforms/let_list.h @@ -79,7 +79,7 @@ class LetList { * * \return a Var that hold the inserted expr. */ - Var Push(Expr expr, Type ty) { return Push(Var("x", ty), expr); } + Var Push(Expr expr, Type ty) { return Push(Var::GenSym(ty), expr); } /*! * \brief insert a binding. diff --git a/src/relay/transforms/partial_eval.cc b/src/relay/transforms/partial_eval.cc index 8f5e9e146d54..28d1aa5532bf 100644 --- a/src/relay/transforms/partial_eval.cc +++ b/src/relay/transforms/partial_eval.cc @@ -842,7 +842,7 @@ class PartialEvaluator : public ExprFunctor }); } - PStatic VisitFunc(const Function& func, LetList* ll, const Var& name = Var("x", Type())) { + PStatic VisitFunc(const Function& func, LetList* ll, const Var& name) { Func f = VisitFuncStatic(func, name); Function u_func = AsFunc(RegisterFuncId(DeDup(AnnotateFuncId(func)))); // TODO(@M.K.): we seems to reduce landin knot into letrec. @@ -851,7 +851,7 @@ class PartialEvaluator : public ExprFunctor } PStatic VisitExpr_(const FunctionNode* op, LetList* ll) final { - return VisitFunc(GetRef(op), ll); + return VisitFunc(GetRef(op), ll, Var::GenSym()); } struct ReflectError : Error { diff --git a/src/relay/transforms/to_a_normal_form.cc b/src/relay/transforms/to_a_normal_form.cc index c955269e3412..741de6d7ea9b 100644 --- a/src/relay/transforms/to_a_normal_form.cc +++ b/src/relay/transforms/to_a_normal_form.cc @@ -219,7 +219,7 @@ class Fill : ExprFunctor, private transform::Lexi // v is already defined (e.g. coming from a Let expression). Otherwise return `now` directly Expr Compound(const Expr& orig, const Expr& now, const Var& v) { Expr annotated_expr = MaybeOnDeviceFixed(now, GetSEScope(orig)); - Var var = v.defined() ? v : Var(String("x"), Type()); + Var var = v.defined() ? v : Var::GenSym(); bool not_included = include_set_ && include_set_->find(orig) == include_set_->end(); if (!v.defined() && not_included) { return annotated_expr; diff --git a/src/runtime/contrib/verilator/verilator_runtime.h b/src/runtime/contrib/verilator/verilator_runtime.h index 588b3f172a3e..9ef17d7481ab 100644 --- a/src/runtime/contrib/verilator/verilator_runtime.h +++ b/src/runtime/contrib/verilator/verilator_runtime.h @@ -94,7 +94,7 @@ class VerilatorRuntime : public JSONRuntimeBase { ~VerilatorRuntime(); - const char* type_key() const { return "verilator"; } + const char* type_key() const final { return "verilator"; } /*! \brief set verilator library */ void SetLibrary(const std::string& lib_name); diff --git a/src/tir/transforms/lower_cross_thread_reduction.cc b/src/tir/transforms/lower_cross_thread_reduction.cc index fb5255664af3..2eea869af516 100644 --- a/src/tir/transforms/lower_cross_thread_reduction.cc +++ b/src/tir/transforms/lower_cross_thread_reduction.cc @@ -207,7 +207,7 @@ class InThreadReducerMaker : private StmtMutator { if (res->thread_binding.defined()) { return res->body; } else { - return res; + return std::move(res); } } else { return Stmt{nullptr}; @@ -564,7 +564,7 @@ class CrossThreadReductionTransformer : public StmtMutator { } } } - return new_block; + return std::move(new_block); } Stmt VisitStmt_(const BlockRealizeNode* realize) final { diff --git a/tests/python/relay/test_pass_plan_devices.py b/tests/python/relay/test_pass_plan_devices.py index 6bf103ea0c2a..ee9cfc909585 100644 --- a/tests/python/relay/test_pass_plan_devices.py +++ b/tests/python/relay/test_pass_plan_devices.py @@ -23,6 +23,7 @@ import tvm from tvm import relay +from tvm.script import tir as T import tvm.testing import numpy as np @@ -1580,6 +1581,110 @@ def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], %c: Tensor[( exercise(input(), expected(), None, None) +def test_lowered(): + """ + Tests propagation of memory scopes from PrimFuncs and insertion + of device_copies to mediate any scope changes. + """ + + @T.prim_func + def input_gem(a: T.handle, b: T.handle, c: T.handle, d: T.handle) -> None: + A = T.match_buffer(a, [128, 128], scope="scopeA") # will flow out + B = T.match_buffer(b, [128, 128], scope="") # will flow in + C = T.match_buffer(c, [128, 128], scope="scopeB") # will flow out + D = T.match_buffer(d, [128, 128], scope="scopeA") # will flow out + + for i, j, k in T.grid(128, 128, 128): + with T.block("update"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + D[vi, vj] = C[vi, vj] + D[vi, vj] = D[vi, vj] + A[vi, vk] * B[vj, vk] + + @T.prim_func + def expected_gem(a: T.handle, b: T.handle, c: T.handle, d: T.handle) -> None: + A = T.match_buffer(a, [128, 128], scope="scopeA") + B = T.match_buffer(b, [128, 128], scope="scopeB") # flowed in + C = T.match_buffer(c, [128, 128], scope="scopeB") + D = T.match_buffer(d, [128, 128], scope="scopeA") + + for i, j, k in T.grid(128, 128, 128): + with T.block("update"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + D[vi, vj] = C[vi, vj] + D[vi, vj] = D[vi, vj] + A[vi, vk] * B[vj, vk] + + metatable = { + "SEScope": [ + CPU, # meta[SEScope][0], no memory scope + CPU_SCOPE_A, # meta[SEScope][1], "scopeA" + CPU_SCOPE_B, + ] + } # meta[SEScope][2], "scopeB" + gem_ty = relay.FuncType( + [ + relay.TensorType((128, 128), "float32"), + relay.TensorType((128, 128), "float32"), + relay.TensorType((128, 128), "float32"), + ], + relay.TensorType((128, 128), "float32"), + ) + gem_gv = relay.GlobalVar("gem", type_annot=gem_ty) + + def input(): + mod = tvm.ir.IRModule() + mod[gem_gv] = input_gem + # - %x on CPU, no memory scope constraint, so will be constrained by first param of gem to "scopeA". + # - %y on CPU "scopeB", so will flow in to second param of gem. + # - %z on CPU "scopeA", so will clash with third param of gem and will need device_copy. + # - result on CPU "scopeB", but result of gem on "scopeA" so will need device_copy + return tvm.parser.parse( + """ + #[version = "0.0.5"] + def @main(%x : Tensor[(128, 128), float32], + %y : Tensor[(128, 128), float32], + %z : Tensor[(128, 128), float32], + param_se_scopes=[meta[SEScope][0], meta[SEScope][2], meta[SEScope][1]], + result_se_scope=meta[SEScope][2]) { + call_lowered(@gem, (%x, %y, %z)) + } + """, + "from_string", + mod, + metatable, + ) + + def expected(): + mod = tvm.ir.IRModule() + mod[gem_gv] = expected_gem + # - %x now on CPU "scopeA", no device_copy needed. + # - %y still on CPU "scopeB", no device_copy needed. + # - %z still on CPU "scopeA", needs device_copy to "scopeB". + # - result still on CPU "scopeB", needs device_copy from "scopeA". + return tvm.parser.parse( + """ + #[version = "0.0.5"] + def @main(%x : Tensor[(128, 128), float32], + %y : Tensor[(128, 128), float32], + %z : Tensor[(128, 128), float32], + param_se_scopes=[meta[SEScope][1], meta[SEScope][2], meta[SEScope][1]], + result_se_scope=meta[SEScope][2]) { + %0 = device_copy(%z, src_se_scope=meta[SEScope][1], dst_se_scope=meta[SEScope][2]); + %1 = on_device(%0, se_scope=meta[SEScope][2], constrain_result=True); + %2 = call_lowered(@gem, (%x, %y, %1)); + %3 = on_device(%2, se_scope=meta[SEScope][1], constrain_result=True); + device_copy(%3, src_se_scope=meta[SEScope][1], dst_se_scope=meta[SEScope][2]) + } + """, + "from_string", + mod, + metatable, + ) + + exercise(input(), expected(), None, None) + + if __name__ == "__main__": import sys import pytest