diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index b910d32ceca47..d4a42f5c3776e 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -219,7 +219,7 @@ class GlobalVarNode : public RelayExprNode { */ class GlobalVar : public RelayExpr { public: - TVM_DLL explicit GlobalVar(String name_hint); + explicit TVM_DLL GlobalVar(String name_hint, Type type = {}); TVM_DEFINE_OBJECT_REF_METHODS(GlobalVar, RelayExpr, GlobalVarNode); }; diff --git a/include/tvm/relay/attrs/call.h b/include/tvm/relay/attrs/call.h index 2b02c6a5edac0..167a593ff377b 100644 --- a/include/tvm/relay/attrs/call.h +++ b/include/tvm/relay/attrs/call.h @@ -39,7 +39,9 @@ struct CallLoweredAttrs : public tvm::AttrsNode { Map metadata; TVM_DECLARE_ATTRS(CallLoweredAttrs, "relay.attrs.CallLoweredAttrs") { - TVM_ATTR_FIELD(metadata).describe("Metadata attached to the lowered function call."); + TVM_ATTR_FIELD(metadata) + .describe("Metadata attached to the lowered function call.") + .set_default(Map()); } }; diff --git a/include/tvm/relay/attrs/on_device.h b/include/tvm/relay/attrs/on_device.h index 405926e209c69..b1f1e6a6dc451 100644 --- a/include/tvm/relay/attrs/on_device.h +++ b/include/tvm/relay/attrs/on_device.h @@ -19,7 +19,7 @@ /*! * \file tvm/relay/attrs/on_device.h - * \brief Attribute for the on device annotation. + * \brief Attribute for the "on_device" annotation (ie operator). */ #ifndef TVM_RELAY_ATTRS_ON_DEVICE_H_ #define TVM_RELAY_ATTRS_ON_DEVICE_H_ @@ -33,9 +33,9 @@ namespace tvm { namespace relay { /*! - * \brief Attributes for the "on_device" special operator. + * \brief Attributes for the "on_device" annotation (ie operator). * - * The Relay call (aka 'annotation'): + * The Relay call: * \code * on_device(sub_expr, se_scope=S) * \endcode @@ -54,44 +54,48 @@ namespace relay { * multiply(device_copy(add(%x, %y), src_se_scope=GPU, dst_se_scope=CPU), %z) * \endcode * - * The Relay call - * \code - * on_device(sub_expr, se_scope=S, is_fixed=True) - * \endcode - * is similar to the above, however the annotation itself must appear in an expression on the - * same \p SEScope \p S. The compiler will check the \p SEScopes are consistent, and will not - * insert any "device_copy" call. This form of annotation shouldn't be necessary in user programs. - * However it is needed by the \p PlanDevices pass to fully specify the results of device planning - * so that the pass is idempotent. - * - * E.g.: The following program is equivalent to the above: - * \code - * let %a = on_device(add(%x, %y), se_scope=GPU, is_fixed=True) - * multiply(device_copy(%a, src_se_scope=GPU, dst_se_scope=CPU), %z) - * \endcode - * The "on_device" annotation with \p is_fixed=True indicates unambiguously that \p %a is stored - * on the GPU. + * The \p constraint_body (default true) and \p constraint_result (default false) fields can be + * used by passes for finer-grained control over how the \p SEScope constraint should be applied. */ struct OnDeviceAttrs : public tvm::AttrsNode { /*! - * \brief (Virtual) \p SEScope on which the result of the argument expression should be stored. + * \brief The \p SEScope to constraint to apply to the body, result, or both body and result + * of the "on_device" call. */ SEScope se_scope = SEScope::FullyUnconstrained(); + + /*! + * \brief If fales (the default), the result of the "on_device" call is not constrained to be + * \p se_scope. + */ + bool constrain_result = false; + + /*! + * \brief If true (the default), the body of the "on_device" call is constrained to be \p + * se_scope. + */ + bool constrain_body = true; + + /*! + * \brief Returns true if both the body and result are constrained. + */ + bool is_fixed() const { return constrain_result && constrain_body; } + /*! - * \brief If true, the result \p SEScope must also be \p se_scope, and device planning should - * not insert any "device_copy" calls to respect this annotation. - * - * This is used by the device planning pass itself when annotating the planned program. + * \brief Returns true only the body is constrained (the 'normal' case). */ - bool is_fixed = false; + bool is_normal() const { return !constrain_result && constrain_body; } TVM_DECLARE_ATTRS(OnDeviceAttrs, "relay.attrs.OnDeviceAttrs") { TVM_ATTR_FIELD(se_scope) - .describe("The (virtual) device and scope holding the expression result.") + .describe("The (virtual) device to constrain to.") .set_default(SEScope::FullyUnconstrained()); - TVM_ATTR_FIELD(is_fixed) - .describe("If true, do not insert a \"device_copy\" call to respect this annotation.") + TVM_ATTR_FIELD(constrain_result) + .describe("Whether the constraint applies to the overall expression") .set_default(false); + TVM_ATTR_FIELD(constrain_body) + .describe("Whether the constraint applies to the body sub-expression.") + .set_default(true); } }; diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index f57b2d1a1952c..e7e24d95abed6 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -229,6 +229,14 @@ class Var : public Expr { */ TVM_DLL Var(Id vid, Type type_annotation, Span span = Span()); + /*! + * \brief Return a globally fresh name. Helps with debugging to follow the same + * variable between passes and sub-expressions. + * + * TODO(mbs): Replace with name creation w.r.t. scopes. + */ + static Var GenSym(Type type_annotation = {}, Span span = {}); + TVM_DEFINE_OBJECT_REF_METHODS(Var, RelayExpr, VarNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(VarNode); }; diff --git a/include/tvm/target/se_scope.h b/include/tvm/target/se_scope.h index 595f986686ed1..ec5da3a80cae2 100644 --- a/include/tvm/target/se_scope.h +++ b/include/tvm/target/se_scope.h @@ -170,7 +170,7 @@ class SEScopeNode : public AttrsNode { * * kInvalidDeviceType denotes unconstrained. */ - int device_type_int; + int /* actually DLDeviceType */ device_type_int; DLDeviceType device_type() const { return static_cast(device_type_int); } @@ -303,6 +303,11 @@ class SEScope : public ObjectRef { return SEScope(device_type, /*virtual_device_id=*/0, std::move(target)); } + /*! \brief Returns the \p SEScope for \p memory_scope alone. */ + static SEScope ForMemoryScope(MemoryScope memory_scope) { + return SEScope(kInvalidDeviceType, -1, {}, std::move(memory_scope)); + } + /*! \brief Returns the \p SEScope for \p device, \p target and \p memory_scope. */ TVM_DLL static SEScope ForDeviceTargetAndMemoryScope(const Device& device, Target target, MemoryScope memory_scope) { diff --git a/include/tvm/tir/var.h b/include/tvm/tir/var.h index 40a0d1ab2f741..1ac58e18db3ef 100644 --- a/include/tvm/tir/var.h +++ b/include/tvm/tir/var.h @@ -52,7 +52,7 @@ class VarNode : public PrimExprNode { */ String name_hint; /*! - * \brief type annotaion of the variable. + * \brief type annotation of the variable. * * It is an optional field that provides a refined type of the variable than dtype. * diff --git a/python/tvm/ir/expr.py b/python/tvm/ir/expr.py index 139e2b8d97fa4..43cba1a835306 100644 --- a/python/tvm/ir/expr.py +++ b/python/tvm/ir/expr.py @@ -64,8 +64,8 @@ class GlobalVar(RelayExpr): The name of the variable. """ - def __init__(self, name_hint): - self.__init_handle_by_constructor__(_ffi_api.GlobalVar, name_hint) + def __init__(self, name_hint, type_annot=None): + self.__init_handle_by_constructor__(_ffi_api.GlobalVar, name_hint, type_annot) def __call__(self, *args): """Call the global variable. diff --git a/python/tvm/relay/op/annotation/annotation.py b/python/tvm/relay/op/annotation/annotation.py index cf70dc6e267e5..cb4e628ebc92f 100644 --- a/python/tvm/relay/op/annotation/annotation.py +++ b/python/tvm/relay/op/annotation/annotation.py @@ -31,29 +31,35 @@ def _make_se_scope(device): raise ValueError("expecting a Device or device name, but received a %s" % (type(device))) -def on_device(data, device, is_fixed=False): - """Annotates an expression with the device type on which its result should be stored. +def on_device(body, device, constrain_result=False, constrain_body=True): + """Annotates a body expression with device constraints. The constraint influences + how the body is compiled, where the body is evaluated, and where the result of + evaluation is stored. + + Note that the defaults for the constrain_body and constrain_result parameters should + almost never need to be overridden by the user. These parameters are exposed here + to help unit tests exercise the PlanDevices pass machinery. Parameters ---------- - data : tvm.relay.Expr + body : tvm.relay.Expr The expression to be annotated. device : Union[:py:class:`Device`, str] The device to annotate with. - is_fixed : bool - If false (the default), a device_copy - If true, the annotation does not imply a device_copy may be inserted to - reconcile the device of the data argument with the device for the context of the - annotated expression. + constrain_result : bool + If false (the default), the result of the on_device is not constrained to be on device. + + constrain_body : bool + If true (the default), the body of the on_device is constrained to be on device. Returns ------- result : tvm.relay.Expr The annotated expression. """ - return _make.OnDevice(data, _make_se_scope(device), is_fixed) + return _make.OnDevice(body, _make_se_scope(device), constrain_result, constrain_body) def function_on_device(function, param_devices, result_device): diff --git a/src/ir/expr.cc b/src/ir/expr.cc index caddf0efcc77a..399873492f041 100644 --- a/src/ir/expr.cc +++ b/src/ir/expr.cc @@ -141,15 +141,18 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "range(min=" << op->min << ", ext=" << op->extent << ')'; }); -GlobalVar::GlobalVar(String name_hint) { +GlobalVar::GlobalVar(String name_hint, Type type) { ObjectPtr n = make_object(); n->name_hint = std::move(name_hint); + n->checked_type_ = std::move(type); data_ = std::move(n); } TVM_REGISTER_NODE_TYPE(GlobalVarNode); -TVM_REGISTER_GLOBAL("ir.GlobalVar").set_body_typed([](String name) { return GlobalVar(name); }); +TVM_REGISTER_GLOBAL("ir.GlobalVar").set_body_typed([](String name, Type type) { + return GlobalVar(name, type); +}); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { diff --git a/src/printer/relay_text_printer.cc b/src/printer/relay_text_printer.cc index ed5f5f62af947..d0c2cfebbbd82 100644 --- a/src/printer/relay_text_printer.cc +++ b/src/printer/relay_text_printer.cc @@ -502,12 +502,6 @@ Doc RelayTextPrinter::VisitExpr_(const FunctionNode* op) { Doc RelayTextPrinter::VisitExpr_(const GlobalVarNode* op) { Doc doc; doc << "@" << op->name_hint; -#if TVM_LOG_DEBUG - if (op->checked_type_.defined()) { - doc << " /* type=" << PrintType(op->checked_type_, /*meta=*/false) << " */"; - } - doc << " /* id=" << reinterpret_cast(op) << " */"; -#endif return doc; } @@ -521,6 +515,11 @@ Doc RelayTextPrinter::VisitExpr_(const CallNode* op) { for (const Expr& arg : op->args) { args.push_back(Print(arg)); } +#if TVM_LOG_DEBUG + for (const Type& type_arg : op->type_args) { + args.push_back(Print(type_arg)); + } +#endif for (const Doc& d : PrintCallAttrs(op->attrs, op->op)) { args.push_back(d); } diff --git a/src/printer/text_printer.cc b/src/printer/text_printer.cc index 5acb9bd3f1dc2..4d4113fef694d 100644 --- a/src/printer/text_printer.cc +++ b/src/printer/text_printer.cc @@ -27,6 +27,7 @@ #include +#include #include namespace tvm { @@ -36,49 +37,71 @@ static const char* kSemVer = "0.0.5"; Doc TextPrinter::PrintMod(const IRModule& mod) { Doc doc; int counter = 0; + + // We'll print in alphabetical order to make a/b diffs easier to work with. + // type definitions + std::vector tyvars; for (const auto& kv : mod->type_definitions) { + tyvars.emplace_back(kv.first); + } + std::sort(tyvars.begin(), tyvars.end(), + [](const GlobalTypeVar& left, const GlobalTypeVar& right) { + return left->name_hint < right->name_hint; + }); + for (const auto& tyvar : tyvars) { if (counter++ != 0) { doc << Doc::NewLine(); } - doc << relay_text_printer_.Print(kv.second); + doc << relay_text_printer_.Print(mod->type_definitions[tyvar]); doc << Doc::NewLine(); } + // functions + std::vector vars; for (const auto& kv : mod->functions) { - if (kv.second.as()) { + vars.emplace_back(kv.first); + } + std::sort(vars.begin(), vars.end(), [](const GlobalVar& left, const GlobalVar& right) { + return left->name_hint < right->name_hint; + }); + for (const auto& var : vars) { + const BaseFunc& base_func = mod->functions[var]; + if (base_func.as()) { relay_text_printer_.dg_ = - relay::DependencyGraph::Create(&relay_text_printer_.arena_, kv.second); + relay::DependencyGraph::Create(&relay_text_printer_.arena_, base_func); } if (counter++ != 0) { doc << Doc::NewLine(); } - if (kv.second.as()) { + if (base_func.as()) { std::ostringstream os; - os << "def @" << kv.first->name_hint; -#if TVM_LOG_DEBUG - os << " /* id=" << reinterpret_cast(kv.first.get()) << " */"; -#endif - doc << relay_text_printer_.PrintFunc(Doc::Text(os.str()), kv.second); - } else if (kv.second.as()) { - doc << "@" << kv.first->name_hint; -#if TVM_LOG_DEBUG - doc << " /* id=" << reinterpret_cast(kv.first.get()) << " */"; -#endif - doc << " = " << tir_text_printer_.PrintPrimFunc(Downcast(kv.second)); + os << "def @" << var->name_hint; + doc << relay_text_printer_.PrintFunc(Doc::Text(os.str()), base_func); + } else if (base_func.as()) { + doc << "@" << var->name_hint; + doc << " = " << tir_text_printer_.PrintPrimFunc(Downcast(base_func)); } doc << Doc::NewLine(); } + #if TVM_LOG_DEBUG // attributes + // TODO(mbs): Make this official, including support from parser. if (mod->attrs.defined() && !mod->attrs->dict.empty()) { - doc << "attributes {" << Doc::NewLine(); + std::vector keys; for (const auto& kv : mod->attrs->dict) { - doc << " '" << kv.first << "' = " << PrettyPrint(kv.second) << Doc::NewLine(); + keys.emplace_back(kv.first); + } + std::sort(keys.begin(), keys.end()); + doc << "attributes {" << Doc::NewLine(); + for (const auto& key : keys) { + doc << " '" << key << "' = " << PrettyPrint(mod->attrs->dict[key]) << Doc::NewLine(); } doc << "}" << Doc::NewLine(); } #endif + return doc; } diff --git a/src/relay/backend/contrib/cmsisnn/extract_constants.cc b/src/relay/backend/contrib/cmsisnn/extract_constants.cc index 5ed23ad1ad6a5..ca003d80c1d9e 100644 --- a/src/relay/backend/contrib/cmsisnn/extract_constants.cc +++ b/src/relay/backend/contrib/cmsisnn/extract_constants.cc @@ -46,6 +46,8 @@ class ExtractConstantsMutator : public MixedModeMutator { private: String gen_var_name() { return "tvm_var_extract_const_" + std::to_string(var_count_++); } + using MixedModeMutator::VisitExpr_; + Expr VisitExpr_(const FunctionNode* function) final { Function func = GetRef(function); function_to_constants_.Set(func, Array{}); @@ -56,7 +58,7 @@ class ExtractConstantsMutator : public MixedModeMutator { func = Function(FreeVars(new_body), new_body, func->ret_type, FreeTypeVars(new_body, mod_), func->attrs); } - return func; + return std::move(func); } Expr Rewrite_(const CallNode* call, const Expr& post) final { diff --git a/src/relay/backend/contrib/cmsisnn/generate_constants.cc b/src/relay/backend/contrib/cmsisnn/generate_constants.cc index 2e12697f36f1c..b0b14a02eee93 100644 --- a/src/relay/backend/contrib/cmsisnn/generate_constants.cc +++ b/src/relay/backend/contrib/cmsisnn/generate_constants.cc @@ -179,7 +179,7 @@ class GenerateConstantsMutator : public MixedModeMutator { if (clip_call) { ret_call = Call(clip_call->op, {ret_call}, clip_call->attrs, {}); } - return ret_call; + return std::move(ret_call); } Expr Rewrite_(const CallNode* call, const Expr& post) final { diff --git a/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc b/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc index 85923b3ed08e1..4410c99f01376 100644 --- a/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc +++ b/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc @@ -81,8 +81,10 @@ class CodeGenCMSISNN : public codegen::CodeGenCHost { int depth_multiplier; }; + using codegen::CodeGenCHost::VisitStmt_; + /*! * \brief Emit the CMSIS-NN context buffer */ - void VisitStmt_(const AllocateNode* op) { + void VisitStmt_(const AllocateNode* op) final { context_buffer_name_ = op->buffer_var->name_hint; context_buffer_size_ = op->constant_allocation_size(); CodeGenC::VisitStmt_(op); diff --git a/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc b/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc index c41399e314ef0..9f206e120a1b4 100644 --- a/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc +++ b/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc @@ -116,10 +116,10 @@ class ConvertAddToSubtract : public MixedModeMutator { // Since we are replacing the Relay function with a call to a TIR function, we must use the // call_lowered op. - auto call_lowered_attrs = make_object(); - call_lowered_attrs->metadata.Set("relay_attrs", call->attrs); - return CallLowered(std::move(new_global_var), call->args, - std::move(Attrs(call_lowered_attrs)), call->type_args, call->span); + CallLoweredAttrs attrs; + attrs.metadata.Set("relay_attrs", call->attrs); + ICHECK(call->type_args.empty()) << "lowered functions cannot be polymorphic"; + return CallLowered(std::move(new_global_var), call->args, std::move(attrs), call->span); } } diff --git a/src/relay/backend/graph_executor_codegen.cc b/src/relay/backend/graph_executor_codegen.cc index 3d889cdf65617..16b1ddb3c82f2 100644 --- a/src/relay/backend/graph_executor_codegen.cc +++ b/src/relay/backend/graph_executor_codegen.cc @@ -694,8 +694,7 @@ class GraphExecutorCodegenModule : public runtime::ModuleNode { *rv = this->output_.external_mods; }); } else if (name == "get_devices") { - return PackedFunc( - [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = Array(); }); + return PackedFunc([sptr_to_self](TVMArgs args, TVMRetValue* rv) { *rv = Array(); }); } else if (name == "get_metadata") { return PackedFunc( [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->output_.metadata; }); diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index b47bc401b37f9..528df647fe4a1 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -560,8 +560,7 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { * to the TIR implementation, and attributes to attach to the call to identify it as * a TIR call. */ - Expr MakeLoweredCall(Function func, Array visited_args, Array type_args, Span span, - Target target) { + Expr MakeLoweredCall(Function func, Array visited_args, Span span, Target target) { CCacheKey key = CCacheKey(func, target); CachedFunc cfunc = compiler_->Lower(key, module_name_); ICHECK(cfunc.defined()); @@ -594,16 +593,16 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { func_with_metadata = WithAttr(func_with_metadata, tvm::attr::kTarget, cfunc->target); this->process_fn_(func_with_metadata); - auto call_lowered_attrs = make_object(); + CallLoweredAttrs call_lowered_attrs; // Non-External Relay Function // TODO(mbs): "reshape" cleanup. if (!opt_compiler && func->HasNonzeroAttr(attr::kReshapeOnly)) { - call_lowered_attrs->metadata.Set(attr::kReshapeOnly, tvm::Integer(1)); + call_lowered_attrs.metadata.Set(attr::kReshapeOnly, tvm::Integer(1)); } - call_lowered_attrs->metadata.Set("relay_attrs", func->attrs); - call_lowered_attrs->metadata.Set("all_prim_fn_vars", all_prim_fn_vars); + call_lowered_attrs.metadata.Set("relay_attrs", func->attrs); + call_lowered_attrs.metadata.Set("all_prim_fn_vars", all_prim_fn_vars); if (IsDynamic(func->ret_type)) { // Also lower the companion dynamic shape function. @@ -616,12 +615,12 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { // Capture the shape function's global var and parameters 'states' in call // annotations so calling convention can be recovered. // TODO(mbs): Shape cleanup. - call_lowered_attrs->metadata.Set("prim_shape_fn_var", lowered_shape_func->prim_fn_var); - call_lowered_attrs->metadata.Set("prim_shape_fn_states", - lowered_shape_func->shape_func_param_states); - call_lowered_attrs->metadata.Set( - "prim_shape_fn_num_inputs", Integer(static_cast(lowered_shape_func->inputs.size()))); - call_lowered_attrs->metadata.Set( + call_lowered_attrs.metadata.Set("prim_shape_fn_var", lowered_shape_func->prim_fn_var); + call_lowered_attrs.metadata.Set("prim_shape_fn_states", + lowered_shape_func->shape_func_param_states); + call_lowered_attrs.metadata.Set("prim_shape_fn_num_inputs", + Integer(static_cast(lowered_shape_func->inputs.size()))); + call_lowered_attrs.metadata.Set( "prim_shape_fn_num_outputs", Integer(static_cast(lowered_shape_func->outputs.size()))); Array all_prim_shape_fn_vars; @@ -629,11 +628,11 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { CHECK(kv.second.as()) << "must be a prim fn"; all_prim_shape_fn_vars.push_back(kv.first); } - call_lowered_attrs->metadata.Set("all_prim_shape_fn_vars", all_prim_shape_fn_vars); + call_lowered_attrs.metadata.Set("all_prim_shape_fn_vars", all_prim_shape_fn_vars); } - return CallLowered(cfunc->prim_fn_var, std::move(visited_args), Attrs(call_lowered_attrs), - type_args, std::move(span)); + return CallLowered(cfunc->prim_fn_var, std::move(visited_args), std::move(call_lowered_attrs), + std::move(span)); } std::pair PreVisitLetBinding_(const Var& var, const Expr& value) final { @@ -729,12 +728,13 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { }); ICHECK(!IsDynamic(call_node->checked_type())); - auto call_lowered_attrs = make_object(); - call_lowered_attrs->metadata.Set("relay_attrs", primitive_func->attrs); + CallLoweredAttrs call_lowered_attrs; + call_lowered_attrs.metadata.Set("relay_attrs", primitive_func->attrs); process_fn_(func_with_metadata); - return CallLowered(call_node->op, std::move(new_args), Attrs(std::move(call_lowered_attrs)), - call_node->type_args, call_node->span); + ICHECK(call_node->type_args.empty()) << "lowered functions cannot be polymorphic"; + return CallLowered(prim_func_var, std::move(new_args), std::move(call_lowered_attrs), + call_node->span); } // Typical case: call to fused primitive Relay Function. @@ -754,8 +754,8 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { // Lower the primitive function for that target. Function function = Downcast(primitive_func); - return MakeLoweredCall(function, std::move(new_args), call_node->type_args, call_node->span, - target); + ICHECK(call_node->type_args.empty()) << "lowered functions cannot be polymorphic"; + return MakeLoweredCall(function, std::move(new_args), call_node->span, target); } IRModule module_; diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index b7c0999ecc72b..23aee452ba09e 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -301,7 +301,8 @@ class VMFunctionCompiler : DeviceAwareExprFunctor { size_t NewRegister() { return registers_num_++; } inline void Emit(const Instruction& instr) { - VLOG(2) << "VMCompiler::Emit: instr=" << instr; + size_t instruction_index = instructions_.size(); + VLOG(2) << "instruction[" << instruction_index << "] = " << instr; ICHECK((int)instr.op < 100) << "Invalid opcode " << (int)instr.op; switch (instr.op) { case Opcode::AllocADT: @@ -336,10 +337,9 @@ class VMFunctionCompiler : DeviceAwareExprFunctor { * in emitted code. Note that the host device is always at index 0. */ Index GetDeviceIndex(const SEScope& se_scope) { - VLOG(2) << "getting device index for " << se_scope; + ICHECK(!se_scope->IsFullyUnconstrained()); auto itr = std::find(context_->se_scopes_.begin(), context_->se_scopes_.end(), se_scope); if (itr != context_->se_scopes_.end()) { - VLOG(2) << "reusing existing scope"; return std::distance(context_->se_scopes_.begin(), itr); } @@ -367,7 +367,7 @@ class VMFunctionCompiler : DeviceAwareExprFunctor { ICHECK(se_scope != host_se_scope_); Index index = context_->se_scopes_.size(); - VLOG(2) << "adding new scope"; + VLOG(2) << "se_scope[" << index << "] = " << se_scope; context_->se_scopes_.push_back(se_scope); return index; @@ -378,11 +378,13 @@ class VMFunctionCompiler : DeviceAwareExprFunctor { void VisitExpr_(const ConstantNode* const_node) final { // Check the shape is valid NDArray data = const_node->data; - size_t konst_idx = context_->constants.size(); + size_t const_index = context_->constants.size(); auto con = GetRef(const_node); - context_->const_device_indexes.push_back(GetDeviceIndex(GetSEScope(con))); + Index device_index = GetDeviceIndex(GetSEScope(con)); + VLOG(2) << "constant[" << const_index << "] on device[" << device_index << "]"; + context_->const_device_indexes.push_back(device_index); context_->constants.push_back(const_node->data); - Emit(Instruction::LoadConst(konst_idx, NewRegister())); + Emit(Instruction::LoadConst(const_index, NewRegister())); } void VisitExpr_(const VarNode* var_node) final { @@ -872,6 +874,7 @@ void VMCompiler::Lower(IRModule mod, TargetMap targets, tvm::Target target_host) // The first device is always for the host. CHECK(context_.se_scopes_.empty()); + VLOG(2) << "se_scope[0] = " << config_->host_se_scope << " (host)"; context_.se_scopes_.push_back(config_->host_se_scope); // Run the optimizations necessary to target the VM. @@ -1085,6 +1088,10 @@ IRModule VMCompiler::OptimizeModuleImpl(IRModule mod) { // let-bound functions. pass_seqs.push_back(DeadCodeElimination(/*inline_once=*/false)); + // Now that we have PrimFuncs, flow and solve SEScope constraints again to account for + // any memory scopes which lowering has settled on. + pass_seqs.push_back(transform::PlanDevices(config_)); + // Inline the functions that are lifted to the module scope. We perform this // pass after all other optimization passes but before the memory allocation // pass. This is because memory allocation pass will insert `invoke_tvm_op` diff --git a/src/relay/ir/adt.cc b/src/relay/ir/adt.cc index c2b8fd641d030..0389547a78f95 100644 --- a/src/relay/ir/adt.cc +++ b/src/relay/ir/adt.cc @@ -115,7 +115,7 @@ Clause WithFields(Clause clause, Optional opt_lhs, Optional opt_r cow_clause_node->lhs = lhs; cow_clause_node->rhs = rhs; } - return std::move(clause); + return clause; } TVM_REGISTER_NODE_TYPE(ClauseNode); @@ -168,7 +168,7 @@ Match WithFields(Match match, Optional opt_data, Optional> o cow_match_node->complete = complete; cow_match_node->span = span; } - return std::move(match); + return match; } TVM_REGISTER_NODE_TYPE(MatchNode); diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index 8998f4e1573db..6af1b9a647e94 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -95,7 +95,7 @@ Tuple WithFields(Tuple tuple, Optional> opt_fields, Optional o cow_tuple_node->fields = fields; cow_tuple_node->span = span; } - return std::move(tuple); + return tuple; } TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) @@ -112,6 +112,13 @@ Var::Var(Id vid, Type type_annotation, Span span) { data_ = std::move(n); } +/* static */ Var Var::GenSym(Type type_annotation, Span span) { + static size_t next_id = 0; + std::ostringstream os; + os << "x_" << next_id++; + return Var(os.str(), std::move(type_annotation), std::move(span)); +} + Var WithFields(Var var, Optional opt_vid, Optional opt_type_annotation, Optional opt_span) { Id vid = opt_vid.value_or(var->vid); @@ -127,7 +134,7 @@ Var WithFields(Var var, Optional opt_vid, Optional opt_type_annotation cow_var_node->type_annotation = type_annotation; cow_var_node->span = span; } - return std::move(var); + return var; } TVM_REGISTER_NODE_TYPE(VarNode); @@ -203,7 +210,7 @@ Call WithFields(Call call, Optional opt_op, Optional> opt_args cow_call_node->type_args = type_args; cow_call_node->span = span; } - return std::move(call); + return call; } TVM_REGISTER_NODE_TYPE(CallNode); @@ -246,7 +253,7 @@ Let WithFields(Let let, Optional opt_var, Optional opt_value, Optiona cow_let_node->body = body; cow_let_node->span = span; } - return std::move(let); + return let; } TVM_REGISTER_NODE_TYPE(LetNode); @@ -286,7 +293,7 @@ If WithFields(If if_expr, Optional opt_cond, Optional opt_true_branc cow_if_node->true_branch = true_branch; cow_if_node->false_branch = false_branch; } - return std::move(if_expr); + return if_expr; } TVM_REGISTER_NODE_TYPE(IfNode); @@ -325,7 +332,7 @@ TupleGetItem WithFields(TupleGetItem tuple_get_item, Optional opt_tuple, cow_tuple_get_item_node->index = index; cow_tuple_get_item_node->span = span; } - return std::move(tuple_get_item); + return tuple_get_item; } TVM_REGISTER_NODE_TYPE(TupleGetItemNode); @@ -357,7 +364,7 @@ RefCreate WithFields(RefCreate ref_create, Optional opt_value, Optionalvalue = value; cow_ref_create_node->span = span; } - return std::move(ref_create); + return ref_create; } TVM_REGISTER_NODE_TYPE(RefCreateNode); @@ -389,7 +396,7 @@ RefRead WithFields(RefRead ref_read, Optional opt_ref, Optional opt_ cow_ref_read_node->ref = ref; cow_ref_read_node->span = span; } - return std::move(ref_read); + return ref_read; } TVM_REGISTER_NODE_TYPE(RefReadNode); @@ -424,7 +431,7 @@ RefWrite WithFields(RefWrite ref_write, Optional opt_ref, Optional o cow_ref_write_node->value = value; cow_ref_write_node->span = span; } - return std::move(ref_write); + return ref_write; } TVM_REGISTER_NODE_TYPE(RefWriteNode); @@ -477,29 +484,29 @@ inline void Dismantle(const Expr& expr) { stack.top().second = true; // special handling - if (const CallNode* op = node.as()) { + if (const auto* call_node = node.as()) { // do not process args if used elsewhere - if (op->args.use_count() < 2) { - for (auto it = op->args.rbegin(); it != op->args.rend(); ++it) { + if (call_node->args.use_count() < 2) { + for (auto it = call_node->args.rbegin(); it != call_node->args.rend(); ++it) { fpush_to_stack(*it); } } - } else if (const TupleNode* op = node.as()) { + } else if (const auto* tuple_node = node.as()) { // do not process fields if used elsewhere - if (op->fields.use_count() < 2) { - for (auto it = op->fields.rbegin(); it != op->fields.rend(); ++it) { + if (tuple_node->fields.use_count() < 2) { + for (auto it = tuple_node->fields.rbegin(); it != tuple_node->fields.rend(); ++it) { fpush_to_stack(*it); } } - } else if (const TupleGetItemNode* op = node.as()) { + } else if (const auto* tuple_get_item_node = node.as()) { // do not process tuple if used elsewhere - if (op->tuple.use_count() < 2) { - fpush_to_stack(op->tuple); + if (tuple_get_item_node->tuple.use_count() < 2) { + fpush_to_stack(tuple_get_item_node->tuple); } - } else if (const LetNode* op = node.as()) { + } else if (const auto* let_node = node.as()) { // do not process let if used elsewhere - if (op->body.use_count() < 2) { - fpush_to_stack(op->body); + if (let_node->body.use_count() < 2) { + fpush_to_stack(let_node->body); } } } diff --git a/src/relay/ir/function.cc b/src/relay/ir/function.cc index f24dd6d1fb4f8..dca452740d978 100644 --- a/src/relay/ir/function.cc +++ b/src/relay/ir/function.cc @@ -88,7 +88,7 @@ Function WithFields(Function function, Optional> opt_params, Optional cow_function_node->attrs = attrs; cow_function_node->span = span; } - return std::move(function); + return function; } FuncType FunctionNode::func_type_annotation() const { diff --git a/src/relay/op/call/call.cc b/src/relay/op/call/call.cc index 6f2d3a8ddbada..ab8e7e12d2132 100644 --- a/src/relay/op/call/call.cc +++ b/src/relay/op/call/call.cc @@ -63,23 +63,23 @@ bool CallLoweredRel(const Array& types, int num_inputs, const Attrs& attrs const Op& CallLoweredOp() { return Op::Get("call_lowered"); } -Expr CallLowered(Expr func, Array inputs, Attrs attrs, Array type_args, Span span) { - // Right now, call_lowered only supports func being a global var pointing to the lowered - // function. - ICHECK(func.as()) - << "Function to call should be GlobalVarNode, but got:" << std::endl - << PrettyPrint(func); - ICHECK(attrs.as()) - << "Expected attributes to be CallLoweredAttrs, but got " << attrs->GetTypeKey(); - return Call(CallLoweredOp(), {std::move(func), Tuple(std::move(inputs))}, std::move(attrs), - std::move(type_args), std::move(span)); +Call CallLowered(GlobalVar lowered_func, Array args, CallLoweredAttrs call_lowered_attrs, + Span span) { + auto attrs = make_object(std::move(call_lowered_attrs)); + return Call(CallLoweredOp(), {std::move(lowered_func), Tuple(std::move(args))}, + Attrs(std::move(attrs)), /*type_args=*/{}, std::move(span)); } TVM_REGISTER_GLOBAL("relay.op.call_lowered") - .set_body_typed([](Expr func, Array inputs, Attrs attrs, Array type_args, - Span span) { - const TupleNode* tuple_node = inputs.as(); - return CallLowered(func, tuple_node->fields, attrs, type_args, span); + .set_body_typed([](Expr lowered_func, Array args, Attrs attrs, Span span) { + const auto* lowered_func_node = lowered_func.as(); + ICHECK(lowered_func_node) << "Function to call should be GlobalVarNode, but got:" << std::endl + << PrettyPrint(lowered_func); + const auto* call_lowered_attrs = attrs.as(); + ICHECK(call_lowered_attrs) << "Expected attributes to be CallLoweredAttrs, but got " + << attrs->GetTypeKey(); + return CallLowered(GetRef(lowered_func_node), std::move(args), *call_lowered_attrs, + std::move(span)); }); RELAY_REGISTER_OP("call_lowered") @@ -105,10 +105,12 @@ CallLoweredProps GetCallLoweredProps(const CallNode* call_node) { ICHECK(tuple_args) << "Expected second arg to call_lowered to be a Tuple of input arguments."; ICHECK(call_node->attrs.defined()) << "Expecting call_lowered to have attributes."; - const auto* attrs = call_node->attrs.as(); - ICHECK(attrs) << "Expected call_lowered op to have CallLoweredAttrs, but found " - << call_node->attrs->GetTypeKey(); - return CallLoweredProps{GetRef(function_node), tuple_args->fields, *attrs}; + const auto* call_lowered_attrs = call_node->attrs.as(); + ICHECK(call_lowered_attrs) << "Expected call_lowered op to have CallLoweredAttrs, but found " + << call_node->attrs->GetTypeKey(); + // If the call_node has type_args then they are for the polymorphic 'call_lowered' operator + // itself which expects the function type and argument type as parameters. + return {GetRef(function_node), tuple_args->fields, *call_lowered_attrs}; } return {}; } @@ -116,8 +118,9 @@ CallLoweredProps GetCallLoweredProps(const CallNode* call_node) { Call GetAnyCall(const CallNode* call_node) { CallLoweredProps props = GetCallLoweredProps(call_node); if (props.lowered_func.defined()) { - auto attrs = make_object(props.attrs); - return Call(std::move(props.lowered_func), props.arguments, Attrs(std::move(attrs)), + auto call_lowered_attrs = make_object(props.attrs); + return Call(std::move(props.lowered_func), std::move(props.arguments), + Attrs(std::move(call_lowered_attrs)), /*type_args=*/{}, call_node->span); } else { return GetRef(call_node); diff --git a/src/relay/op/call/call.h b/src/relay/op/call/call.h index a2d30eaa87635..6193c9249ee26 100644 --- a/src/relay/op/call/call.h +++ b/src/relay/op/call/call.h @@ -32,23 +32,33 @@ namespace tvm { namespace relay { -/*! - * \brief Helper to construct a Relay call with the call_lowered op. - * \param func Lowered function to call with call_lowered. - * \param inputs Arguments to be passed to the function. - * \param attrs Function attributes, should be TIRCallAttrs. - * \param type_args Type arguments for the call. - * \param span TVM span for propogating debugging info. - * \return - */ -Expr CallLowered(Expr func, Array inputs, Attrs attrs, Array type_args, Span span); - /*! * \brief Returns the Relay call_lowered op. Use this helper to avoid extraneous calls to * Registry::Get. */ const Op& CallLoweredOp(); +/*! + * \brief Helper to construct a Relay call with the "call_lowered" op. + * + * The callee must: + * - Be a global bound to a PrimFunc or an externally defined functions. + * - Accept only tensor arguments and return tensor results. + * - Arguments and results correspond to the flattened form (see FlattenTupleType) of the + * Relay Function type. + * - Return results by output pointer, ie use DPS. + * The arguments remain in Relay form (ie not flattened). + * The result remains in Relay form (ie returned from the call and not flattened). + * + * \param lowered_func Lowered function to call with call_lowered. + * \param args Arguments to be passed to the function. + * \param call_lowered_attrs Function attributes. + * \param span TVM span for propagating debugging info. + * \return + */ +Call CallLowered(GlobalVar lowered_func, Array args, CallLoweredAttrs call_lowered_attrs, + Span span); + /*! * \brief Lowered function and the arguments to call it with. */ @@ -57,7 +67,7 @@ struct CallLoweredProps { GlobalVar lowered_func; /*! \brief Array of the arguments to call lowered_func with. */ Array arguments; - /*! \brief Arguments from the call_lowered op. */ + /*! \brief Attributes from the call_lowered op. */ CallLoweredAttrs attrs; }; diff --git a/src/relay/op/memory/on_device.cc b/src/relay/op/memory/on_device.cc index 9541d4122a2f9..ae5ef33da6d01 100644 --- a/src/relay/op/memory/on_device.cc +++ b/src/relay/op/memory/on_device.cc @@ -29,8 +29,6 @@ #include #include #include -#include -#include #include "../../transforms/infer_layout_utils.h" #include "../type_relations.h" @@ -45,53 +43,74 @@ const Op& OnDeviceOp() { return op; } -Expr OnDevice(Expr expr, SEScope se_scope, bool is_fixed) { - ICHECK(!se_scope->IsFullyUnconstrained()); +Call OnDevice(Expr body, SEScope se_scope, bool constrain_result, bool constrain_body) { + ICHECK((!constrain_result && !constrain_body) || !se_scope->IsFullyUnconstrained()); auto attrs = make_object(); - attrs->se_scope = std::move(se_scope); - attrs->is_fixed = is_fixed; - Span span = expr->span; - return Call(OnDeviceOp(), {std::move(expr)}, Attrs(std::move(attrs)), /*type_args=*/{}, + attrs->se_scope = + (constrain_result || constrain_body) ? std::move(se_scope) : SEScope::FullyUnconstrained(); + attrs->constrain_result = constrain_result; + attrs->constrain_body = constrain_body; + Span span = body->span; // about to be moved + return Call(OnDeviceOp(), {std::move(body)}, Attrs(std::move(attrs)), /*type_args=*/{}, std::move(span)); } TVM_REGISTER_GLOBAL("relay.op.annotation._make.OnDevice").set_body_typed(OnDevice); -Expr MaybeOnDevice(Expr expr, SEScope se_scope, bool is_fixed) { +Expr MaybeOnDevice(Expr body, SEScope se_scope, bool constrain_result, bool constrain_body) { if (se_scope->IsFullyUnconstrained()) { // Nothing to annotate with. - return expr; + return body; } - if (expr->IsInstance() || expr->IsInstance()) { + if (body->IsInstance() || body->IsInstance()) { // These operators are device polymorphic so no annotation is required. - return expr; + return body; } - if (expr->IsInstance() || expr->IsInstance()) { + if (body->IsInstance() || body->IsInstance()) { // The device can be recovered from the binding site of the global or local variable. - return expr; + return body; } - if (expr->IsInstance()) { + if (body->IsInstance()) { // If a primitive function then it is device polymorphic. Otherwise the device is captured // by the function's "result_se_scope" attribute. - return expr; + return body; } - OnDeviceProps props = GetOnDeviceProps(expr); + OnDeviceProps props = GetOnDeviceProps(body); if (props.body.defined()) { - // Don't nest on_devices. - // If the inner and outer device types differ then we need to be careful: - // - If the inner on_device is_fixed then it disagrees with the outer. - // - If the outer on_device is_fixed then it implies a hidden device_copy - // Otherwise just use the inner device type and ignore the outer. - ICHECK(props.se_scope == se_scope || (!is_fixed && !props.is_fixed)); - return OnDevice(props.body, se_scope, is_fixed || props.is_fixed); + // The user is asking for + // on_device(on_device(body, se_scope=inner), se_scope=outer) + // ^ ^ ^ + // outer middle inner + // First recover the implied constraints (if any) for outer and inner, and check they don't + // contradict. + const SEScope& inner = props.se_scope; + const SEScope& outer = se_scope; + bool constrain_outer = constrain_result; + bool constrain_inner = props.constrain_body; + if (constrain_outer && constrain_inner) { + ICHECK(inner == outer) + << "Cannot constrain result and body of nested on_device calls to different SEScopes"; + } + // There are two possible ways the middle sub-expression may be constrained, check they don't + // contradict. + bool constrain_middle_via_outer = constrain_body; + bool constrain_middle_via_inner = props.constrain_result; + if (constrain_middle_via_outer && constrain_middle_via_inner) { + ICHECK(inner == outer) + << "Cannot constrain intermediate result of nested on_device calls to different SEScopes"; + } + // We can now ignore the intermediate constraints, if any. + return OnDevice(props.body, (constrain_inner || constrain_outer) ? outer : inner, + constrain_outer, constrain_inner); + } else { + return OnDevice(body, std::move(se_scope), constrain_result, constrain_body); } - return OnDevice(expr, std::move(se_scope), is_fixed); } RELAY_REGISTER_OP("on_device") .describe(R"code(Annotate an expression with device type)code" TVM_ADD_FILELINE) .set_num_inputs(1) - .add_argument("data", "Tensor", "The input data.") + .add_argument("body", "Expr", "The sub-expression to be annotated.") .set_support_level(10) .add_type_rel("Identity", IdentityRel) .set_attrs_type_key("relay.attrs.OnDeviceAttrs") @@ -106,14 +125,8 @@ OnDeviceProps GetOnDeviceProps(const CallNode* call_node) { ICHECK(call_node->attrs.defined()) << "on_device requires attributes"; const auto* on_device_attrs = call_node->attrs.as(); ICHECK(on_device_attrs != nullptr) << "on_device requires OnDeviceAttrs"; - // Follow nesting: - // on_device(on_device(expr, se_scope=S), se_scope=T) == {expr, S} - auto inner = GetOnDeviceProps(call_node->args[0]); - if (inner.body.defined()) { - return {inner.body, inner.se_scope, on_device_attrs->is_fixed || inner.is_fixed}; - } else { - return {call_node->args[0], on_device_attrs->se_scope, on_device_attrs->is_fixed}; - } + return {call_node->args[0], on_device_attrs->se_scope, on_device_attrs->constrain_result, + on_device_attrs->constrain_body}; } return {}; } diff --git a/src/relay/op/memory/on_device.h b/src/relay/op/memory/on_device.h index a7b6cb7cf52a5..bac6695ac35b1 100644 --- a/src/relay/op/memory/on_device.h +++ b/src/relay/op/memory/on_device.h @@ -39,14 +39,50 @@ namespace relay { const Op& OnDeviceOp(); /*! - * \brief Wraps \p expr in an "on_device" CallNode for \p se_scope and \p is_fixed. + * \brief Wraps \p body in an "on_device" CallNode for \p se_scope. * * See \p OnDeviceAttrs for an overview. */ -Expr OnDevice(Expr expr, SEScope se_scope, bool is_fixed); +Call OnDevice(Expr body, SEScope se_scope, bool constrain_result = false, + bool constrain_body = true); + +/*! \brief Result of \p GetOnDeviceProps. */ +struct OnDeviceProps { + Expr body; // = null + SEScope se_scope = SEScope::FullyUnconstrained(); + bool constrain_result = false; + bool constrain_body = false; + + OnDeviceProps() = default; + + OnDeviceProps(Expr body, SEScope se_scope, bool constrain_result, bool constrain_body) + : body(std::move(body)), + se_scope(std::move(se_scope)), + constrain_result(constrain_result), + constrain_body(constrain_body) {} + + bool is_fixed() const { return constrain_result && constrain_body; } + bool is_normal() const { return !constrain_result && constrain_body; } +}; + +/*! + * \brief As for OnDevice, but taking all fields other than \p body from \p props. + */ +inline Call OnDeviceWithProps(Expr body, const OnDeviceProps& props) { + return OnDevice(std::move(body), props.se_scope, props.constrain_result, props.constrain_body); +} /*! - * \brief Wraps \p expr in an "on_device" CallNode for \p se_scope and \p is_fixed if the + * \brief As for OnDevice, but don't constrain the body or result to any particular virtual device. + * This allows a "device_copy" when required. + */ +inline Call OnDeviceCopyOk(Expr body) { + return OnDevice(std::move(body), SEScope::FullyUnconstrained(), + /*constrain_result=*/false, /*constrain_body=*/false); +} + +/*! + * \brief Wraps \p expr in an "on_device" CallNode for \p se_scope and \p constraint if the * \p SEScope for \p expr cannot otherwise be recovered by the lexical scoping convention. * This means we will NOT wrap if: * - \p se_scope is full unconstrained, which signals there are no device annotations @@ -57,33 +93,34 @@ Expr OnDevice(Expr expr, SEScope se_scope, bool is_fixed); * - \p expr is a global var. The device is on the function attributes the global is bound to. * - \p expr is a local var. The device is tracked by the device aware visitors for us. * - \p expr is a constructor. These are device polymorphic. - * + * Nested on_device calls will never be constructed, they are instead merged on-the-fly. */ -Expr MaybeOnDevice(Expr expr, SEScope se_scope, bool is_fixed); - -/*! \brief Result of \p GetOnDeviceProps. */ -struct OnDeviceProps { - Expr body; // = null - SEScope se_scope = SEScope::FullyUnconstrained(); - bool is_fixed = false; +Expr MaybeOnDevice(Expr body, SEScope se_scope, bool constrain_result = false, + bool constrain_body = true); - OnDeviceProps() = default; +/*! \brief As for MaybeOnDevice, but with both body and result constrained. */ +inline Expr MaybeOnDeviceFixed(Expr body, SEScope se_scope) { + return MaybeOnDevice(std::move(body), std::move(se_scope), /*constrain_result=*/true, + /*constrain_body=*/true); +} - OnDeviceProps(Expr body, SEScope se_scope, bool isFixed) - : body(std::move(body)), se_scope(std::move(se_scope)), is_fixed(isFixed) {} -}; +/*! \brief As for MaybeOnDevice, but with fields other than body taken from \p props. */ +inline Expr MaybeOnDeviceWithProps(Expr body, const OnDeviceProps& props) { + return MaybeOnDevice(std::move(body), props.se_scope, props.constrain_result, + props.constrain_body); +} /*! - * \brief Returns the body expression, \p SEScope, and is_fixed field for \p call_node if it + * \brief Returns the body expression, \p SEScope, and constraint field for \p call_node if it * is an "on_device" CallNode. Otherwise returns the null expression, the unconstrained - * \p SEScope, and false. + * \p SEScope, and \p kBody. */ OnDeviceProps GetOnDeviceProps(const CallNode* call_node); /*! - * \brief Returns the body expression, \p SEScope, and is_fixed field for \p expr if it is an + * \brief Returns the body expression, \p SEScope, and constraint field for \p expr if it is an * "on_device" CallNode. Otherwise returns the null expression, the unconstrained \p SEScope, - * and \p false. + * and \p kBody. */ OnDeviceProps GetOnDeviceProps(const Expr& expr); diff --git a/src/relay/transforms/device_aware_visitors.cc b/src/relay/transforms/device_aware_visitors.cc index e3d5a821c58e4..29965d2dac97d 100644 --- a/src/relay/transforms/device_aware_visitors.cc +++ b/src/relay/transforms/device_aware_visitors.cc @@ -36,11 +36,12 @@ namespace transform { LexicalOnDeviceMixin::LexicalOnDeviceMixin(const Optional& maybe_mod) { if (maybe_mod) { - for (const auto& pair : maybe_mod.value()->functions) { - if (const auto* function_node = pair.second.as()) { + for (const auto& kv : maybe_mod.value()->functions) { + if (const auto* function_node = kv.second.as()) { SEScope se_scope = GetFunctionResultSEScope(function_node); if (!se_scope->IsFullyUnconstrained()) { - global_var_se_scopes_.emplace(pair.first, se_scope); + VLOG(2) << "global '" << kv.first->name_hint << "' has scope " << se_scope; + global_var_se_scopes_.emplace(kv.first, se_scope); } } } @@ -49,7 +50,7 @@ LexicalOnDeviceMixin::LexicalOnDeviceMixin(const Optional& maybe_mod) SEScope LexicalOnDeviceMixin::GetSEScope(const Expr& expr) const { OnDeviceProps props = GetOnDeviceProps(expr); - if (props.body.defined() && props.is_fixed) { + if (props.body.defined() && props.is_fixed()) { return props.se_scope; } else if (const auto* var_node = expr.as()) { // Lookup variable binding. @@ -175,7 +176,7 @@ void DeviceAwareExprVisitor::VisitExpr_(const LetNode* let_node) { void DeviceAwareExprVisitor::VisitExpr_(const CallNode* call_node) { OnDeviceProps props = GetOnDeviceProps(call_node); - if (props.body.defined() && props.is_fixed) { + if (props.body.defined() && props.is_fixed()) { // Entering lexical scope of fixed "on_device" call. PushSEScope(props.se_scope); VisitExpr(props.body); @@ -266,13 +267,13 @@ Expr DeviceAwareExprMutator::VisitExpr_(const LetNode* let_node) { Expr DeviceAwareExprMutator::VisitExpr_(const CallNode* call_node) { OnDeviceProps props = GetOnDeviceProps(call_node); - if (props.body.defined() && props.is_fixed) { + if (props.body.defined() && props.is_fixed()) { // Entering lexical scope of fixed "on_device" call. PushSEScope(props.se_scope); Expr expr = VisitExpr(props.body); // Leaving lexical scope of "on_device" call. PopSEScope(); - return MaybeOnDevice(expr, props.se_scope, props.is_fixed); + return MaybeOnDeviceWithProps(expr, props); } else { return DeviceAwareVisitExpr_(call_node); } diff --git a/src/relay/transforms/device_aware_visitors.h b/src/relay/transforms/device_aware_visitors.h index 8cdf0db74ebd3..044cda85c5796 100644 --- a/src/relay/transforms/device_aware_visitors.h +++ b/src/relay/transforms/device_aware_visitors.h @@ -146,7 +146,10 @@ class DeviceAwareExprFunctor : public ExprFunctorparams[i], GetFunctionParamSEScope(function_node, i)); } // Entering scope of function body. - PushSEScope(GetFunctionResultSEScope(function_node)); + SEScope se_scope = GetFunctionResultSEScope(function_node); + VLOG(2) << "entering " << se_scope << " for function:" << std::endl + << PrettyPrint(GetRef(function_node)); + PushSEScope(se_scope); EnterFunctionBody(); DeviceAwareVisitExpr_(function_node); @@ -154,6 +157,8 @@ class DeviceAwareExprFunctor : public ExprFunctor(function_node)); // Function parameters go out of scope. for (size_t i = 0; i < function_node->params.size(); ++i) { PopBoundVar(function_node->params[i]); @@ -168,7 +173,9 @@ class DeviceAwareExprFunctor : public ExprFunctor()) { // Let-bound var (in pre visited version) goes into scope. // (We'll just assume this is a letrec.) - PushBoundVar(inner_let_node->var, GetSEScope(inner_let_node->value)); + SEScope se_scope = GetSEScope(inner_let_node->value); + VLOG(2) << "var '" << inner_let_node->var->name_hint() << "' has scope " << se_scope; + PushBoundVar(inner_let_node->var, se_scope); PreVisitLetBinding_(inner_let_node->var, inner_let_node->value); bindings.emplace_back(inner_let_node); expr = inner_let_node->body; @@ -187,12 +194,16 @@ class DeviceAwareExprFunctor : public ExprFunctor(call_node)); PushSEScope(props.se_scope); VisitExpr(props.body); // Leaving lexical scope of "on_device" call. PopSEScope(); + VLOG(2) << "leaving " << props.se_scope << " for on_device:" << std::endl + << PrettyPrint(GetRef(call_node)); } else { DeviceAwareVisitExpr_(call_node); } diff --git a/src/relay/transforms/device_domains.cc b/src/relay/transforms/device_domains.cc index dafd8bc814bbe..3cc9310b13598 100644 --- a/src/relay/transforms/device_domains.cc +++ b/src/relay/transforms/device_domains.cc @@ -199,16 +199,23 @@ DeviceDomainPtr DeviceDomains::DomainForCallee(const Call& call) { DeviceCopyProps device_copy_props = GetDeviceCopyProps(call.get()); CallLoweredProps call_lowered_props = GetCallLoweredProps(call.get()); - if (on_device_props.body.defined()) { - // on_device(expr, se_scope=, is_fixed=false) - // on_device : fn():?x? - // - // on_device(expr, se_scope=, is_fixed=true) - // on_device: fn(): - args_and_result.emplace_back( - ForSEScope(on_device_props.body->checked_type(), on_device_props.se_scope)); - if (on_device_props.is_fixed) { - args_and_result.emplace_back(args_and_result.front()); + if (call_lowered_props.lowered_func.defined()) { + return DomainFor(call_lowered_props.lowered_func); + } else if (on_device_props.body.defined()) { + // By default: + // on_device(expr, se_scope=) + // on_device : fn():?x? + // However we'll interpret the constrain_body and constrain_result fields to decide + // on free vs constrained domains for the argument and result respectively. + if (on_device_props.constrain_body) { + args_and_result.emplace_back( + ForSEScope(on_device_props.body->checked_type(), on_device_props.se_scope)); + } else { + args_and_result.emplace_back(Free(on_device_props.body->checked_type())); + } + if (on_device_props.constrain_result) { + args_and_result.emplace_back( + ForSEScope(on_device_props.body->checked_type(), on_device_props.se_scope)); } else { args_and_result.emplace_back(Free(on_device_props.body->checked_type())); } @@ -286,11 +293,9 @@ DeviceDomainPtr DeviceDomains::DomainForCallee(const Call& call) { args_and_result.emplace_back(param_domain); } args_and_result.emplace_back(result_domain); - } else if (call_lowered_props.lowered_func.defined()) { - return DomainFor(call_lowered_props.lowered_func); } else { // We still need to handle the case where the function / op is not lowered - // because the device planner runs before and after lowering. + // because the device planner runs both before and after lowering. return DomainFor(call->op); } auto domain = MakeHigherOrderDomain(std::move(args_and_result)); diff --git a/src/relay/transforms/device_planner.cc b/src/relay/transforms/device_planner.cc index 8ea5f5dac0a4c..0f1af3340daf6 100644 --- a/src/relay/transforms/device_planner.cc +++ b/src/relay/transforms/device_planner.cc @@ -20,6 +20,10 @@ /*! * \file src/relay/transforms/device_planner.cc * \brief Determines a unique \p SEScope to hold the result of every Relay sub-expression. + * This pass can be run multiple times, and can be run both before and after lowering. + * + * TODO(mbs): Rename SEScope |-> VirtualDevice, and use 'virtual device' (or just 'device') + * throughout. * * We say a Relay expression E is 'on device D' if the result of executing E is stored on D. * We represent D by an \p SEScope, which means we can track anywhere from an arbitrary device @@ -29,17 +33,21 @@ * Note that 'stored on device D' is almost but not quite the same as 'executes on device D', * see below. * - * This pass assumes the module already contains some "on_device" and/or "device_copy" CallNodes: - * - "device_copy" CallNodes (with a \p DeviceCopyAttrs attribute) specify a 'src_se_scope' and - * 'dst_se_scope' \p SEScopes, which constrain the argument and context of the call - * respectively. It is ok if source and destination devices are the same, such no-op copies - * will be removed after accounting for the device preference. - * - "on_device" CallNodes (with a \p OnDeviceAttrs attribute) specify an 'se_scope', which - * constrains the argument of the call, but (usually, see below) leaves the context + * This pass works by collecting and solving device constraints, using defaulting heuristics to + * resolve any remaining undetermined devices, and encoding the results on the output in a form + * that's reasonably friendly to downstream passes. + * + * Specific \p SEScopes flow into the constraints from five places: + * - Existing "device_copy" CallNodes (with a \p DeviceCopyAttrs attribute) specify a + * 'src_se_scope' and 'dst_se_scope' \p SEScope. Those constrain the argument and context of + * the call respectively. It is ok if source and destination devices are the same, such no-op + * copies will be removed after accounting for the device preference. + * - Existing "on_device" CallNodes (with a \p OnDeviceAttrs attribute) specify an 'se_scope', + * which constrains the argument of the call, but (usually, see below) leaves the context * unconstrained. These are called 'annotations' in the rest of the code, have no operational - * significance by themselves, but may trigger the insertion of a new "device_copy". - * - In two situations the result of an "on_device" CallNode may also be constrained to the - * given device: + * significance by themselves, but may trigger the insertion of a new "device_copy" call by + * this pass. In two situations the result of an "on_device" CallNode may also be constrained + * to the given 'se_scope': * - The "on_device" call occurs at the top-level of a function body, or occurs as an * immediately let-bound expression. In this situation the extra degree of freedom in * the function result and let-binding leads to surprising device copies, so we simply @@ -47,6 +55,13 @@ * - The \p OnDeviceAttrs has an \p is_fixed field of \p true, which indicates we inserted * it ourselves during an earlier invocation of this pass. This helps make this pass * idempotent. + * - Some special operators require their arguments or results to be on the 'host' (typcially + * a CPU) \p SEScope, see below. + * - Any \p PrimFuncs in the \p IRModule (if \p LowerTEPass has already run) may constrain their + * argument buffers to have a specific memory scope, which is part of \p SEScope. + * - Annotations left over from a previous run of this pass, such as 'param_se_scopes' and + * 'result_se_scope' function attributes we introduce below. This is so the pass is idempotent + * and can be re-run to flow additional memory scope constraints. * * We proceed in four phases: * @@ -61,14 +76,13 @@ * * Phase 1 * ------- - * We flow constraints from the "on_device" and "device_copy" calls (and some special ops, see - * below) to all other Relay sub-expressions. (For idempotence we also respect any existing - * "param_se_scopes" and "result_se_scope" function attributes we introduce below.) + * We flow constraints from the "on_device" and "device_copy" calls, PrimFunc buffer memory scopes, + * and some special ops, to all other Relay sub-expressions. * * For a primitive such as \code add(e1, e2) \endcode all arguments and results must be on the * same device. However each call site can use a different device. In other words primitives are * 'device polymorphic' since we compile and execute them for each required device. ADT constructors - * are similarly polymorphic. + * are similarly polymorphic, but require all constructor args to be on the same device. * * For most Relay expressions the device for the overall expression is the same as the device * for its sub-expressions. E.g. each field of a tuple must be on the same device as the tuple @@ -90,6 +104,10 @@ * different from each other. Every call to the function must use the same choice of parameter * and result devices -- there is no 'device polymorphism' for Relay functions. * + * Currently \p PrimFuncs and external functions do not carry over their parameter and result + * devices from their original Relay Function representations. However we know all calls to those + * functions are device-consistent, thus no information is lost. + * * Phase 2 * ------- * After flowing constraints we apply some defaulting heuristics (using a global default \p SEScope) @@ -99,6 +117,8 @@ * - Unconstrained let-bound expression devices default to the device for the overall let. * TODO(mbs): These are very simple minded heuristics, and ultimately we'd like to treat the * assignment of the remaining unconstrained sub-expressions as an optimiziation problem in itself. + * This requires a formal notion of 'choicepoint' inside the compiler which can integrate with + * automation. * * Phase 3 * ------- @@ -108,10 +128,15 @@ * the function's parameters and the result. * - Additional "device_copy" CallNodes where a copy is required in order to respect the * intent of the original "on_device" CallNodes. - * - Additional "on_device" CallNodes where the device type of an expression does not match - * that of the lexically enclosing "on_device" CallNode or function attribute. In practice + * - Additional "on_device" CallNodes where the device type of an expression is not trivially + * implied by the lexically enclosing "on_device" CallNode or function attribute. In practice * this means "on_device" CallNodes may appear in two places: - * - On a let-bound expression if its device differs from the overall let expression. + * - On let-bound expressions. It is tempting to elide the "on_device" if the let-bound value + * has the same device as the overall let expression. However this would mean passes which + * inline let-bound values, such as FoldConstant and DeadCodeElimination, would need to us + * a DeviceAware visitor which in turn requires the expression to be in ANF to avoid + * deep recursion. To minimize disruption we always include the "on_device" so that it + * can follow the inline. * - On a call argument if its device differs from the call result. In particular, the * argument to a "device_copy" call will always be wrapped in an "on_device". (That may * seem pedantic but simplifies downstream handling.) @@ -132,8 +157,7 @@ * f(on_device(%x1, se_scope=CPU)) * \endcode * - * This pass can be run before FuseOps it can use device-specific fusion rules. - * TODO(mbs): We also need to support running after FuseOps. + * This pass can be run before FuseOps so that it can use device-specific fusion rules. * * 'Stored on' vs 'Executes on' * ---------------------------- @@ -227,17 +251,13 @@ * | * `-- Mark's stamp of completeness :-) * - * TODO(mbs): - * * Proper diagnostics for unification failure using spans. - * * Support running the pass post FuseOps (so need to understand primitive functions, both - * outlines and lined) and post the VM transforms (probably need to support more intrinsic - * forms?). - * * Don't hardcode the 'CPU' device for shape funcs etc, and distinguish between the default - * device for primitives vs the default device for the rest of Relay. - * * We may want some 'device polymorphism' for Relay functions. Eg it's ok for the function - * to be called with params/result on different (virtual) device ids provided the target and - * memory scopes are consistent. - * * Switch to expr.CopyWith(...) form once implemented to avoid unnecessary copies. + * TODO(mbs): Proper diagnostics for unification failure using spans. + * TODO(mbs): Support running the pass post FuseOps (so need to understand primitive functions, both + * outlines and lined) and post the VM transforms (probably need to support more intrinsic + * forms?). + * TODO(mbs): We may want some 'device polymorphism' for Relay functions. Eg it's ok for the + * function to be called with params/result on different (virtual) device ids provided the target + * and memory scopes are consistent. */ #include @@ -252,6 +272,7 @@ #include #include #include +#include #include @@ -266,23 +287,39 @@ namespace transform { namespace { -/****** -******* Phase 0 -*******/ +/* =============== Phase 0 =============== */ /*! * \brief Rewrites "on_device" calls to handle some special cases. * - * \code - * let %x = on_device(e, se_scope=d) - * ==> let %x = on_device(e, se_scope=d, is_fixed=True) + * - Don't let the device for %x remain unconstrained: + * \code + * let %x = on_device(e, se_scope=d) + * ==> let %x = on_device(e, se_scope=d, constraint=kBoth) + * \endcode * - * fn(%x) { on_device(e, se_scope=d) } - * ==> fn(%x) { on_device(e, se_scope=d, is_fixed=True) + * - Don't let the function result remain unconstrained: + * \code + * fn(%x) { on_device(e, se_scope=d) } + * ==> fn(%x) { on_device(e, se_scope=d, constraint=kBoth) + * \endcode * - * on_device(e).0 - * ==> on_device(e.0) - * \endcode + * - Project-then-copy rather than copy-then-project: + * \code + * on_device(e).0 + * ==> on_device(e.0) + * \endcode + * + * - Be prepared to copy arguments and results on primitive call boundaries in case memory + * scopes don't line up. We'll use the 'fully unconstrained' version of on_device so that + * we can allow for a device_copy without knowing the specific device for the arguments: + * \code + * call_lowered(@prim, (a, b)) + * ==> copy_ok(call_lowered(@prim, (copy_ok(a), copy_ok(b)))) + * where + * copy_ok(x) = on_device(x, se_scope=SEScope::FullyUnconstrained, + * constrain_body=False, constrain_result=False) + * \endcode */ class RewriteOnDevices : public ExprMutator { public: @@ -294,11 +331,11 @@ class RewriteOnDevices : public ExprMutator { OnDeviceProps props = GetOnDeviceProps(tuple); Expr tuple_get_item = WithFields(GetRef(tuple_get_item_node), std::move(tuple)); - if (props.body.defined() && !props.is_fixed) { + if (props.body.defined() && props.is_normal()) { VLOG(2) << "wrapping tuple get item:" << std::endl << PrettyPrint(GetRef(tuple_get_item_node)) << std::endl << "with \"on_device\" for SEScope " << props.se_scope; - return OnDevice(tuple_get_item, props.se_scope, /*is_fixed=*/false); + return OnDeviceWithProps(tuple_get_item, props); } else { return tuple_get_item; } @@ -311,19 +348,19 @@ class RewriteOnDevices : public ExprMutator { Let inner_let = GetRef(inner_let_node); Expr value = VisitExpr(inner_let_node->value); OnDeviceProps props = GetOnDeviceProps(value); - if (props.body.defined() && !props.is_fixed) { + if (props.body.defined() && props.is_normal()) { VLOG(2) << "revising let-bound expression of let:" << std::endl << PrettyPrint(expr) << std::endl << "to be fixed to SEScope " << props.se_scope; - value = OnDevice(props.body, props.se_scope, /*is_fixed=*/true); + value = MaybeOnDeviceFixed(props.body, props.se_scope); } bindings.emplace_back(inner_let, value); expr = inner_let_node->body; } expr = VisitExpr(expr); for (auto itr = bindings.rbegin(); itr != bindings.rend(); ++itr) { - expr = WithFields(/*let=*/std::move(std::get<0>(*itr)), /*var = unchanged*/ {}, - /*value=*/std::move(std::get<1>(*itr)), /*body=*/std::move(expr)); + expr = WithFields(/*let=*/std::move(std::get<0>(*itr)), /*opt_var=*/{}, + /*opt_value=*/std::move(std::get<1>(*itr)), /*opt_body=*/std::move(expr)); } return expr; } @@ -331,20 +368,35 @@ class RewriteOnDevices : public ExprMutator { Expr VisitExpr_(const FunctionNode* function_node) final { Expr body = VisitExpr(function_node->body); OnDeviceProps props = GetOnDeviceProps(body); - if (props.body.defined() && !props.is_fixed) { + if (props.body.defined() && props.is_normal()) { VLOG(2) << "revising body of function:" << std::endl << PrettyPrint(GetRef(function_node)) << std::endl << "to be fixed to SEScope " << props.se_scope; - body = OnDevice(props.body, props.se_scope, /*is_fixed=*/true); + body = MaybeOnDeviceFixed(props.body, props.se_scope); + } + return WithFields(GetRef(function_node), function_node->params, std::move(body)); + } + + Expr VisitExpr_(const CallNode* call_node) final { + CallLoweredProps props = GetCallLoweredProps(call_node); + if (props.lowered_func.defined()) { + VLOG(2) << "allowing device_copy on PrimFunc arguments"; + Array new_args; + new_args.reserve(props.arguments.size()); + for (const auto& arg : props.arguments) { + new_args.push_back(OnDeviceCopyOk(VisitExpr(arg))); + } + Call new_call = CallLowered(std::move(props.lowered_func), std::move(new_args), props.attrs, + call_node->span); + return OnDeviceCopyOk(std::move(new_call)); + + } else { + return ExprMutator::VisitExpr_(call_node); } - return WithFields(GetRef(function_node), std::move(function_node->params), - std::move(body)); } }; -/****** -******* Phase 1 -*******/ +/* =============== Phase 1 =============== */ /* * \brief Collects the system of device constraints for all sub-expressions in a module. @@ -374,28 +426,133 @@ class DeviceAnalyzer : public ExprVisitor { */ std::unique_ptr Analyze() { VLOG_CONTEXT << "DeviceAnalyzer"; - for (const auto& pair : mod_->functions) { - VLOG(2) << "collecting constraints for '" << PrettyPrint(pair.first) << "'"; - domains_->UnifyExprExact(pair.first, pair.second); - VisitExpr(pair.second); + for (const auto& kv : mod_->functions) { + // The global variable and what it is bound to must obviously agree on domain. + if (const auto* function_node = AsOptimizableFunctionNode(kv.second)) { + VLOG(2) << "collecting constraints from Relay Function '" << kv.first->name_hint << "'"; + domains_->UnifyExprExact(kv.first, kv.second); + VisitExpr(GetRef(function_node)); + } else if (const auto* prim_func_node = kv.second.as()) { + VLOG(2) << "collecting constraints from TIR PrimFunc '" << kv.first->name_hint << "'"; + domains_->UnifyExprExact( + kv.first, + DomainForPrimFunc(kv.first->checked_type(), GetRef(prim_func_node))); + } else { + VLOG(2) << "skipping '" << kv.first->name_hint << "'"; + } } return std::move(domains_); } private: + /*! + * \brief Return the (consistent) memory scope to use for the parameter or result of \p type, + * given we are currently looking at \p prim_func_memory_scopes[curr_scope_index] in the + * overall memory scopes extracted from the \p PrimFunc. Fails if memory scopes are not + * consistent (eg \p type is a tuple for which the \p PrimFunc is attempting to map to + * multiple memory scopes after flattening). Returns empty if no memory scopes constraints + * arise from \p PrimFunc. + */ + static std::string ConsistentMemoryScope(const std::vector& prim_func_memory_scopes, + size_t* curr_scope_index, const Type& type) { + std::string memory_scope; + for (size_t i = 0; i < FlattenTupleType(type).size(); ++i) { + ICHECK_LT(*curr_scope_index, prim_func_memory_scopes.size()) + << "mismatch between PrimFunc and Function arguments"; + const std::string& tensor_memory_scope = prim_func_memory_scopes[(*curr_scope_index)++]; + if (memory_scope.empty()) { + memory_scope = tensor_memory_scope; + } else if (tensor_memory_scope.empty()) { + // No constraint. + } else { + // Tuples must be homogenous on their SEScope and thus memory scope. + ICHECK_EQ(tensor_memory_scope, memory_scope); + } + } + return memory_scope; + } + + /*! + * \brief Return the domain representing \p prim_func which, before lowering, had + * the Relay \p type. + */ + DeviceDomainPtr DomainForPrimFunc(const Type& type, const tir::PrimFunc& prim_func) { + auto func_domain = domains_->DomainFor(prim_func); // higher-order + + // TODO(mbs): We don't visit the body of the functions -- there's currently nothing to be done. + const auto* func_type_node = type.as(); + ICHECK(func_type_node); + + // Extract the memory scopes from all the prim_func's buffers. Note that the prim_func + // is in flattened and DPS form so we'll need to work backwards to project these scopes back + // into the appropriate Relay argument. + std::vector prim_func_memory_scopes; + for (const auto& var : prim_func->params) { + auto itr = prim_func->buffer_map.find(var); + if (itr == prim_func->buffer_map.end()) { + VLOG(1) << "no buffer map entry for '" << var->name_hint << "'"; + continue; + } + if (!(*itr).second->data->type_annotation.defined()) { + VLOG(1) << "no type annotation for '" << var->name_hint << "'"; + continue; + } + const auto* pointer_type_node = (*itr).second->data->type_annotation.as(); + if (pointer_type_node == nullptr) { + VLOG(1) << "not a pointer type for '" << var->name_hint << "'"; + continue; + } + const std::string& memory_scope = pointer_type_node->storage_scope; + // An empty memory_scope signals no constraint, which matched the 'unconstrained' value + // in SEScope. + VLOG(2) << "prim_func_memory_scopes[" << prim_func_memory_scopes.size() << "] = '" + << memory_scope << "'"; + prim_func_memory_scopes.emplace_back(memory_scope); + } + + // Build the implied domain (in terms of the function's Relay type) implied by any memory scope + // constrains in the function's buffers, for both arguments and results. + std::vector args_and_result_domains; + args_and_result_domains.reserve(func_type_node->arg_types.size() + 1); + size_t curr_scope_index = 0; + + // For each Relay parameter... + for (const auto& param_type : func_type_node->arg_types) { + std::string param_memory_scope = + ConsistentMemoryScope(prim_func_memory_scopes, &curr_scope_index, param_type); + args_and_result_domains.push_back( + domains_->MakeFirstOrderDomain(SEScope::ForMemoryScope(param_memory_scope))); + } + + // For the Relay result... + std::string ret_memory_scope = + ConsistentMemoryScope(prim_func_memory_scopes, &curr_scope_index, func_type_node->ret_type); + args_and_result_domains.push_back( + domains_->MakeFirstOrderDomain(SEScope::ForMemoryScope(ret_memory_scope))); + + // All PrimFunc pointer args should have been accounted for. + ICHECK_EQ(curr_scope_index, prim_func_memory_scopes.size()) + << "mismatch between PrimFunc and Function arguments"; + + return domains_->MakeHigherOrderDomain(std::move(args_and_result_domains)); + } + void VisitExpr_(const CallNode* call_node) final { auto call = GetRef(call_node); + // We don't care if the call is in pre- or post-lowered form. + auto vanilla_call = GetAnyCall(call_node); + // Find the higher-order domain for the callee. See DomainForCallee for the special rules // for primitives. - VisitExpr(call_node->op); + VisitExpr(vanilla_call->op); auto func_domain = domains_->DomainForCallee(call); // higher-order // Build the domain for the function implied by its arguments and call context. - ICHECK_EQ(func_domain->function_arity(), call_node->args.size()); + ICHECK_EQ(func_domain->function_arity(), vanilla_call->args.size()) << PrettyPrint(call); std::vector args_and_result_domains; - args_and_result_domains.reserve(call_node->args.size() + 1); - for (const auto& arg : call_node->args) { + args_and_result_domains.reserve(vanilla_call->args.size() + 1); + for (const auto& arg : vanilla_call->args) { args_and_result_domains.emplace_back(domains_->DomainFor(arg)); VisitExpr(arg); } @@ -416,9 +573,9 @@ class DeviceAnalyzer : public ExprVisitor { LOG(FATAL) << "Function parameters and result SEScopes do not match those of call. Call:" << std::endl << PrettyPrint(call) << std::endl - << "with function scopes:" << std::endl + << "with function virtual devices:" << std::endl << domains_->ToString(func_domain) << std::endl - << "and implied call scopes:" << std::endl + << "and implied call virtual devices:" << std::endl << domains_->ToString(implied_domain); } @@ -494,9 +651,9 @@ class DeviceAnalyzer : public ExprVisitor { << "Function SEScopes are incompatible with its \"on_device\" annotation. Function:" << std::endl << PrettyPrint(function) << std::endl - << "with function scopes:" << std::endl + << "with function virtual devices:" << std::endl << domains_->ToString(func_domain) << std::endl - << "and annotation scopes:" << std::endl + << "and annotation virtual devices:" << std::endl << domains_->ToString(annotation_domain); } } @@ -621,9 +778,7 @@ class DeviceAnalyzer : public ExprVisitor { std::unique_ptr domains_; }; -/****** -******* Phase 2 -*******/ +/* =============== Phase 2 =============== */ /*! * \brief Ensures every sub-expression in a module has a device type, using both the global @@ -637,7 +792,7 @@ class DeviceAnalyzer : public ExprVisitor { * \endcode * we know the parameter \p %z must be on device \p d, but the devices for \p %x and \p %y, * and the device for the function result, are still 'free'. The global 'default' device type - * is first used to 'fix' \p @main's result type, which in turn 'fixes' \p %x and \p %y, which + * is first used to 'fix' \p main's result type, which in turn 'fixes' \p %x and \p %y, which * in turn 'fixes' the device on which the \p add and \p multiply are executed. * * TODO(mbs): I think this is deterministic? We do however visit the top-level defs in hashmap @@ -651,9 +806,13 @@ class DeviceDefaulter : public ExprVisitor { std::unique_ptr Default() { VLOG_CONTEXT << "DeviceDefaulter"; VLOG(0) << "defaulting to SEScope " << domains_->config()->default_primitive_se_scope; - for (const auto& pair : mod_->functions) { - VLOG(2) << "defaulting devices for '" << PrettyPrint(pair.first) << "'"; - VisitExpr(pair.second); + for (const auto& kv : mod_->functions) { + if (const auto* function_node = AsOptimizableFunctionNode(kv.second)) { + VLOG(2) << "defaulting devices for '" << kv.first->name_hint << "'"; + VisitExpr(GetRef(function_node)); + } else { + VLOG(2) << "skipping '" << kv.first->name_hint << "'"; + } } return std::move(domains_); } @@ -678,8 +837,12 @@ class DeviceDefaulter : public ExprVisitor { void VisitExpr_(const CallNode* call_node) final { auto call = GetRef(call_node); + + // We don't care if the call is pre- or post-lowered. + auto vanilla_call = GetAnyCall(call_node); + auto func_domain = domains_->DomainForCallee(call); // higher-order - ICHECK_EQ(func_domain->function_arity(), call_node->args.size()); + ICHECK_EQ(func_domain->function_arity(), vanilla_call->args.size()); if (!domains_->IsFullyConstrained(func_domain)) { // For calls to Relay functions this step is identical to that for VisitExpr_(FunctionNode*) // above. But for calls to primitives we may still need to force free domains to be @@ -720,9 +883,7 @@ class DeviceDefaulter : public ExprVisitor { std::unique_ptr domains_; }; -/****** -******* Phase 3 -*******/ +/* =============== Phase 3 =============== */ /*! * \brief Inserts missing "device_copy" CallNodes, and ensures the device type of every @@ -767,9 +928,14 @@ class DeviceCapturer : public ExprMutator { VLOG_CONTEXT << "CaptureDevices"; IRModule result(/*functions=*/{}, mod_->type_definitions, mod_->Imports(), mod_->source_map, mod_->attrs); - for (const auto& pair : mod_->functions) { - VLOG(2) << "capturing devices for '" << PrettyPrint(pair.first) << "'"; - result->Add(pair.first, Downcast(Mutate(pair.second))); + for (const auto& kv : mod_->functions) { + if (const auto* function_node = AsOptimizableFunctionNode(kv.second)) { + VLOG(2) << "capturing devices for '" << kv.first->name_hint << "'"; + result->Add(kv.first, Downcast(Mutate(GetRef(function_node)))); + } else { + VLOG(2) << "skipping '" << kv.first->name_hint << "'"; + result->Add(kv.first, kv.second); + } } return result; } @@ -825,6 +991,11 @@ class DeviceCapturer : public ExprMutator { Expr VisitExpr_(const CallNode* call_node) final { auto call = GetRef(call_node); + + // We don't care if the call is pre- or post-lowered + // (However we'll preserve the form in the result below.) + auto vanilla_call = GetAnyCall(call_node); + SEScope call_se_scope = GetSEScope(call); auto on_device_props = GetOnDeviceProps(call_node); @@ -854,7 +1025,7 @@ class DeviceCapturer : public ExprMutator { auto func_domain = domains_->DomainForCallee(call); // higher-order VLOG(2) << "considering call:" << std::endl << PrettyPrint(call) << std::endl - << "in scope " << call_se_scope << " with function domain:" << std::endl + << "in scope " << call_se_scope << " with function virtual devices:" << std::endl << domains_->ToString(func_domain); SEScope result_se_scope = domains_->ResultSEScope(func_domain); ICHECK(!result_se_scope->IsFullyUnconstrained()); @@ -863,25 +1034,32 @@ class DeviceCapturer : public ExprMutator { Expr op = VisitChild( /*lexical_se_scope=*/call_se_scope, /*expected_se_scope=*/call_se_scope, - /*child_se_scope=*/result_se_scope, call_node->op); + /*child_se_scope=*/result_se_scope, vanilla_call->op); // Each argument can be on the device for the corresponding function parameter. However if // any of those differ from the overall call device then wrap them in an "on_device" to // help downstream transforms track devices lexically. Array args; - args.reserve(call_node->args.size()); - ICHECK_EQ(func_domain->function_arity(), call->args.size()); - for (size_t i = 0; i < call_node->args.size(); ++i) { + args.reserve(vanilla_call->args.size()); + ICHECK_EQ(func_domain->function_arity(), vanilla_call->args.size()); + for (size_t i = 0; i < vanilla_call->args.size(); ++i) { SEScope param_se_scope = domains_->ResultSEScope(func_domain->function_param(i)); ICHECK(!param_se_scope->IsFullyUnconstrained()) << "for parameter " << i << " for call:" << std::endl << PrettyPrint(call); args.push_back(VisitChild(/*lexical_se_scope=*/call_se_scope, /*expected_se_scope=*/param_se_scope, - /*child_se_scope=*/GetSEScope(call_node->args[i]), - call_node->args[i])); + /*child_se_scope=*/GetSEScope(vanilla_call->args[i]), + vanilla_call->args[i])); + } + + if (call_node->op == CallLoweredOp()) { + Call new_call = + CallLowered(Downcast(op), args, /*call_lowered_attrs=*/{}, /*span=*/{}); + return WithFields(call, std::move(new_call->op), std::move(new_call->args)); + } else { + return WithFields(call, std::move(op), std::move(args)); } - return WithFields(GetRef(call_node), std::move(op), std::move(args)); } Expr VisitExpr_(const LetNode* let_node) final { @@ -895,11 +1073,11 @@ class DeviceCapturer : public ExprMutator { // We have a device transition which needs to be handled. break; } - // The let-bound value can be on a different device than the overall let. However if those - // devices don't agree wrap the let-bound value in an "on_device" to help downstream - // transforms track devices lexically. + // The let-bound value can be on a different device than the overall let. + // By using the fully-unconstrained SEScope for the 'lexical' scope we'll force the let-bound + // value to *always* be wrapped by an "on_device" (see introductory comment for motivation.) Expr value = - VisitChild(/*lexical_se_scope=*/let_se_scope, + VisitChild(/*lexical_se_scope=*/SEScope::FullyUnconstrained(), /*expected_se_scope=*/GetSEScope(inner_let_node->var), /*child_se_scope=*/GetSEScope(inner_let_node->value), inner_let_node->value); bindings.emplace_back(inner_let_node->var, value, inner_let_node->span); @@ -967,7 +1145,7 @@ class DeviceCapturer : public ExprMutator { OnDeviceProps props = GetOnDeviceProps(expr); Expr true_expr = props.body.defined() ? props.body : expr; ICHECK(domains_->contains(true_expr)); - // If expr is higher order we'll return only the result domain's SEScope. + // If expr is higher order we'll return only the result domain's device. SEScope se_scope = domains_->ResultSEScope(domains_->DomainFor(true_expr)); ICHECK(!se_scope->IsFullyUnconstrained()) << "no SEScope was determined for expression:" << std::endl @@ -1003,7 +1181,6 @@ class DeviceCapturer : public ExprMutator { */ Expr VisitChild(const SEScope& lexical_se_scope, const SEScope& expected_se_scope, const SEScope& child_se_scope, const Expr& child) { - ICHECK(!lexical_se_scope->IsFullyUnconstrained()); ICHECK(!expected_se_scope->IsFullyUnconstrained()); if (child->IsInstance() || child->IsInstance()) { // Primitive operators and contructors don't need to be rewritten and can have a @@ -1012,19 +1189,19 @@ class DeviceCapturer : public ExprMutator { } Expr result = VisitExpr(child); if (child_se_scope != expected_se_scope) { - VLOG(2) << "creating " << DeviceCopyOp()->name << " from scope " << child_se_scope - << " to scope " << expected_se_scope << " for:" << std::endl + VLOG(2) << "creating " << DeviceCopyOp()->name << " from virtual device " << child_se_scope + << " to virtual device " << expected_se_scope << " for:" << std::endl << PrettyPrint(result); // Also wrap the child in an "on_device" so downstream transforms can track devices // lexically. - result = MaybeOnDevice(result, child_se_scope, /*is_fixed=*/true); + result = MaybeOnDeviceFixed(result, child_se_scope); result = DeviceCopy(result, child_se_scope, expected_se_scope); } if (expected_se_scope != lexical_se_scope) { - VLOG(2) << "creating " << OnDeviceOp()->name << " for scope " << expected_se_scope + VLOG(2) << "creating " << OnDeviceOp()->name << " for virtual device " << expected_se_scope << " for:" << std::endl << PrettyPrint(result); - result = MaybeOnDevice(result, expected_se_scope, /*is_fixed=*/true); + result = MaybeOnDeviceFixed(result, expected_se_scope); } return result; } @@ -1078,9 +1255,7 @@ tvm::transform::Pass PlanDevicesCore(CompilationConfig config) { } // namespace -/****** -******* Overall composite Pass -*******/ +/* =============== Driver =============== */ // This function is declared in the public . tvm::transform::Pass PlanDevices(CompilationConfig config) { diff --git a/src/relay/transforms/fold_constant.cc b/src/relay/transforms/fold_constant.cc index 05ee9d5ad5921..831d28b485408 100644 --- a/src/relay/transforms/fold_constant.cc +++ b/src/relay/transforms/fold_constant.cc @@ -145,7 +145,7 @@ class ConstantFolder : public MixedModeMutator { Expr Rewrite_(const CallNode* pre_call_node, const Expr& post) final { Call pre_call = GetRef(pre_call_node); if (inside_primitive_) { - return pre_call; + return std::move(pre_call); } Call post_call = Downcast(post); @@ -215,12 +215,12 @@ class ConstantFolder : public MixedModeMutator { OnDeviceProps props = GetOnDeviceProps(post_tuple_get_item_node->tuple); if (props.body.defined()) { // (on_device((x, y, z), se_scope=D).1 ==> on_device(y, se_scope=D) - return MaybeOnDevice(result, props.se_scope, props.is_fixed); + return MaybeOnDeviceWithProps(result, props); } else { return result; } } - return std::move(post_tuple_get_item); + return post_tuple_get_item; } // Convert value to expression. diff --git a/src/relay/transforms/let_list.h b/src/relay/transforms/let_list.h index 56875f6c16a16..f449d6c3b011e 100644 --- a/src/relay/transforms/let_list.h +++ b/src/relay/transforms/let_list.h @@ -79,7 +79,7 @@ class LetList { * * \return a Var that hold the inserted expr. */ - Var Push(Expr expr, Type ty) { return Push(Var("x", ty), expr); } + Var Push(Expr expr, Type ty) { return Push(Var::GenSym(ty), expr); } /*! * \brief insert a binding. diff --git a/src/relay/transforms/memory_alloc.cc b/src/relay/transforms/memory_alloc.cc index 00be629eabff1..25827d5e918da 100644 --- a/src/relay/transforms/memory_alloc.cc +++ b/src/relay/transforms/memory_alloc.cc @@ -69,6 +69,8 @@ class DialectRewriter : public transform::DeviceAwareExprMutator { Function Rewrite(const Function& expr) { return Downcast(Mutate(expr)); } private: + using ExprMutator::VisitExpr_; + Expr VisitExpr_(const TupleNode* tuple_node) final { LetList& scope = scopes_.back(); Array new_fields; @@ -77,8 +79,10 @@ class DialectRewriter : public transform::DeviceAwareExprMutator { for (auto field : tuple_node->fields) { auto new_field = Mutate(field); if (new_field->IsInstance()) { + SEScope se_scope = GetSEScope(field); + ICHECK(!se_scope->IsFullyUnconstrained()); Var const_var("const", Type(nullptr)); - new_field = scope.Push(const_var, new_field); + new_field = scope.Push(const_var, MaybeOnDeviceFixed(new_field, se_scope)); } new_fields.push_back(new_field); } @@ -89,7 +93,9 @@ class DialectRewriter : public transform::DeviceAwareExprMutator { std::pair PreVisitLetBinding_(const Var& var, const Expr& value) final { Expr new_value = Mutate(value); - scopes_.back().Push(var, new_value); + SEScope se_scope = GetSEScope(value); + ICHECK(!se_scope->IsFullyUnconstrained()); + scopes_.back().Push(var, MaybeOnDeviceFixed(new_value, se_scope)); // Since we always need a let block on which to bind sub-expressions the rewritten bindings // are tracked in the current scopes. But return the rewritten binding anyway. return {var, new_value}; @@ -127,6 +133,7 @@ class DialectRewriter : public transform::DeviceAwareExprMutator { VLOG(1) << "converting lowered call to DPS:" << std::endl << PrettyPrint(call); SEScope se_scope = GetSEScope(call); + ICHECK(!se_scope->IsFullyUnconstrained()); LetList& scope = scopes_.back(); std::vector new_args; @@ -176,7 +183,7 @@ class DialectRewriter : public transform::DeviceAwareExprMutator { Expr invoke = InvokeTVMOp(call_lowered_props.lowered_func, ins, outs, Downcast(call_lowered_props.attrs.metadata.at("relay_attrs"))); - scope.Push(OnDevice(invoke, se_scope, /*is_fixed=*/true)); + scope.Push(MaybeOnDeviceFixed(invoke, se_scope)); return ToTupleType(ret_type, std::vector(outputs.begin(), outputs.end())); } @@ -192,8 +199,7 @@ class DialectRewriter : public transform::DeviceAwareExprMutator { /*! Returns an \p alloc_tensor call for a tensor of \p shape and \p dtype over \p storage. */ inline Expr AllocTensor(const Expr& storage, tvm::relay::Expr shape, DataType dtype, Array assert_shape) { - Expr offset = OnDevice(MakeConstantScalar(DataType::Int(64), 0), host_se_scope_, - /*is_fixed=*/true); + Expr offset = MaybeOnDeviceFixed(MakeConstantScalar(DataType::Int(64), 0), host_se_scope_); return tvm::relay::AllocTensor(storage, std::move(offset), std::move(shape), dtype, assert_shape); } @@ -236,22 +242,20 @@ class DialectRewriter : public transform::DeviceAwareExprMutator { CHECK(imm) << "expect static int shape"; int_shape.push_back(imm->value); } - Expr shape = OnDevice(MakeConstant(int_shape), host_se_scope_, /*is_fixed=*/true); - Expr size = OnDevice(ComputeStorage(type), host_se_scope_, /*is_fixed=*/true); + Expr shape = MaybeOnDeviceFixed(MakeConstant(int_shape), host_se_scope_); + Expr size = MaybeOnDeviceFixed(ComputeStorage(type), host_se_scope_); // Alignment is directly captured in the instruction rather than calculated, so we // don't want to wrap it with an "on_device". Expr alignment = ComputeAlignment(type->dtype); // Run type inference later to get the correct type. Var var("storage_" + name_hint, Type(nullptr)); - Expr value = OnDevice(AllocStorage(size, alignment, se_scope, type->dtype), se_scope, - /*is_fixed=*/true); - auto sto = scope->Push(var, value); + Expr value = AllocStorage(size, alignment, se_scope, type->dtype); + auto sto = scope->Push(var, MaybeOnDeviceFixed(value, se_scope)); // TODO(@jroesch): There is a bug with typing based on the constant shape. - auto tensor = OnDevice(AllocTensor(sto, shape, type->dtype, /*assert_shape=*/type->shape), - se_scope, /*is_fixed=*/true); + auto tensor = AllocTensor(sto, shape, type->dtype, /*assert_shape=*/type->shape); Var tensor_var("tensor_" + name_hint, Type(nullptr)); - return scope->Push(tensor_var, tensor); + return scope->Push(tensor_var, MaybeOnDeviceFixed(tensor, se_scope)); } /*! @@ -287,23 +291,24 @@ class DialectRewriter : public transform::DeviceAwareExprMutator { if (state == tec::kNeedInputShape) { std::vector exprs = FromTupleType(ty, arg); for (size_t j = 0; j < exprs.size(); ++j) { - Expr sh_of = Mutate(ShapeOf(exprs[j])); // already accounts for device + Expr sh_of = Mutate(ShapeOf(exprs[j])); Var in_shape_var("in_shape_" + std::to_string(input_pos + j), Type(nullptr)); - shape_func_ins.push_back(scope->Push(in_shape_var, sh_of)); + shape_func_ins.push_back( + scope->Push(in_shape_var, MaybeOnDeviceFixed(sh_of, host_se_scope_))); input_pos++; } } else if (state == tec::kNeedInputData) { auto new_arg = Mutate(arg); // already accounts for device SEScope arg_se_scope = GetSEScope(arg); + ICHECK(!arg_se_scope->IsFullyUnconstrained()); // The dynamic shape function is expecting its data on the host/CPU, so insert a // device_copy otherwise. (We'll need to fuse & lower these copies in the same way // we fuse & lower other operators we insert for, eg, dynamic tensor size calculation.) - if (arg_se_scope != host_se_scope_) { - new_arg = OnDevice(DeviceCopy(new_arg, arg_se_scope, host_se_scope_), host_se_scope_, - /*is_fixed=*/true); - } + new_arg = MaybeDeviceCopy(MaybeOnDeviceFixed(new_arg, arg_se_scope), arg_se_scope, + host_se_scope_); Var in_shape_var("in_shape_" + std::to_string(input_pos), Type(nullptr)); - shape_func_ins.push_back(scope->Push(in_shape_var, new_arg)); + shape_func_ins.push_back( + scope->Push(in_shape_var, MaybeOnDeviceFixed(new_arg, host_se_scope_))); input_pos++; } else { // TODO(@jroesch): handle kNeedBoth @@ -322,20 +327,16 @@ class DialectRewriter : public transform::DeviceAwareExprMutator { ICHECK(tensor_type_node); // Put the shape func on the host. This also ensures that everything between // shape_of and shape_func is similarly on the host. - Expr alloc = MakeStaticAllocation(scope, GetRef(tensor_type_node), host_se_scope_, - std::to_string(i)); - // TODO(mbs): Don't really need a fresh var here since alloc will always be a var. - Var shape_func_out_var("shape_func_out_" + std::to_string(i), Type(nullptr)); - alloc = scope->Push(shape_func_out_var, alloc); + Var alloc = MakeStaticAllocation(scope, GetRef(tensor_type_node), host_se_scope_, + "out_shape_" + std::to_string(i)); out_shapes.push_back(alloc); } // Represent the call in DPS form. - auto shape_call = OnDevice(InvokeTVMOp(prim_fn_var, Tuple(shape_func_ins), Tuple(out_shapes), - Downcast(attrs.metadata.at("relay_attrs"))), - host_se_scope_, /*is_fixed=*/true); + auto shape_call = InvokeTVMOp(prim_fn_var, Tuple(shape_func_ins), Tuple(out_shapes), + Downcast(attrs.metadata.at("relay_attrs"))); Var shape_func_var("shape_func", Type(nullptr)); - scope->Push(shape_func_var, shape_call); + scope->Push(shape_func_var, MaybeOnDeviceFixed(shape_call, host_se_scope_)); return out_shapes; } @@ -349,14 +350,12 @@ class DialectRewriter : public transform::DeviceAwareExprMutator { for (size_t i = 0; i < out_shapes.size(); ++i) { auto out_shape = out_shapes[i]; auto out_type = out_types[i]; - auto size = OnDevice(ComputeStorageInRelay(out_shape, out_type), host_se_scope_, - /*is_fixed=*/true); + auto size = MaybeOnDeviceFixed(ComputeStorageInRelay(out_shape, out_type), host_se_scope_); // Alignment is directly captured in the instruction so don't wrap in "on_device". auto alignment = ComputeAlignment(out_type->dtype); Var sto_var("storage_" + std::to_string(i), Type(nullptr)); - auto val = OnDevice(AllocStorage(size, alignment, se_scope, out_type->dtype), se_scope, - /*is_fixed=*/true); - storages.push_back(scope->Push(sto_var, val)); + auto val = AllocStorage(size, alignment, se_scope, out_type->dtype); + storages.push_back(scope->Push(sto_var, MaybeOnDeviceFixed(val, se_scope))); } Array outs; @@ -364,17 +363,15 @@ class DialectRewriter : public transform::DeviceAwareExprMutator { auto out_shape = out_shapes[i]; auto out_type = out_types[i]; auto storage = storages[i]; - auto alloc = OnDevice(AllocTensor(storage, out_shape, out_type->dtype, out_type->shape), - se_scope, /*is_fixed=*/true); + auto alloc = AllocTensor(storage, out_shape, out_type->dtype, out_type->shape); Var out_var("out_" + std::to_string(i), Type(nullptr)); - outs.push_back(scope->Push(out_var, alloc)); + outs.push_back(scope->Push(out_var, MaybeOnDeviceFixed(alloc, se_scope))); } Tuple tuple_outs(outs); auto call = InvokeTVMOp(func, ins, tuple_outs, Downcast(attrs.metadata.at("relay_attrs"))); - auto invoke = OnDevice(call, se_scope, /*is_fixed=*/true); - scope->Push(invoke); + scope->Push(MaybeOnDeviceFixed(call, se_scope)); return ToTupleType(ret_type, std::vector(tuple_outs->fields.begin(), tuple_outs->fields.end())); } @@ -398,7 +395,7 @@ class DialectRewriter : public transform::DeviceAwareExprMutator { CHECK(imm) << "expect static int shape"; shape.push_back(imm->value); } - shape_expr = OnDevice(MakeConstant(shape), host_se_scope_, /*is_fixed=*/true); + shape_expr = MaybeOnDeviceFixed(MakeConstant(shape), host_se_scope_); } return ReshapeTensor(ins->fields[0], shape_expr, ret_ty->shape); } diff --git a/src/relay/transforms/partial_eval.cc b/src/relay/transforms/partial_eval.cc index 8f5e9e146d54b..28d1aa5532bf7 100644 --- a/src/relay/transforms/partial_eval.cc +++ b/src/relay/transforms/partial_eval.cc @@ -842,7 +842,7 @@ class PartialEvaluator : public ExprFunctor }); } - PStatic VisitFunc(const Function& func, LetList* ll, const Var& name = Var("x", Type())) { + PStatic VisitFunc(const Function& func, LetList* ll, const Var& name) { Func f = VisitFuncStatic(func, name); Function u_func = AsFunc(RegisterFuncId(DeDup(AnnotateFuncId(func)))); // TODO(@M.K.): we seems to reduce landin knot into letrec. @@ -851,7 +851,7 @@ class PartialEvaluator : public ExprFunctor } PStatic VisitExpr_(const FunctionNode* op, LetList* ll) final { - return VisitFunc(GetRef(op), ll); + return VisitFunc(GetRef(op), ll, Var::GenSym()); } struct ReflectError : Error { diff --git a/src/relay/transforms/to_a_normal_form.cc b/src/relay/transforms/to_a_normal_form.cc index f958a600551e5..741de6d7ea9bb 100644 --- a/src/relay/transforms/to_a_normal_form.cc +++ b/src/relay/transforms/to_a_normal_form.cc @@ -211,15 +211,15 @@ class Fill : ExprFunctor, private transform::Lexi } Expr Atomic(const Expr& e, const Var& v) { - Expr annotated_expr = MaybeOnDevice(e, GetSEScope(e), /*is_fixed=*/true); + Expr annotated_expr = MaybeOnDeviceFixed(e, GetSEScope(e)); return v.defined() ? GetScope(e)->let_list->Push(v, annotated_expr) : annotated_expr; } // Bind expression `now` to var `v` if the original expression is in the include set, or if // v is already defined (e.g. coming from a Let expression). Otherwise return `now` directly Expr Compound(const Expr& orig, const Expr& now, const Var& v) { - Expr annotated_expr = MaybeOnDevice(now, GetSEScope(orig), /*is_fixed=*/true); - Var var = v.defined() ? v : Var(String("x"), Type()); + Expr annotated_expr = MaybeOnDeviceFixed(now, GetSEScope(orig)); + Var var = v.defined() ? v : Var::GenSym(); bool not_included = include_set_ && include_set_->find(orig) == include_set_->end(); if (!v.defined() && not_included) { return annotated_expr; @@ -230,14 +230,14 @@ class Fill : ExprFunctor, private transform::Lexi Expr VisitExpr_(const CallNode* c, const Var& v) final { OnDeviceProps props = GetOnDeviceProps(c); - if (props.body.defined() && props.is_fixed) { + if (props.body.defined() && props.is_fixed()) { // Keep track of expression device type for lexically enclosing sub-expressions. PushSEScope(props.se_scope); Expr body = VisitExpr(props.body, v); // We are done with this sub-expression. PopSEScope(); // Preserve the "on_device" annotations. - return OnDevice(body, props.se_scope, props.is_fixed); + return OnDeviceWithProps(body, props); } Expr e = GetRef(c); diff --git a/src/runtime/contrib/verilator/verilator_runtime.h b/src/runtime/contrib/verilator/verilator_runtime.h index 588b3f172a3ee..9ef17d7481ab7 100644 --- a/src/runtime/contrib/verilator/verilator_runtime.h +++ b/src/runtime/contrib/verilator/verilator_runtime.h @@ -94,7 +94,7 @@ class VerilatorRuntime : public JSONRuntimeBase { ~VerilatorRuntime(); - const char* type_key() const { return "verilator"; } + const char* type_key() const final { return "verilator"; } /*! \brief set verilator library */ void SetLibrary(const std::string& lib_name); diff --git a/src/runtime/vm/executable.cc b/src/runtime/vm/executable.cc index b613a03bfc5cd..44971c0bcee98 100644 --- a/src/runtime/vm/executable.cc +++ b/src/runtime/vm/executable.cc @@ -178,7 +178,7 @@ std::string Executable::GetConstants() const { for (size_t i = 0; i < constants.size(); ++i) { const auto& constant = constants[i]; auto ndarray = Downcast(constant); - oss << "VM Constant[" << i << "]: has shape " << ShapeString(ndarray.Shape(), ndarray->dtype) + oss << "VM Const[" << i << "]: has shape " << ShapeString(ndarray.Shape(), ndarray->dtype) << " on device index " << const_device_indexes[i] << std::endl; } return oss.str(); diff --git a/src/target/se_scope.cc b/src/target/se_scope.cc index 95d5a7de57752..8e6c6fe7f2a26 100644 --- a/src/target/se_scope.cc +++ b/src/target/se_scope.cc @@ -62,10 +62,6 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "memory_scope='" << node->memory_scope << "'"; } } -#if TVM_LOG_DEBUG - // We rely on object identity of SEScopes, so include the object address to help debugging. - p->stream << ", id=" << reinterpret_cast(ref.get()); -#endif p->stream << ")"; }); @@ -173,11 +169,9 @@ SEScope SEScopeCache::Make(DLDeviceType device_type, int virtual_device_id, Targ SEScope prototype(device_type, virtual_device_id, std::move(target), std::move(memory_scope)); auto itr = cache_.find(prototype); if (itr == cache_.end()) { - VLOG(1) << "added new scope " << prototype; cache_.emplace(prototype); return prototype; } else { - VLOG(1) << "reusing existing scope " << *itr; ICHECK_EQ(prototype->target.defined(), (*itr)->target.defined()); if (prototype->target.defined()) { ICHECK_EQ(prototype->target->host.defined(), (*itr)->target->host.defined()); diff --git a/src/target/target.cc b/src/target/target.cc index 792884061db6c..a5c493a582ab2 100644 --- a/src/target/target.cc +++ b/src/target/target.cc @@ -566,10 +566,6 @@ String TargetNode::ToDebugString() const { if (host.defined()) { os << ", host=" << GetHost().value()->ToDebugString(); } -#if TVM_LOG_DEBUG - // We depend on pointer equality so include that in the debug representation. - os << ", id=" << reinterpret_cast(this); -#endif os << ")"; return os.str(); } diff --git a/src/tir/transforms/lower_cross_thread_reduction.cc b/src/tir/transforms/lower_cross_thread_reduction.cc index 630c00f8c1f10..3ddd59abddec1 100644 --- a/src/tir/transforms/lower_cross_thread_reduction.cc +++ b/src/tir/transforms/lower_cross_thread_reduction.cc @@ -207,7 +207,7 @@ class InThreadReducerMaker : private StmtMutator { if (res->thread_binding.defined()) { return res->body; } else { - return res; + return std::move(res); } } else { return Stmt{nullptr}; @@ -564,7 +564,7 @@ class CrossThreadReductionTransformer : public StmtMutator { } } } - return new_block; + return std::move(new_block); } Stmt VisitStmt_(const BlockRealizeNode* realize) final { diff --git a/tests/python/relay/test_pass_plan_devices.py b/tests/python/relay/test_pass_plan_devices.py index 4ef726fbeac2c..ead4128c6c795 100644 --- a/tests/python/relay/test_pass_plan_devices.py +++ b/tests/python/relay/test_pass_plan_devices.py @@ -23,6 +23,7 @@ import tvm from tvm import relay +from tvm.script import tir as T import tvm.testing import numpy as np @@ -45,6 +46,9 @@ GPU = tvm.target.make_se_scope(GPU_DEVICE, GPU_TARGET) # device_type=2 DEFAULT = GPU +CPU_LOCAL = tvm.target.make_se_scope(CPU_DEVICE, CPU_TARGET, memory_scope="local") +CPU_GLOBAL = tvm.target.make_se_scope(CPU_DEVICE, CPU_TARGET, memory_scope="global") + CTXT = tvm.transform.PassContext(config={"relay.fallback_device_type": DEFAULT.device_type_int}) core = tvm.IRModule() @@ -178,7 +182,7 @@ def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], param_se_scopes=[meta[SEScope][0], meta[SEScope][0], meta[SEScope][1], meta[SEScope][1]], result_se_scope=meta[SEScope][1]) { %0 = add(%a, %b); - %1 = on_device(%0, se_scope=meta[SEScope][0], is_fixed=True); + %1 = on_device(%0, se_scope=meta[SEScope][0], constrain_result=True); %2 = device_copy(%1, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); %3 = add(%c, %d); subtract(%2, %3) @@ -225,7 +229,7 @@ def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], param_se_scopes=[meta[SEScope][0], meta[SEScope][0], meta[SEScope][1], meta[SEScope][1]], result_se_scope=meta[SEScope][1]) { %0 = add(%a, %b); - %1 = on_device(%0, se_scope=meta[SEScope][0], is_fixed=True); + %1 = on_device(%0, se_scope=meta[SEScope][0], constrain_result=True); %2 = device_copy(%1, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); %3 = add(%c, %d); subtract(%2, %3) @@ -272,9 +276,9 @@ def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], param_se_scopes=[meta[SEScope][0], meta[SEScope][0], meta[SEScope][0], meta[SEScope][0]], result_se_scope=meta[SEScope][1]) { %0 = add(%a, %b); - %1 = on_device(%0, se_scope=meta[SEScope][0], is_fixed=True); + %1 = on_device(%0, se_scope=meta[SEScope][0], constrain_result=True); %2 = add(%c, %d); - %3 = on_device(%2, se_scope=meta[SEScope][0], is_fixed=True); + %3 = on_device(%2, se_scope=meta[SEScope][0], constrain_result=True); %4 = device_copy(%1, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); %5 = device_copy(%3, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); subtract(%4, %5) @@ -318,8 +322,8 @@ def expected(): def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], param_se_scopes=[meta[SEScope][0], meta[SEScope][0]], result_se_scope=meta[SEScope][1]) { %0 = add(%a, %b); - %1 = on_device(%0, se_scope=meta[SEScope][0], is_fixed=True); - %2 = on_device(%0, se_scope=meta[SEScope][0], is_fixed=True); + %1 = on_device(%0, se_scope=meta[SEScope][0], constrain_result=True); + %2 = on_device(%0, se_scope=meta[SEScope][0], constrain_result=True); %3 = device_copy(%1, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); %4 = device_copy(%2, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); subtract(%3, %4) @@ -367,8 +371,8 @@ def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], param_se_scopes=[meta[SEScope][0], meta[SEScope][0], meta[SEScope][1], meta[SEScope][1]], result_se_scope=meta[SEScope][1]) { %0 = add(%a, %b); - let %l = on_device(%0, se_scope=meta[SEScope][0], is_fixed=True); - let %r = add(%c, %d); + let %l = on_device(%0, se_scope=meta[SEScope][0], constrain_result=True); + let %r = on_device(add(%c, %d), se_scope=meta[SEScope][1], constrain_result=True); %1 = device_copy(%l, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); subtract(%1, %r) } @@ -473,7 +477,7 @@ def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], add(%x, %y) }; %1 = %f(%a, %b); - %2 = on_device(%1, se_scope=meta[SEScope][0], is_fixed=True); + %2 = on_device(%1, se_scope=meta[SEScope][0], constrain_result=True); %3 = device_copy(%2, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); %4 = add(%c, %d); subtract(%3, %4) @@ -591,7 +595,7 @@ def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], param_se_scopes=[meta[SEScope][0], meta[SEScope][0]], result_se_scope=meta[SEScope][0]) { add(%a, %b) }; - let %t = (%f, %x); + let %t = on_device((%f, %x), se_scope=meta[SEScope][0], constrain_result=True); %0 = %t.1; %1 = %t.0; %1(%0, %y) @@ -650,7 +654,7 @@ def ref(x): def test_shape_of(): metatable = {"SEScope": [HOST, GPU]} - # We need to use is_fixed=True in the on_device call so that the tensor will be on the GPU. Otherwise the + # We need to use constrain_result=True in the on_device call so that the tensor will be on the GPU. Otherwise the # result defaults to the result device for @main which is the CPU, thus forcing a copy. # TODO(mbs): Perhaps the defaulting heuristics are being too clever? def input(): @@ -658,7 +662,7 @@ def input(): """ #[version = "0.0.5"] def @main(%x: Tensor[(?, ?), float32]) { - %0 = on_device(%x, se_scope=meta[SEScope][1], is_fixed=True); + %0 = on_device(%x, se_scope=meta[SEScope][1], constrain_result=True); vm.shape_of(%0, dtype="int64") } """, @@ -744,8 +748,8 @@ def expected(): """ #[version = "0.0.5"] def @main(%sto: Storage[], param_se_scopes=[meta[SEScope][1]], result_se_scope=meta[SEScope][1]) { - %0 = on_device(0, se_scope=meta[SEScope][0], is_fixed=True); - %1 = on_device(meta[relay.Constant][0], se_scope=meta[SEScope][0], is_fixed=True); + %0 = on_device(0, se_scope=meta[SEScope][0], constrain_result=True); + %1 = on_device(meta[relay.Constant][0], se_scope=meta[SEScope][0], constrain_result=True); memory.alloc_tensor(%sto, %0, %1, const_shape=meta[relay.Constant][0], assert_shape=[]) } """, @@ -781,7 +785,7 @@ def expected(): #[version = "0.0.5"] def @main(%x: Tensor[(2, 8), float32], param_se_scopes=[meta[SEScope][1]], result_se_scope=meta[SEScope][1]) { - %0 = on_device(meta[relay.Constant][0], se_scope=meta[SEScope][0], is_fixed=True); + %0 = on_device(meta[relay.Constant][0], se_scope=meta[SEScope][0], constrain_result=True); vm.reshape_tensor(%x, %0, newshape=[2, 4, 2]) } """, @@ -861,9 +865,9 @@ def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], %z: Tensor[( param_se_scopes=[meta[SEScope][0], meta[SEScope][0], meta[SEScope][1]], result_se_scope=meta[SEScope][1]) { %0 = add(%x, %y); - %1 = on_device(%0, se_scope=meta[SEScope][0], is_fixed=True); + %1 = on_device(%0, se_scope=meta[SEScope][0], constrain_result=True); %2 = device_copy(%1, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); - %3 = on_device(%0, se_scope=meta[SEScope][0], is_fixed=True); + %3 = on_device(%0, se_scope=meta[SEScope][0], constrain_result=True); %4 = subtract(%2, %z); %5 = device_copy(%3, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); add(%4, %5) @@ -908,7 +912,7 @@ def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], %z: Tensor[( param_se_scopes=[meta[SEScope][1], meta[SEScope][1], meta[SEScope][0]], result_se_scope=meta[SEScope][0]) { %0 = add(%x, %y); - %1 = on_device(%0, se_scope=meta[SEScope][1], is_fixed=True); + %1 = on_device(%0, se_scope=meta[SEScope][1], constrain_result=True); %2 = device_copy(%1, src_se_scope=meta[SEScope][1], dst_se_scope=meta[SEScope][0]); subtract(%2, %z) } @@ -1010,13 +1014,13 @@ def @main(%data1: Tensor[(1, 64, 56, 56), float32], %data2: Tensor[(1, 64, 56, 5 param_se_scopes=[meta[SEScope][0], meta[SEScope][0], meta[SEScope][0]], result_se_scope=meta[SEScope][0]) { %0 = nn.conv2d(%data1, %weight, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]); - %1 = on_device(%0, se_scope=meta[SEScope][0], is_fixed=True); + %1 = on_device(%0, se_scope=meta[SEScope][0], constrain_result=True); %2 = nn.conv2d(%data2, %weight, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]); - %3 = on_device(%2, se_scope=meta[SEScope][0], is_fixed=True); + %3 = on_device(%2, se_scope=meta[SEScope][0], constrain_result=True); %4 = device_copy(%1, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); %5 = device_copy(%3, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); %6 = add(%4, %5); - %7 = on_device(%6, se_scope=meta[SEScope][1], is_fixed=True); + %7 = on_device(%6, se_scope=meta[SEScope][1], constrain_result=True); %8 = device_copy(%7, src_se_scope=meta[SEScope][1], dst_se_scope=meta[SEScope][0]); nn.conv2d(%8, %weight, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]) } @@ -1061,11 +1065,11 @@ def expected(): def @main(%x: Tensor[(3, 3, 4), float32], param_se_scopes=[meta[SEScope][0]], result_se_scope=meta[SEScope][1]) { %0 = split(%x, indices_or_sections=3); - let %t = on_device(%0, se_scope=meta[SEScope][0], is_fixed=True); + let %t = on_device(%0, se_scope=meta[SEScope][0], constrain_result=True); %1 = %t.0; - %2 = on_device(%1, se_scope=meta[SEScope][0], is_fixed=True); + %2 = on_device(%1, se_scope=meta[SEScope][0], constrain_result=True); %3 = %t.1; - %4 = on_device(%3, se_scope=meta[SEScope][0], is_fixed=True); + %4 = on_device(%3, se_scope=meta[SEScope][0], constrain_result=True); %5 = device_copy(%2, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); %6 = device_copy(%4, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); subtract(%5, %6) @@ -1129,14 +1133,14 @@ def expected(): def @main(%x: Tensor[(5, 7), float32], param_se_scopes=[meta[SEScope][0]], result_se_scope=meta[SEScope][0]) { %0 = negative(%x); - %1 = on_device(%0, se_scope=meta[SEScope][0], is_fixed=True); + %1 = on_device(%0, se_scope=meta[SEScope][0], constrain_result=True); %2 = device_copy(%1, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); - %3 = on_device(%0, se_scope=meta[SEScope][0], is_fixed=True); + %3 = on_device(%0, se_scope=meta[SEScope][0], constrain_result=True); %4 = device_copy(%3, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); %5 = negative(%2); %6 = negative(%4); %7 = add(%5, %6); - %8 = on_device(%7, se_scope=meta[SEScope][1], is_fixed=True); + %8 = on_device(%7, se_scope=meta[SEScope][1], constrain_result=True); %9 = device_copy(%8, src_se_scope=meta[SEScope][1], dst_se_scope=meta[SEScope][0]); negative(%9) } @@ -1199,14 +1203,14 @@ def expected(): def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], param_se_scopes=[meta[SEScope][1], meta[SEScope][1]], result_se_scope=meta[SEScope][0]) { %0 = add(%x, %y); - %1 = on_device(%0, se_scope=meta[SEScope][1], is_fixed=True); + %1 = on_device(%0, se_scope=meta[SEScope][1], constrain_result=True); %2 = device_copy(%1, src_se_scope=meta[SEScope][1], dst_se_scope=meta[SEScope][0]); %3 = negative(%2); - %4 = on_device(%3, se_scope=meta[SEScope][0], is_fixed=True); + %4 = on_device(%3, se_scope=meta[SEScope][0], constrain_result=True); %5 = device_copy(%4, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); %6 = negative(%0); %7 = add(%5, %6); - %8 = on_device(%7, se_scope=meta[SEScope][1], is_fixed=True); + %8 = on_device(%7, se_scope=meta[SEScope][1], constrain_result=True); %9 = device_copy(%8, src_se_scope=meta[SEScope][1], dst_se_scope=meta[SEScope][0]); negative(%9) } @@ -1267,7 +1271,7 @@ def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], param_se_scopes=[meta[SEScope][0], meta[SEScope][0], meta[SEScope][1], meta[SEScope][1]], result_se_scope=meta[SEScope][0]) { %0 = multiply(%c, %d); - %1 = on_device(%0, se_scope=meta[SEScope][1], is_fixed=True); + %1 = on_device(%0, se_scope=meta[SEScope][1], constrain_result=True); %2 = add(%a, %b); %3 = device_copy(%1, src_se_scope=meta[SEScope][1], dst_se_scope=meta[SEScope][0]); subtract(%2, %3) @@ -1294,7 +1298,7 @@ def input(): #[version = "0.0.5"] def @main(%x: bool, %y: Tensor[(5, 7), float32], %z: Tensor[(5, 7), float32]) { let %f = fn (%a) { - %0 = on_device(%y, se_scope=meta[SEScope][0], is_fixed=True); + %0 = on_device(%y, se_scope=meta[SEScope][0], constrain_result=True); add(%a, %0) }; let %g = fn (%a1) { @@ -1326,11 +1330,11 @@ def @main(%x: bool, %y: Tensor[(5, 7), float32], %z: Tensor[(5, 7), float32], let %g = fn (%a1, param_se_scopes=[meta[SEScope][0]], result_se_scope=meta[SEScope][0]) { subtract(%a1, %y) }; - let %h = if (%x) { + let %h = on_device(if (%x) { %f } else { %g - }; + }, se_scope=meta[SEScope][0], constrain_result=True); %h(%z) } """, @@ -1430,9 +1434,9 @@ def expected(): #[version = "0.0.5"] def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], param_se_scopes=[meta[SEScope][1], meta[SEScope][0]], result_se_scope=meta[SEScope][1]) { - let %r = ref(%x); + let %r = on_device(ref(%x), se_scope=meta[SEScope][1], constrain_result=True); %0 = device_copy(%y, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); - ref_write(%r, %0); + on_device(ref_write(%r, %0), se_scope=meta[SEScope][1], constrain_result=True); %1 = ref_read(%r); add(%x, %1) } @@ -1463,7 +1467,7 @@ def input(): Nil, } def @main(%x : Tensor[(5, 7), float32], %y : Tensor[(5, 7), float32]) { - %0 = on_device(%y, se_scope=meta[SEScope][0], is_fixed=True); + %0 = on_device(%y, se_scope=meta[SEScope][0], constrain_result=True); %1 = Nil; %2 = Cons(%0, %1); let %l = Cons(%x, %2); @@ -1489,7 +1493,7 @@ def @main(%x : Tensor[(5, 7), float32], %y : Tensor[(5, 7), float32], param_se_scopes=[meta[SEScope][0], meta[SEScope][0]], result_se_scope=meta[SEScope][0]) { %0 = Nil; %1 = Cons(%y, %0); - let %l = Cons(%x, %1); + let %l = on_device(Cons(%x, %1), se_scope=meta[SEScope][0], constrain_result=True); match? (%l) { Cons(%z, _) => %z } @@ -1507,6 +1511,72 @@ def ref(x, y): exercise(input(), expected(), ref, rands((5, 7), 2)) +def test_lowered(): + """ + Tests propagation of memory scopes from PrimFuncs and insertion + of device_copies to mediate any scope changes. + """ + + @T.prim_func + def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [128, 128], scope="local") + B = T.match_buffer(b, [128, 128], scope="local") + C = T.match_buffer(c, [128, 128], scope="local") + + for i, j, k in T.grid(128, 128, 128): + with T.block("update"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = T.float32(0) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + + metatable = {"SEScope": [CPU_LOCAL, CPU_GLOBAL]} + matmul_ty = relay.FuncType( + [relay.TensorType((128, 128), "float32"), relay.TensorType((128, 128), "float32")], + relay.TensorType((128, 128), "float32"), + ) + matmul_gv = relay.GlobalVar("matmul", type=matmul_ty) + mod = tvm.ir.IRModule() + mod[matmul_gv] = matmul + + def input(): + return tvm.parser.parse( + """ + #[version = "0.0.5"] + def @main(%x : Tensor[(128, 128), float32], + %y : Tensor[(128, 128), float32], + param_se_scopes=[meta[SEScope][0], meta[SEScope][1]], + result_se_scope=meta[SEScope][1]) { + call_lowered(@matmul, (%x, %y)) + } + """, + "from_string", + mod, + metatable, + ) + + def expected(): + return tvm.parser.parse( + """ + #[version = "0.0.5"] + def @main(%x : Tensor[(128, 128), float32], + %y : Tensor[(128, 128), float32], + param_se_scopes=[meta[SEScope][0], meta[SEScope][1]], + result_se_scope=meta[SEScope][1]) { + %0 = device_copy(%y, src_se_scope=meta[SEScope][1], dst_se_scope=meta[SEScope][0]); + %1 = call_lowered(@matmul, (%x, %0)); + %2 = on_device(%1, se_scope=meta[SEScope][0], constrain_result=True); + device_copy(%2, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]) + } + """, + "from_string", + mod, + metatable, + ) + + exercise(input(), expected(), None, None) + + if __name__ == "__main__": import sys import pytest diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py index a35f75bd1aae2..ea1f4ddb3b627 100644 --- a/tests/python/relay/test_vm.py +++ b/tests/python/relay/test_vm.py @@ -1040,8 +1040,8 @@ def @main(%a: Tensor[(5, 7), float32], # - The offset of the tensor within the storage (second arg) to alloc_tensor # Both should be on the CPU assert "VirtualDevice[0]: device type 1" in exe.virtual_devices - assert "Constant[0]: has shape int64[] on device index 0" in exe.constants - assert "Constant[1]: has shape int64[] on device index 0" in exe.constants + assert "Const[0]: has shape int64[] on device index 0" in exe.constants + assert "Const[1]: has shape int64[] on device index 0" in exe.constants @tvm.testing.requires_cuda @@ -1073,7 +1073,7 @@ def @main(%x: Tensor[(2, 8), float32], # The newshape annotation should have been turned into a constant on the CPU. assert "VirtualDevice[0]: device type 1" in exe.virtual_devices - assert "Constant[0]: has shape int64[3] on device index 0" in exe.constants + assert "Const[0]: has shape int64[3] on device index 0" in exe.constants @tvm.testing.requires_cuda