diff --git a/include/tvm/relay/attrs/call.h b/include/tvm/relay/attrs/call.h index 2b02c6a5edac..167a593ff377 100644 --- a/include/tvm/relay/attrs/call.h +++ b/include/tvm/relay/attrs/call.h @@ -39,7 +39,9 @@ struct CallLoweredAttrs : public tvm::AttrsNode { Map metadata; TVM_DECLARE_ATTRS(CallLoweredAttrs, "relay.attrs.CallLoweredAttrs") { - TVM_ATTR_FIELD(metadata).describe("Metadata attached to the lowered function call."); + TVM_ATTR_FIELD(metadata) + .describe("Metadata attached to the lowered function call.") + .set_default(Map()); } }; diff --git a/include/tvm/relay/attrs/on_device.h b/include/tvm/relay/attrs/on_device.h index 405926e209c6..b1f1e6a6dc45 100644 --- a/include/tvm/relay/attrs/on_device.h +++ b/include/tvm/relay/attrs/on_device.h @@ -19,7 +19,7 @@ /*! * \file tvm/relay/attrs/on_device.h - * \brief Attribute for the on device annotation. + * \brief Attribute for the "on_device" annotation (ie operator). */ #ifndef TVM_RELAY_ATTRS_ON_DEVICE_H_ #define TVM_RELAY_ATTRS_ON_DEVICE_H_ @@ -33,9 +33,9 @@ namespace tvm { namespace relay { /*! - * \brief Attributes for the "on_device" special operator. + * \brief Attributes for the "on_device" annotation (ie operator). * - * The Relay call (aka 'annotation'): + * The Relay call: * \code * on_device(sub_expr, se_scope=S) * \endcode @@ -54,44 +54,48 @@ namespace relay { * multiply(device_copy(add(%x, %y), src_se_scope=GPU, dst_se_scope=CPU), %z) * \endcode * - * The Relay call - * \code - * on_device(sub_expr, se_scope=S, is_fixed=True) - * \endcode - * is similar to the above, however the annotation itself must appear in an expression on the - * same \p SEScope \p S. The compiler will check the \p SEScopes are consistent, and will not - * insert any "device_copy" call. This form of annotation shouldn't be necessary in user programs. - * However it is needed by the \p PlanDevices pass to fully specify the results of device planning - * so that the pass is idempotent. - * - * E.g.: The following program is equivalent to the above: - * \code - * let %a = on_device(add(%x, %y), se_scope=GPU, is_fixed=True) - * multiply(device_copy(%a, src_se_scope=GPU, dst_se_scope=CPU), %z) - * \endcode - * The "on_device" annotation with \p is_fixed=True indicates unambiguously that \p %a is stored - * on the GPU. + * The \p constraint_body (default true) and \p constraint_result (default false) fields can be + * used by passes for finer-grained control over how the \p SEScope constraint should be applied. */ struct OnDeviceAttrs : public tvm::AttrsNode { /*! - * \brief (Virtual) \p SEScope on which the result of the argument expression should be stored. + * \brief The \p SEScope to constraint to apply to the body, result, or both body and result + * of the "on_device" call. */ SEScope se_scope = SEScope::FullyUnconstrained(); + + /*! + * \brief If fales (the default), the result of the "on_device" call is not constrained to be + * \p se_scope. + */ + bool constrain_result = false; + + /*! + * \brief If true (the default), the body of the "on_device" call is constrained to be \p + * se_scope. + */ + bool constrain_body = true; + + /*! + * \brief Returns true if both the body and result are constrained. + */ + bool is_fixed() const { return constrain_result && constrain_body; } + /*! - * \brief If true, the result \p SEScope must also be \p se_scope, and device planning should - * not insert any "device_copy" calls to respect this annotation. - * - * This is used by the device planning pass itself when annotating the planned program. + * \brief Returns true only the body is constrained (the 'normal' case). */ - bool is_fixed = false; + bool is_normal() const { return !constrain_result && constrain_body; } TVM_DECLARE_ATTRS(OnDeviceAttrs, "relay.attrs.OnDeviceAttrs") { TVM_ATTR_FIELD(se_scope) - .describe("The (virtual) device and scope holding the expression result.") + .describe("The (virtual) device to constrain to.") .set_default(SEScope::FullyUnconstrained()); - TVM_ATTR_FIELD(is_fixed) - .describe("If true, do not insert a \"device_copy\" call to respect this annotation.") + TVM_ATTR_FIELD(constrain_result) + .describe("Whether the constraint applies to the overall expression") .set_default(false); + TVM_ATTR_FIELD(constrain_body) + .describe("Whether the constraint applies to the body sub-expression.") + .set_default(true); } }; diff --git a/python/tvm/relay/op/annotation/annotation.py b/python/tvm/relay/op/annotation/annotation.py index cf70dc6e267e..cb4e628ebc92 100644 --- a/python/tvm/relay/op/annotation/annotation.py +++ b/python/tvm/relay/op/annotation/annotation.py @@ -31,29 +31,35 @@ def _make_se_scope(device): raise ValueError("expecting a Device or device name, but received a %s" % (type(device))) -def on_device(data, device, is_fixed=False): - """Annotates an expression with the device type on which its result should be stored. +def on_device(body, device, constrain_result=False, constrain_body=True): + """Annotates a body expression with device constraints. The constraint influences + how the body is compiled, where the body is evaluated, and where the result of + evaluation is stored. + + Note that the defaults for the constrain_body and constrain_result parameters should + almost never need to be overridden by the user. These parameters are exposed here + to help unit tests exercise the PlanDevices pass machinery. Parameters ---------- - data : tvm.relay.Expr + body : tvm.relay.Expr The expression to be annotated. device : Union[:py:class:`Device`, str] The device to annotate with. - is_fixed : bool - If false (the default), a device_copy - If true, the annotation does not imply a device_copy may be inserted to - reconcile the device of the data argument with the device for the context of the - annotated expression. + constrain_result : bool + If false (the default), the result of the on_device is not constrained to be on device. + + constrain_body : bool + If true (the default), the body of the on_device is constrained to be on device. Returns ------- result : tvm.relay.Expr The annotated expression. """ - return _make.OnDevice(data, _make_se_scope(device), is_fixed) + return _make.OnDevice(body, _make_se_scope(device), constrain_result, constrain_body) def function_on_device(function, param_devices, result_device): diff --git a/src/printer/relay_text_printer.cc b/src/printer/relay_text_printer.cc index ed5f5f62af94..d0c2cfebbbd8 100644 --- a/src/printer/relay_text_printer.cc +++ b/src/printer/relay_text_printer.cc @@ -502,12 +502,6 @@ Doc RelayTextPrinter::VisitExpr_(const FunctionNode* op) { Doc RelayTextPrinter::VisitExpr_(const GlobalVarNode* op) { Doc doc; doc << "@" << op->name_hint; -#if TVM_LOG_DEBUG - if (op->checked_type_.defined()) { - doc << " /* type=" << PrintType(op->checked_type_, /*meta=*/false) << " */"; - } - doc << " /* id=" << reinterpret_cast(op) << " */"; -#endif return doc; } @@ -521,6 +515,11 @@ Doc RelayTextPrinter::VisitExpr_(const CallNode* op) { for (const Expr& arg : op->args) { args.push_back(Print(arg)); } +#if TVM_LOG_DEBUG + for (const Type& type_arg : op->type_args) { + args.push_back(Print(type_arg)); + } +#endif for (const Doc& d : PrintCallAttrs(op->attrs, op->op)) { args.push_back(d); } diff --git a/src/printer/text_printer.cc b/src/printer/text_printer.cc index 5acb9bd3f1dc..4d4113fef694 100644 --- a/src/printer/text_printer.cc +++ b/src/printer/text_printer.cc @@ -27,6 +27,7 @@ #include +#include #include namespace tvm { @@ -36,49 +37,71 @@ static const char* kSemVer = "0.0.5"; Doc TextPrinter::PrintMod(const IRModule& mod) { Doc doc; int counter = 0; + + // We'll print in alphabetical order to make a/b diffs easier to work with. + // type definitions + std::vector tyvars; for (const auto& kv : mod->type_definitions) { + tyvars.emplace_back(kv.first); + } + std::sort(tyvars.begin(), tyvars.end(), + [](const GlobalTypeVar& left, const GlobalTypeVar& right) { + return left->name_hint < right->name_hint; + }); + for (const auto& tyvar : tyvars) { if (counter++ != 0) { doc << Doc::NewLine(); } - doc << relay_text_printer_.Print(kv.second); + doc << relay_text_printer_.Print(mod->type_definitions[tyvar]); doc << Doc::NewLine(); } + // functions + std::vector vars; for (const auto& kv : mod->functions) { - if (kv.second.as()) { + vars.emplace_back(kv.first); + } + std::sort(vars.begin(), vars.end(), [](const GlobalVar& left, const GlobalVar& right) { + return left->name_hint < right->name_hint; + }); + for (const auto& var : vars) { + const BaseFunc& base_func = mod->functions[var]; + if (base_func.as()) { relay_text_printer_.dg_ = - relay::DependencyGraph::Create(&relay_text_printer_.arena_, kv.second); + relay::DependencyGraph::Create(&relay_text_printer_.arena_, base_func); } if (counter++ != 0) { doc << Doc::NewLine(); } - if (kv.second.as()) { + if (base_func.as()) { std::ostringstream os; - os << "def @" << kv.first->name_hint; -#if TVM_LOG_DEBUG - os << " /* id=" << reinterpret_cast(kv.first.get()) << " */"; -#endif - doc << relay_text_printer_.PrintFunc(Doc::Text(os.str()), kv.second); - } else if (kv.second.as()) { - doc << "@" << kv.first->name_hint; -#if TVM_LOG_DEBUG - doc << " /* id=" << reinterpret_cast(kv.first.get()) << " */"; -#endif - doc << " = " << tir_text_printer_.PrintPrimFunc(Downcast(kv.second)); + os << "def @" << var->name_hint; + doc << relay_text_printer_.PrintFunc(Doc::Text(os.str()), base_func); + } else if (base_func.as()) { + doc << "@" << var->name_hint; + doc << " = " << tir_text_printer_.PrintPrimFunc(Downcast(base_func)); } doc << Doc::NewLine(); } + #if TVM_LOG_DEBUG // attributes + // TODO(mbs): Make this official, including support from parser. if (mod->attrs.defined() && !mod->attrs->dict.empty()) { - doc << "attributes {" << Doc::NewLine(); + std::vector keys; for (const auto& kv : mod->attrs->dict) { - doc << " '" << kv.first << "' = " << PrettyPrint(kv.second) << Doc::NewLine(); + keys.emplace_back(kv.first); + } + std::sort(keys.begin(), keys.end()); + doc << "attributes {" << Doc::NewLine(); + for (const auto& key : keys) { + doc << " '" << key << "' = " << PrettyPrint(mod->attrs->dict[key]) << Doc::NewLine(); } doc << "}" << Doc::NewLine(); } #endif + return doc; } diff --git a/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc b/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc index c41399e314ef..89b325f51a0c 100644 --- a/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc +++ b/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc @@ -116,10 +116,10 @@ class ConvertAddToSubtract : public MixedModeMutator { // Since we are replacing the Relay function with a call to a TIR function, we must use the // call_lowered op. - auto call_lowered_attrs = make_object(); - call_lowered_attrs->metadata.Set("relay_attrs", call->attrs); - return CallLowered(std::move(new_global_var), call->args, - std::move(Attrs(call_lowered_attrs)), call->type_args, call->span); + CallLoweredAttrs attrs; + attrs.metadata.Set("relay_attrs", call->attrs); + ICHECK(call->type_args.empty()) << "lowered functions cannot be polymorphic"; + return CallLowered(std::move(new_global_var), call->args, std::move(attrs), call->span); } } @@ -144,5 +144,4 @@ transform::Pass RelayToTIR() { } // namespace example_target_hooks } // namespace contrib } // namespace relay - } // namespace tvm diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index b47bc401b37f..528df647fe4a 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -560,8 +560,7 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { * to the TIR implementation, and attributes to attach to the call to identify it as * a TIR call. */ - Expr MakeLoweredCall(Function func, Array visited_args, Array type_args, Span span, - Target target) { + Expr MakeLoweredCall(Function func, Array visited_args, Span span, Target target) { CCacheKey key = CCacheKey(func, target); CachedFunc cfunc = compiler_->Lower(key, module_name_); ICHECK(cfunc.defined()); @@ -594,16 +593,16 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { func_with_metadata = WithAttr(func_with_metadata, tvm::attr::kTarget, cfunc->target); this->process_fn_(func_with_metadata); - auto call_lowered_attrs = make_object(); + CallLoweredAttrs call_lowered_attrs; // Non-External Relay Function // TODO(mbs): "reshape" cleanup. if (!opt_compiler && func->HasNonzeroAttr(attr::kReshapeOnly)) { - call_lowered_attrs->metadata.Set(attr::kReshapeOnly, tvm::Integer(1)); + call_lowered_attrs.metadata.Set(attr::kReshapeOnly, tvm::Integer(1)); } - call_lowered_attrs->metadata.Set("relay_attrs", func->attrs); - call_lowered_attrs->metadata.Set("all_prim_fn_vars", all_prim_fn_vars); + call_lowered_attrs.metadata.Set("relay_attrs", func->attrs); + call_lowered_attrs.metadata.Set("all_prim_fn_vars", all_prim_fn_vars); if (IsDynamic(func->ret_type)) { // Also lower the companion dynamic shape function. @@ -616,12 +615,12 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { // Capture the shape function's global var and parameters 'states' in call // annotations so calling convention can be recovered. // TODO(mbs): Shape cleanup. - call_lowered_attrs->metadata.Set("prim_shape_fn_var", lowered_shape_func->prim_fn_var); - call_lowered_attrs->metadata.Set("prim_shape_fn_states", - lowered_shape_func->shape_func_param_states); - call_lowered_attrs->metadata.Set( - "prim_shape_fn_num_inputs", Integer(static_cast(lowered_shape_func->inputs.size()))); - call_lowered_attrs->metadata.Set( + call_lowered_attrs.metadata.Set("prim_shape_fn_var", lowered_shape_func->prim_fn_var); + call_lowered_attrs.metadata.Set("prim_shape_fn_states", + lowered_shape_func->shape_func_param_states); + call_lowered_attrs.metadata.Set("prim_shape_fn_num_inputs", + Integer(static_cast(lowered_shape_func->inputs.size()))); + call_lowered_attrs.metadata.Set( "prim_shape_fn_num_outputs", Integer(static_cast(lowered_shape_func->outputs.size()))); Array all_prim_shape_fn_vars; @@ -629,11 +628,11 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { CHECK(kv.second.as()) << "must be a prim fn"; all_prim_shape_fn_vars.push_back(kv.first); } - call_lowered_attrs->metadata.Set("all_prim_shape_fn_vars", all_prim_shape_fn_vars); + call_lowered_attrs.metadata.Set("all_prim_shape_fn_vars", all_prim_shape_fn_vars); } - return CallLowered(cfunc->prim_fn_var, std::move(visited_args), Attrs(call_lowered_attrs), - type_args, std::move(span)); + return CallLowered(cfunc->prim_fn_var, std::move(visited_args), std::move(call_lowered_attrs), + std::move(span)); } std::pair PreVisitLetBinding_(const Var& var, const Expr& value) final { @@ -729,12 +728,13 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { }); ICHECK(!IsDynamic(call_node->checked_type())); - auto call_lowered_attrs = make_object(); - call_lowered_attrs->metadata.Set("relay_attrs", primitive_func->attrs); + CallLoweredAttrs call_lowered_attrs; + call_lowered_attrs.metadata.Set("relay_attrs", primitive_func->attrs); process_fn_(func_with_metadata); - return CallLowered(call_node->op, std::move(new_args), Attrs(std::move(call_lowered_attrs)), - call_node->type_args, call_node->span); + ICHECK(call_node->type_args.empty()) << "lowered functions cannot be polymorphic"; + return CallLowered(prim_func_var, std::move(new_args), std::move(call_lowered_attrs), + call_node->span); } // Typical case: call to fused primitive Relay Function. @@ -754,8 +754,8 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { // Lower the primitive function for that target. Function function = Downcast(primitive_func); - return MakeLoweredCall(function, std::move(new_args), call_node->type_args, call_node->span, - target); + ICHECK(call_node->type_args.empty()) << "lowered functions cannot be polymorphic"; + return MakeLoweredCall(function, std::move(new_args), call_node->span, target); } IRModule module_; diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index b7c0999ecc72..93b2bcb8d7ef 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -301,7 +301,8 @@ class VMFunctionCompiler : DeviceAwareExprFunctor { size_t NewRegister() { return registers_num_++; } inline void Emit(const Instruction& instr) { - VLOG(2) << "VMCompiler::Emit: instr=" << instr; + size_t instruction_index = instructions_.size(); + VLOG(2) << "instruction[" << instruction_index << "] = " << instr; ICHECK((int)instr.op < 100) << "Invalid opcode " << (int)instr.op; switch (instr.op) { case Opcode::AllocADT: @@ -336,10 +337,9 @@ class VMFunctionCompiler : DeviceAwareExprFunctor { * in emitted code. Note that the host device is always at index 0. */ Index GetDeviceIndex(const SEScope& se_scope) { - VLOG(2) << "getting device index for " << se_scope; + ICHECK(!se_scope->IsFullyUnconstrained()); auto itr = std::find(context_->se_scopes_.begin(), context_->se_scopes_.end(), se_scope); if (itr != context_->se_scopes_.end()) { - VLOG(2) << "reusing existing scope"; return std::distance(context_->se_scopes_.begin(), itr); } @@ -367,7 +367,7 @@ class VMFunctionCompiler : DeviceAwareExprFunctor { ICHECK(se_scope != host_se_scope_); Index index = context_->se_scopes_.size(); - VLOG(2) << "adding new scope"; + VLOG(2) << "se_scope[" << index << "] = " << se_scope; context_->se_scopes_.push_back(se_scope); return index; @@ -378,11 +378,13 @@ class VMFunctionCompiler : DeviceAwareExprFunctor { void VisitExpr_(const ConstantNode* const_node) final { // Check the shape is valid NDArray data = const_node->data; - size_t konst_idx = context_->constants.size(); + size_t const_index = context_->constants.size(); auto con = GetRef(const_node); - context_->const_device_indexes.push_back(GetDeviceIndex(GetSEScope(con))); + Index device_index = GetDeviceIndex(GetSEScope(con)); + VLOG(2) << "constant[" << const_index << "] on device[" << device_index << "]"; + context_->const_device_indexes.push_back(device_index); context_->constants.push_back(const_node->data); - Emit(Instruction::LoadConst(konst_idx, NewRegister())); + Emit(Instruction::LoadConst(const_index, NewRegister())); } void VisitExpr_(const VarNode* var_node) final { @@ -872,6 +874,7 @@ void VMCompiler::Lower(IRModule mod, TargetMap targets, tvm::Target target_host) // The first device is always for the host. CHECK(context_.se_scopes_.empty()); + VLOG(2) << "se_scope[0] = " << config_->host_se_scope << " (host)"; context_.se_scopes_.push_back(config_->host_se_scope); // Run the optimizations necessary to target the VM. diff --git a/src/relay/op/call/call.cc b/src/relay/op/call/call.cc index 6f2d3a8ddbad..ab8e7e12d213 100644 --- a/src/relay/op/call/call.cc +++ b/src/relay/op/call/call.cc @@ -63,23 +63,23 @@ bool CallLoweredRel(const Array& types, int num_inputs, const Attrs& attrs const Op& CallLoweredOp() { return Op::Get("call_lowered"); } -Expr CallLowered(Expr func, Array inputs, Attrs attrs, Array type_args, Span span) { - // Right now, call_lowered only supports func being a global var pointing to the lowered - // function. - ICHECK(func.as()) - << "Function to call should be GlobalVarNode, but got:" << std::endl - << PrettyPrint(func); - ICHECK(attrs.as()) - << "Expected attributes to be CallLoweredAttrs, but got " << attrs->GetTypeKey(); - return Call(CallLoweredOp(), {std::move(func), Tuple(std::move(inputs))}, std::move(attrs), - std::move(type_args), std::move(span)); +Call CallLowered(GlobalVar lowered_func, Array args, CallLoweredAttrs call_lowered_attrs, + Span span) { + auto attrs = make_object(std::move(call_lowered_attrs)); + return Call(CallLoweredOp(), {std::move(lowered_func), Tuple(std::move(args))}, + Attrs(std::move(attrs)), /*type_args=*/{}, std::move(span)); } TVM_REGISTER_GLOBAL("relay.op.call_lowered") - .set_body_typed([](Expr func, Array inputs, Attrs attrs, Array type_args, - Span span) { - const TupleNode* tuple_node = inputs.as(); - return CallLowered(func, tuple_node->fields, attrs, type_args, span); + .set_body_typed([](Expr lowered_func, Array args, Attrs attrs, Span span) { + const auto* lowered_func_node = lowered_func.as(); + ICHECK(lowered_func_node) << "Function to call should be GlobalVarNode, but got:" << std::endl + << PrettyPrint(lowered_func); + const auto* call_lowered_attrs = attrs.as(); + ICHECK(call_lowered_attrs) << "Expected attributes to be CallLoweredAttrs, but got " + << attrs->GetTypeKey(); + return CallLowered(GetRef(lowered_func_node), std::move(args), *call_lowered_attrs, + std::move(span)); }); RELAY_REGISTER_OP("call_lowered") @@ -105,10 +105,12 @@ CallLoweredProps GetCallLoweredProps(const CallNode* call_node) { ICHECK(tuple_args) << "Expected second arg to call_lowered to be a Tuple of input arguments."; ICHECK(call_node->attrs.defined()) << "Expecting call_lowered to have attributes."; - const auto* attrs = call_node->attrs.as(); - ICHECK(attrs) << "Expected call_lowered op to have CallLoweredAttrs, but found " - << call_node->attrs->GetTypeKey(); - return CallLoweredProps{GetRef(function_node), tuple_args->fields, *attrs}; + const auto* call_lowered_attrs = call_node->attrs.as(); + ICHECK(call_lowered_attrs) << "Expected call_lowered op to have CallLoweredAttrs, but found " + << call_node->attrs->GetTypeKey(); + // If the call_node has type_args then they are for the polymorphic 'call_lowered' operator + // itself which expects the function type and argument type as parameters. + return {GetRef(function_node), tuple_args->fields, *call_lowered_attrs}; } return {}; } @@ -116,8 +118,9 @@ CallLoweredProps GetCallLoweredProps(const CallNode* call_node) { Call GetAnyCall(const CallNode* call_node) { CallLoweredProps props = GetCallLoweredProps(call_node); if (props.lowered_func.defined()) { - auto attrs = make_object(props.attrs); - return Call(std::move(props.lowered_func), props.arguments, Attrs(std::move(attrs)), + auto call_lowered_attrs = make_object(props.attrs); + return Call(std::move(props.lowered_func), std::move(props.arguments), + Attrs(std::move(call_lowered_attrs)), /*type_args=*/{}, call_node->span); } else { return GetRef(call_node); diff --git a/src/relay/op/call/call.h b/src/relay/op/call/call.h index a2d30eaa8763..6193c9249ee2 100644 --- a/src/relay/op/call/call.h +++ b/src/relay/op/call/call.h @@ -32,23 +32,33 @@ namespace tvm { namespace relay { -/*! - * \brief Helper to construct a Relay call with the call_lowered op. - * \param func Lowered function to call with call_lowered. - * \param inputs Arguments to be passed to the function. - * \param attrs Function attributes, should be TIRCallAttrs. - * \param type_args Type arguments for the call. - * \param span TVM span for propogating debugging info. - * \return - */ -Expr CallLowered(Expr func, Array inputs, Attrs attrs, Array type_args, Span span); - /*! * \brief Returns the Relay call_lowered op. Use this helper to avoid extraneous calls to * Registry::Get. */ const Op& CallLoweredOp(); +/*! + * \brief Helper to construct a Relay call with the "call_lowered" op. + * + * The callee must: + * - Be a global bound to a PrimFunc or an externally defined functions. + * - Accept only tensor arguments and return tensor results. + * - Arguments and results correspond to the flattened form (see FlattenTupleType) of the + * Relay Function type. + * - Return results by output pointer, ie use DPS. + * The arguments remain in Relay form (ie not flattened). + * The result remains in Relay form (ie returned from the call and not flattened). + * + * \param lowered_func Lowered function to call with call_lowered. + * \param args Arguments to be passed to the function. + * \param call_lowered_attrs Function attributes. + * \param span TVM span for propagating debugging info. + * \return + */ +Call CallLowered(GlobalVar lowered_func, Array args, CallLoweredAttrs call_lowered_attrs, + Span span); + /*! * \brief Lowered function and the arguments to call it with. */ @@ -57,7 +67,7 @@ struct CallLoweredProps { GlobalVar lowered_func; /*! \brief Array of the arguments to call lowered_func with. */ Array arguments; - /*! \brief Arguments from the call_lowered op. */ + /*! \brief Attributes from the call_lowered op. */ CallLoweredAttrs attrs; }; diff --git a/src/relay/op/memory/on_device.cc b/src/relay/op/memory/on_device.cc index 9541d4122a2f..ae5ef33da6d0 100644 --- a/src/relay/op/memory/on_device.cc +++ b/src/relay/op/memory/on_device.cc @@ -29,8 +29,6 @@ #include #include #include -#include -#include #include "../../transforms/infer_layout_utils.h" #include "../type_relations.h" @@ -45,53 +43,74 @@ const Op& OnDeviceOp() { return op; } -Expr OnDevice(Expr expr, SEScope se_scope, bool is_fixed) { - ICHECK(!se_scope->IsFullyUnconstrained()); +Call OnDevice(Expr body, SEScope se_scope, bool constrain_result, bool constrain_body) { + ICHECK((!constrain_result && !constrain_body) || !se_scope->IsFullyUnconstrained()); auto attrs = make_object(); - attrs->se_scope = std::move(se_scope); - attrs->is_fixed = is_fixed; - Span span = expr->span; - return Call(OnDeviceOp(), {std::move(expr)}, Attrs(std::move(attrs)), /*type_args=*/{}, + attrs->se_scope = + (constrain_result || constrain_body) ? std::move(se_scope) : SEScope::FullyUnconstrained(); + attrs->constrain_result = constrain_result; + attrs->constrain_body = constrain_body; + Span span = body->span; // about to be moved + return Call(OnDeviceOp(), {std::move(body)}, Attrs(std::move(attrs)), /*type_args=*/{}, std::move(span)); } TVM_REGISTER_GLOBAL("relay.op.annotation._make.OnDevice").set_body_typed(OnDevice); -Expr MaybeOnDevice(Expr expr, SEScope se_scope, bool is_fixed) { +Expr MaybeOnDevice(Expr body, SEScope se_scope, bool constrain_result, bool constrain_body) { if (se_scope->IsFullyUnconstrained()) { // Nothing to annotate with. - return expr; + return body; } - if (expr->IsInstance() || expr->IsInstance()) { + if (body->IsInstance() || body->IsInstance()) { // These operators are device polymorphic so no annotation is required. - return expr; + return body; } - if (expr->IsInstance() || expr->IsInstance()) { + if (body->IsInstance() || body->IsInstance()) { // The device can be recovered from the binding site of the global or local variable. - return expr; + return body; } - if (expr->IsInstance()) { + if (body->IsInstance()) { // If a primitive function then it is device polymorphic. Otherwise the device is captured // by the function's "result_se_scope" attribute. - return expr; + return body; } - OnDeviceProps props = GetOnDeviceProps(expr); + OnDeviceProps props = GetOnDeviceProps(body); if (props.body.defined()) { - // Don't nest on_devices. - // If the inner and outer device types differ then we need to be careful: - // - If the inner on_device is_fixed then it disagrees with the outer. - // - If the outer on_device is_fixed then it implies a hidden device_copy - // Otherwise just use the inner device type and ignore the outer. - ICHECK(props.se_scope == se_scope || (!is_fixed && !props.is_fixed)); - return OnDevice(props.body, se_scope, is_fixed || props.is_fixed); + // The user is asking for + // on_device(on_device(body, se_scope=inner), se_scope=outer) + // ^ ^ ^ + // outer middle inner + // First recover the implied constraints (if any) for outer and inner, and check they don't + // contradict. + const SEScope& inner = props.se_scope; + const SEScope& outer = se_scope; + bool constrain_outer = constrain_result; + bool constrain_inner = props.constrain_body; + if (constrain_outer && constrain_inner) { + ICHECK(inner == outer) + << "Cannot constrain result and body of nested on_device calls to different SEScopes"; + } + // There are two possible ways the middle sub-expression may be constrained, check they don't + // contradict. + bool constrain_middle_via_outer = constrain_body; + bool constrain_middle_via_inner = props.constrain_result; + if (constrain_middle_via_outer && constrain_middle_via_inner) { + ICHECK(inner == outer) + << "Cannot constrain intermediate result of nested on_device calls to different SEScopes"; + } + // We can now ignore the intermediate constraints, if any. + return OnDevice(props.body, (constrain_inner || constrain_outer) ? outer : inner, + constrain_outer, constrain_inner); + } else { + return OnDevice(body, std::move(se_scope), constrain_result, constrain_body); } - return OnDevice(expr, std::move(se_scope), is_fixed); } RELAY_REGISTER_OP("on_device") .describe(R"code(Annotate an expression with device type)code" TVM_ADD_FILELINE) .set_num_inputs(1) - .add_argument("data", "Tensor", "The input data.") + .add_argument("body", "Expr", "The sub-expression to be annotated.") .set_support_level(10) .add_type_rel("Identity", IdentityRel) .set_attrs_type_key("relay.attrs.OnDeviceAttrs") @@ -106,14 +125,8 @@ OnDeviceProps GetOnDeviceProps(const CallNode* call_node) { ICHECK(call_node->attrs.defined()) << "on_device requires attributes"; const auto* on_device_attrs = call_node->attrs.as(); ICHECK(on_device_attrs != nullptr) << "on_device requires OnDeviceAttrs"; - // Follow nesting: - // on_device(on_device(expr, se_scope=S), se_scope=T) == {expr, S} - auto inner = GetOnDeviceProps(call_node->args[0]); - if (inner.body.defined()) { - return {inner.body, inner.se_scope, on_device_attrs->is_fixed || inner.is_fixed}; - } else { - return {call_node->args[0], on_device_attrs->se_scope, on_device_attrs->is_fixed}; - } + return {call_node->args[0], on_device_attrs->se_scope, on_device_attrs->constrain_result, + on_device_attrs->constrain_body}; } return {}; } diff --git a/src/relay/op/memory/on_device.h b/src/relay/op/memory/on_device.h index a7b6cb7cf52a..bac6695ac35b 100644 --- a/src/relay/op/memory/on_device.h +++ b/src/relay/op/memory/on_device.h @@ -39,14 +39,50 @@ namespace relay { const Op& OnDeviceOp(); /*! - * \brief Wraps \p expr in an "on_device" CallNode for \p se_scope and \p is_fixed. + * \brief Wraps \p body in an "on_device" CallNode for \p se_scope. * * See \p OnDeviceAttrs for an overview. */ -Expr OnDevice(Expr expr, SEScope se_scope, bool is_fixed); +Call OnDevice(Expr body, SEScope se_scope, bool constrain_result = false, + bool constrain_body = true); + +/*! \brief Result of \p GetOnDeviceProps. */ +struct OnDeviceProps { + Expr body; // = null + SEScope se_scope = SEScope::FullyUnconstrained(); + bool constrain_result = false; + bool constrain_body = false; + + OnDeviceProps() = default; + + OnDeviceProps(Expr body, SEScope se_scope, bool constrain_result, bool constrain_body) + : body(std::move(body)), + se_scope(std::move(se_scope)), + constrain_result(constrain_result), + constrain_body(constrain_body) {} + + bool is_fixed() const { return constrain_result && constrain_body; } + bool is_normal() const { return !constrain_result && constrain_body; } +}; + +/*! + * \brief As for OnDevice, but taking all fields other than \p body from \p props. + */ +inline Call OnDeviceWithProps(Expr body, const OnDeviceProps& props) { + return OnDevice(std::move(body), props.se_scope, props.constrain_result, props.constrain_body); +} /*! - * \brief Wraps \p expr in an "on_device" CallNode for \p se_scope and \p is_fixed if the + * \brief As for OnDevice, but don't constrain the body or result to any particular virtual device. + * This allows a "device_copy" when required. + */ +inline Call OnDeviceCopyOk(Expr body) { + return OnDevice(std::move(body), SEScope::FullyUnconstrained(), + /*constrain_result=*/false, /*constrain_body=*/false); +} + +/*! + * \brief Wraps \p expr in an "on_device" CallNode for \p se_scope and \p constraint if the * \p SEScope for \p expr cannot otherwise be recovered by the lexical scoping convention. * This means we will NOT wrap if: * - \p se_scope is full unconstrained, which signals there are no device annotations @@ -57,33 +93,34 @@ Expr OnDevice(Expr expr, SEScope se_scope, bool is_fixed); * - \p expr is a global var. The device is on the function attributes the global is bound to. * - \p expr is a local var. The device is tracked by the device aware visitors for us. * - \p expr is a constructor. These are device polymorphic. - * + * Nested on_device calls will never be constructed, they are instead merged on-the-fly. */ -Expr MaybeOnDevice(Expr expr, SEScope se_scope, bool is_fixed); - -/*! \brief Result of \p GetOnDeviceProps. */ -struct OnDeviceProps { - Expr body; // = null - SEScope se_scope = SEScope::FullyUnconstrained(); - bool is_fixed = false; +Expr MaybeOnDevice(Expr body, SEScope se_scope, bool constrain_result = false, + bool constrain_body = true); - OnDeviceProps() = default; +/*! \brief As for MaybeOnDevice, but with both body and result constrained. */ +inline Expr MaybeOnDeviceFixed(Expr body, SEScope se_scope) { + return MaybeOnDevice(std::move(body), std::move(se_scope), /*constrain_result=*/true, + /*constrain_body=*/true); +} - OnDeviceProps(Expr body, SEScope se_scope, bool isFixed) - : body(std::move(body)), se_scope(std::move(se_scope)), is_fixed(isFixed) {} -}; +/*! \brief As for MaybeOnDevice, but with fields other than body taken from \p props. */ +inline Expr MaybeOnDeviceWithProps(Expr body, const OnDeviceProps& props) { + return MaybeOnDevice(std::move(body), props.se_scope, props.constrain_result, + props.constrain_body); +} /*! - * \brief Returns the body expression, \p SEScope, and is_fixed field for \p call_node if it + * \brief Returns the body expression, \p SEScope, and constraint field for \p call_node if it * is an "on_device" CallNode. Otherwise returns the null expression, the unconstrained - * \p SEScope, and false. + * \p SEScope, and \p kBody. */ OnDeviceProps GetOnDeviceProps(const CallNode* call_node); /*! - * \brief Returns the body expression, \p SEScope, and is_fixed field for \p expr if it is an + * \brief Returns the body expression, \p SEScope, and constraint field for \p expr if it is an * "on_device" CallNode. Otherwise returns the null expression, the unconstrained \p SEScope, - * and \p false. + * and \p kBody. */ OnDeviceProps GetOnDeviceProps(const Expr& expr); diff --git a/src/relay/transforms/device_aware_visitors.cc b/src/relay/transforms/device_aware_visitors.cc index e3d5a821c58e..29965d2dac97 100644 --- a/src/relay/transforms/device_aware_visitors.cc +++ b/src/relay/transforms/device_aware_visitors.cc @@ -36,11 +36,12 @@ namespace transform { LexicalOnDeviceMixin::LexicalOnDeviceMixin(const Optional& maybe_mod) { if (maybe_mod) { - for (const auto& pair : maybe_mod.value()->functions) { - if (const auto* function_node = pair.second.as()) { + for (const auto& kv : maybe_mod.value()->functions) { + if (const auto* function_node = kv.second.as()) { SEScope se_scope = GetFunctionResultSEScope(function_node); if (!se_scope->IsFullyUnconstrained()) { - global_var_se_scopes_.emplace(pair.first, se_scope); + VLOG(2) << "global '" << kv.first->name_hint << "' has scope " << se_scope; + global_var_se_scopes_.emplace(kv.first, se_scope); } } } @@ -49,7 +50,7 @@ LexicalOnDeviceMixin::LexicalOnDeviceMixin(const Optional& maybe_mod) SEScope LexicalOnDeviceMixin::GetSEScope(const Expr& expr) const { OnDeviceProps props = GetOnDeviceProps(expr); - if (props.body.defined() && props.is_fixed) { + if (props.body.defined() && props.is_fixed()) { return props.se_scope; } else if (const auto* var_node = expr.as()) { // Lookup variable binding. @@ -175,7 +176,7 @@ void DeviceAwareExprVisitor::VisitExpr_(const LetNode* let_node) { void DeviceAwareExprVisitor::VisitExpr_(const CallNode* call_node) { OnDeviceProps props = GetOnDeviceProps(call_node); - if (props.body.defined() && props.is_fixed) { + if (props.body.defined() && props.is_fixed()) { // Entering lexical scope of fixed "on_device" call. PushSEScope(props.se_scope); VisitExpr(props.body); @@ -266,13 +267,13 @@ Expr DeviceAwareExprMutator::VisitExpr_(const LetNode* let_node) { Expr DeviceAwareExprMutator::VisitExpr_(const CallNode* call_node) { OnDeviceProps props = GetOnDeviceProps(call_node); - if (props.body.defined() && props.is_fixed) { + if (props.body.defined() && props.is_fixed()) { // Entering lexical scope of fixed "on_device" call. PushSEScope(props.se_scope); Expr expr = VisitExpr(props.body); // Leaving lexical scope of "on_device" call. PopSEScope(); - return MaybeOnDevice(expr, props.se_scope, props.is_fixed); + return MaybeOnDeviceWithProps(expr, props); } else { return DeviceAwareVisitExpr_(call_node); } diff --git a/src/relay/transforms/device_aware_visitors.h b/src/relay/transforms/device_aware_visitors.h index 8cdf0db74ebd..044cda85c579 100644 --- a/src/relay/transforms/device_aware_visitors.h +++ b/src/relay/transforms/device_aware_visitors.h @@ -146,7 +146,10 @@ class DeviceAwareExprFunctor : public ExprFunctorparams[i], GetFunctionParamSEScope(function_node, i)); } // Entering scope of function body. - PushSEScope(GetFunctionResultSEScope(function_node)); + SEScope se_scope = GetFunctionResultSEScope(function_node); + VLOG(2) << "entering " << se_scope << " for function:" << std::endl + << PrettyPrint(GetRef(function_node)); + PushSEScope(se_scope); EnterFunctionBody(); DeviceAwareVisitExpr_(function_node); @@ -154,6 +157,8 @@ class DeviceAwareExprFunctor : public ExprFunctor(function_node)); // Function parameters go out of scope. for (size_t i = 0; i < function_node->params.size(); ++i) { PopBoundVar(function_node->params[i]); @@ -168,7 +173,9 @@ class DeviceAwareExprFunctor : public ExprFunctor()) { // Let-bound var (in pre visited version) goes into scope. // (We'll just assume this is a letrec.) - PushBoundVar(inner_let_node->var, GetSEScope(inner_let_node->value)); + SEScope se_scope = GetSEScope(inner_let_node->value); + VLOG(2) << "var '" << inner_let_node->var->name_hint() << "' has scope " << se_scope; + PushBoundVar(inner_let_node->var, se_scope); PreVisitLetBinding_(inner_let_node->var, inner_let_node->value); bindings.emplace_back(inner_let_node); expr = inner_let_node->body; @@ -187,12 +194,16 @@ class DeviceAwareExprFunctor : public ExprFunctor(call_node)); PushSEScope(props.se_scope); VisitExpr(props.body); // Leaving lexical scope of "on_device" call. PopSEScope(); + VLOG(2) << "leaving " << props.se_scope << " for on_device:" << std::endl + << PrettyPrint(GetRef(call_node)); } else { DeviceAwareVisitExpr_(call_node); } diff --git a/src/relay/transforms/device_domains.cc b/src/relay/transforms/device_domains.cc index dafd8bc814bb..76697d8437f4 100644 --- a/src/relay/transforms/device_domains.cc +++ b/src/relay/transforms/device_domains.cc @@ -199,16 +199,23 @@ DeviceDomainPtr DeviceDomains::DomainForCallee(const Call& call) { DeviceCopyProps device_copy_props = GetDeviceCopyProps(call.get()); CallLoweredProps call_lowered_props = GetCallLoweredProps(call.get()); + // TODO(mbs): Support call_lowered to PrimFuncs. + ICHECK(!call_lowered_props.lowered_func.defined()); if (on_device_props.body.defined()) { - // on_device(expr, se_scope=, is_fixed=false) - // on_device : fn():?x? - // - // on_device(expr, se_scope=, is_fixed=true) - // on_device: fn(): - args_and_result.emplace_back( - ForSEScope(on_device_props.body->checked_type(), on_device_props.se_scope)); - if (on_device_props.is_fixed) { - args_and_result.emplace_back(args_and_result.front()); + // By default: + // on_device(expr, se_scope=) + // on_device : fn():?x? + // However we'll interpret the constrain_body and constrain_result fields to decide + // on free vs constrained domains for the argument and result respectively. + if (on_device_props.constrain_body) { + args_and_result.emplace_back( + ForSEScope(on_device_props.body->checked_type(), on_device_props.se_scope)); + } else { + args_and_result.emplace_back(Free(on_device_props.body->checked_type())); + } + if (on_device_props.constrain_result) { + args_and_result.emplace_back( + ForSEScope(on_device_props.body->checked_type(), on_device_props.se_scope)); } else { args_and_result.emplace_back(Free(on_device_props.body->checked_type())); } @@ -286,11 +293,9 @@ DeviceDomainPtr DeviceDomains::DomainForCallee(const Call& call) { args_and_result.emplace_back(param_domain); } args_and_result.emplace_back(result_domain); - } else if (call_lowered_props.lowered_func.defined()) { - return DomainFor(call_lowered_props.lowered_func); } else { // We still need to handle the case where the function / op is not lowered - // because the device planner runs before and after lowering. + // because the device planner runs both before and after lowering. return DomainFor(call->op); } auto domain = MakeHigherOrderDomain(std::move(args_and_result)); @@ -313,6 +318,33 @@ void DeviceDomains::UnifyExprExact(const Expr& lhs, const Expr& rhs) { } } +void DeviceDomains::OptionalUnifyExprExact(const Expr& lhs, const Expr& rhs) { + auto lhs_domain = DomainFor(lhs); + auto rhs_domain = DomainFor(rhs); + // Snapshot + std::unordered_map domain_to_equiv_snapshot = domain_to_equiv_; + if (UnifyOrNull(lhs_domain, rhs_domain) == nullptr) { + // Rollback + domain_to_equiv_ = domain_to_equiv_snapshot; + VLOG(2) << "Unable to unify SEScopes for expression:" << std::endl + << PrettyPrint(lhs) << std::endl + << "with scope:" << std::endl + << ToString(lhs_domain) << std::endl + << "and expression:" << std::endl + << PrettyPrint(rhs) << std::endl + << "with scope:" << std::endl + << ToString(rhs_domain) << std::endl + << ". Leaving scopes non-unified."; + } else { + VLOG(2) << "Unified SEScopes for expression:" << std::endl + << PrettyPrint(lhs) << std::endl + << "and expression:" << std::endl + << PrettyPrint(rhs) << std::endl + << "to scope:" << std::endl + << ToString(lhs_domain); + } +} + void DeviceDomains::UnifyExprExact(const Expr& expr, const DeviceDomainPtr& expected_domain) { auto actual_domain = DomainFor(expr); if (UnifyOrNull(actual_domain, expected_domain) == nullptr) { diff --git a/src/relay/transforms/device_domains.h b/src/relay/transforms/device_domains.h index 997fef3456c7..223c7d42bfa1 100644 --- a/src/relay/transforms/device_domains.h +++ b/src/relay/transforms/device_domains.h @@ -62,6 +62,16 @@ class DeviceDomains; * result_domain() = * result_domain(fn(D1, ..., Dn):Dr) = result_domain(Dr) * \endcode + * + * TODO(mbs): We currently don't allow sub-SEScope constraints. Eg for a function we can + * express that the argument and result SEScopes must be exactly equal, but we cannot express + * that though the devices and targets for arguments and results must be equal, it is ok for + * memory scopes to differ. At the moment we can get away with this since we run PlanDevices + * twice: once with all memory scopes unconstrained, then again with just memory scopes as + * the new property to flow. However we're on thin ice here and better would be to allow + * constraints on SEScopes to be exploded into their device/target component and their + * memory scope component. Should we fold layout constraints into SEScopes then they would + * probably be grouped with memory scopes. */ class DeviceDomain { public: @@ -177,7 +187,8 @@ class DeviceDomains { /*! * \brief Unifies \p lhs and \p rhs, returning the most-bound of the two. Returns null if - * \p lhs and \p rhs are not unifiable. + * \p lhs and \p rhs are not unifiable, in which case the constraint system may be left in + * a partially modified state. */ // TODO(mbs): I don't think we need an occurs check since the program is well-typed, but // given we have refs to functions I'm prepared to be surprised. @@ -237,6 +248,12 @@ class DeviceDomains { */ void UnifyExprExact(const Expr& lhs, const Expr& rhs); + /*! + * \brief Attempts to unify the domains for expressions \p lhs and \p rhs, however if they + * cannot be unified then returns with no change to the unification system. + */ + void OptionalUnifyExprExact(const Expr& lhs, const Expr& rhs); + /*! * \brief Unifies the domain for \p expr with \p expected_domain. * diff --git a/src/relay/transforms/device_planner.cc b/src/relay/transforms/device_planner.cc index 8ea5f5dac0a4..a85233de17e5 100644 --- a/src/relay/transforms/device_planner.cc +++ b/src/relay/transforms/device_planner.cc @@ -20,6 +20,10 @@ /*! * \file src/relay/transforms/device_planner.cc * \brief Determines a unique \p SEScope to hold the result of every Relay sub-expression. + * This pass can be run multiple times, and can be run both before and after lowering. + * + * TODO(mbs): Rename SEScope |-> VirtualDevice, and use 'virtual device' (or just 'device') + * throughout. * * We say a Relay expression E is 'on device D' if the result of executing E is stored on D. * We represent D by an \p SEScope, which means we can track anywhere from an arbitrary device @@ -29,17 +33,21 @@ * Note that 'stored on device D' is almost but not quite the same as 'executes on device D', * see below. * - * This pass assumes the module already contains some "on_device" and/or "device_copy" CallNodes: - * - "device_copy" CallNodes (with a \p DeviceCopyAttrs attribute) specify a 'src_se_scope' and - * 'dst_se_scope' \p SEScopes, which constrain the argument and context of the call - * respectively. It is ok if source and destination devices are the same, such no-op copies - * will be removed after accounting for the device preference. - * - "on_device" CallNodes (with a \p OnDeviceAttrs attribute) specify an 'se_scope', which - * constrains the argument of the call, but (usually, see below) leaves the context + * This pass works by collecting and solving device constraints, using defaulting heuristics to + * resolve any remaining undetermined devices, and encoding the results on the output in a form + * that's reasonably friendly to downstream passes. + * + * Specific \p SEScopes flow into the constraints from five places: + * - Existing "device_copy" CallNodes (with a \p DeviceCopyAttrs attribute) specify a + * 'src_se_scope' and 'dst_se_scope' \p SEScope. Those constrain the argument and context of + * the call respectively. It is ok if source and destination devices are the same, such no-op + * copies will be removed after accounting for the device preference. + * - Existing "on_device" CallNodes (with a \p OnDeviceAttrs attribute) specify an 'se_scope', + * which constrains the argument of the call, but (usually, see below) leaves the context * unconstrained. These are called 'annotations' in the rest of the code, have no operational - * significance by themselves, but may trigger the insertion of a new "device_copy". - * - In two situations the result of an "on_device" CallNode may also be constrained to the - * given device: + * significance by themselves, but may trigger the insertion of a new "device_copy" call by + * this pass. In two situations the result of an "on_device" CallNode may also be constrained + * to the given 'se_scope': * - The "on_device" call occurs at the top-level of a function body, or occurs as an * immediately let-bound expression. In this situation the extra degree of freedom in * the function result and let-binding leads to surprising device copies, so we simply @@ -47,6 +55,11 @@ * - The \p OnDeviceAttrs has an \p is_fixed field of \p true, which indicates we inserted * it ourselves during an earlier invocation of this pass. This helps make this pass * idempotent. + * - Some special operators require their arguments or results to be on the 'host' (typcially + * a CPU) \p SEScope, see below. + * - Annotations left over from a previous run of this pass, such as 'param_se_scopes' and + * 'result_se_scope' function attributes we introduce below. This is so the pass is idempotent + * and can be re-run to flow additional memory scope constraints. * * We proceed in four phases: * @@ -61,14 +74,13 @@ * * Phase 1 * ------- - * We flow constraints from the "on_device" and "device_copy" calls (and some special ops, see - * below) to all other Relay sub-expressions. (For idempotence we also respect any existing - * "param_se_scopes" and "result_se_scope" function attributes we introduce below.) + * We flow constraints from the "on_device" and "device_copy" calls, + * and some special ops, to all other Relay sub-expressions. * * For a primitive such as \code add(e1, e2) \endcode all arguments and results must be on the * same device. However each call site can use a different device. In other words primitives are * 'device polymorphic' since we compile and execute them for each required device. ADT constructors - * are similarly polymorphic. + * are similarly polymorphic, but require all constructor args to be on the same device. * * For most Relay expressions the device for the overall expression is the same as the device * for its sub-expressions. E.g. each field of a tuple must be on the same device as the tuple @@ -99,6 +111,8 @@ * - Unconstrained let-bound expression devices default to the device for the overall let. * TODO(mbs): These are very simple minded heuristics, and ultimately we'd like to treat the * assignment of the remaining unconstrained sub-expressions as an optimiziation problem in itself. + * This requires a formal notion of 'choicepoint' inside the compiler which can integrate with + * automation. * * Phase 3 * ------- @@ -108,10 +122,15 @@ * the function's parameters and the result. * - Additional "device_copy" CallNodes where a copy is required in order to respect the * intent of the original "on_device" CallNodes. - * - Additional "on_device" CallNodes where the device type of an expression does not match - * that of the lexically enclosing "on_device" CallNode or function attribute. In practice + * - Additional "on_device" CallNodes where the device type of an expression is not trivially + * implied by the lexically enclosing "on_device" CallNode or function attribute. In practice * this means "on_device" CallNodes may appear in two places: - * - On a let-bound expression if its device differs from the overall let expression. + * - On let-bound expressions. It is tempting to elide the "on_device" if the let-bound value + * has the same device as the overall let expression. However this would mean passes which + * inline let-bound values, such as FoldConstant and DeadCodeElimination, would need to us + * a DeviceAware visitor which in turn requires the expression to be in ANF to avoid + * deep recursion. To minimize disruption we always include the "on_device" so that it + * can follow the inline. * - On a call argument if its device differs from the call result. In particular, the * argument to a "device_copy" call will always be wrapped in an "on_device". (That may * seem pedantic but simplifies downstream handling.) @@ -125,15 +144,14 @@ * passes must preserve the lexical scoping of the "on_device" CallNodes. E.g. conversion * to ANF must respect the lexical scoping convention: * \code - * f(on_device(g(h(a, b), c), se_scope=CPU)) - * ==> - * let %x0 = on_device(h(a, b), se_scope=CPU) - * let %x1 = on_device(g(%x0), se_scope=CPU) - * f(on_device(%x1, se_scope=CPU)) + * f(on_device(g(h(a, b), c), se_scope=CPU)) + * ==> + * let %x0 = on_device(h(a, b), se_scope=CPU) + * let %x1 = on_device(g(%x0), se_scope=CPU) + * f(on_device(%x1, se_scope=CPU)) * \endcode * - * This pass can be run before FuseOps it can use device-specific fusion rules. - * TODO(mbs): We also need to support running after FuseOps. + * This pass can be run before FuseOps so that it can use device-specific fusion rules. * * 'Stored on' vs 'Executes on' * ---------------------------- @@ -227,17 +245,10 @@ * | * `-- Mark's stamp of completeness :-) * - * TODO(mbs): - * * Proper diagnostics for unification failure using spans. - * * Support running the pass post FuseOps (so need to understand primitive functions, both - * outlines and lined) and post the VM transforms (probably need to support more intrinsic - * forms?). - * * Don't hardcode the 'CPU' device for shape funcs etc, and distinguish between the default - * device for primitives vs the default device for the rest of Relay. - * * We may want some 'device polymorphism' for Relay functions. Eg it's ok for the function - * to be called with params/result on different (virtual) device ids provided the target and - * memory scopes are consistent. - * * Switch to expr.CopyWith(...) form once implemented to avoid unnecessary copies. + * TODO(mbs): Proper diagnostics for unification failure using spans. + * TODO(mbs): We may want some 'device polymorphism' for Relay functions. Eg it's ok for the + * function to be called with params/result on different (virtual) device ids provided the target + * and memory scopes are consistent. */ #include @@ -252,6 +263,8 @@ #include #include #include +#include +#include #include @@ -266,27 +279,32 @@ namespace transform { namespace { -/****** -******* Phase 0 -*******/ +/* =============== Phase 0 =============== */ /*! * \brief Rewrites "on_device" calls to handle some special cases. * - * \code - * let %x = on_device(e, se_scope=d) - * ==> let %x = on_device(e, se_scope=d, is_fixed=True) + * - Don't let the device for %x remain unconstrained: + * \code + * let %x = on_device(e, se_scope=d) + * ==> let %x = on_device(e, se_scope=d, constraint=kBoth) + * \endcode * - * fn(%x) { on_device(e, se_scope=d) } - * ==> fn(%x) { on_device(e, se_scope=d, is_fixed=True) + * - Don't let the function result remain unconstrained: + * \code + * fn(%x) { on_device(e, se_scope=d) } + * ==> fn(%x) { on_device(e, se_scope=d, constraint=kBoth) + * \endcode * - * on_device(e).0 - * ==> on_device(e.0) - * \endcode + * - Project-then-copy rather than copy-then-project: + * \code + * on_device(e).0 + * ==> on_device(e.0) + * \endcode */ class RewriteOnDevices : public ExprMutator { public: - RewriteOnDevices() = default; + explicit RewriteOnDevices(IRModule mod) : mod_(std::move(mod)) {} private: Expr VisitExpr_(const TupleGetItemNode* tuple_get_item_node) final { @@ -294,11 +312,11 @@ class RewriteOnDevices : public ExprMutator { OnDeviceProps props = GetOnDeviceProps(tuple); Expr tuple_get_item = WithFields(GetRef(tuple_get_item_node), std::move(tuple)); - if (props.body.defined() && !props.is_fixed) { + if (props.body.defined() && props.is_normal()) { VLOG(2) << "wrapping tuple get item:" << std::endl << PrettyPrint(GetRef(tuple_get_item_node)) << std::endl << "with \"on_device\" for SEScope " << props.se_scope; - return OnDevice(tuple_get_item, props.se_scope, /*is_fixed=*/false); + return OnDeviceWithProps(tuple_get_item, props); } else { return tuple_get_item; } @@ -311,19 +329,19 @@ class RewriteOnDevices : public ExprMutator { Let inner_let = GetRef(inner_let_node); Expr value = VisitExpr(inner_let_node->value); OnDeviceProps props = GetOnDeviceProps(value); - if (props.body.defined() && !props.is_fixed) { + if (props.body.defined() && props.is_normal()) { VLOG(2) << "revising let-bound expression of let:" << std::endl << PrettyPrint(expr) << std::endl << "to be fixed to SEScope " << props.se_scope; - value = OnDevice(props.body, props.se_scope, /*is_fixed=*/true); + value = MaybeOnDeviceFixed(props.body, props.se_scope); } bindings.emplace_back(inner_let, value); expr = inner_let_node->body; } expr = VisitExpr(expr); for (auto itr = bindings.rbegin(); itr != bindings.rend(); ++itr) { - expr = WithFields(/*let=*/std::move(std::get<0>(*itr)), /*var = unchanged*/ {}, - /*value=*/std::move(std::get<1>(*itr)), /*body=*/std::move(expr)); + expr = WithFields(/*let=*/std::move(std::get<0>(*itr)), /*opt_var=*/{}, + /*opt_value=*/std::move(std::get<1>(*itr)), /*opt_body=*/std::move(expr)); } return expr; } @@ -331,20 +349,20 @@ class RewriteOnDevices : public ExprMutator { Expr VisitExpr_(const FunctionNode* function_node) final { Expr body = VisitExpr(function_node->body); OnDeviceProps props = GetOnDeviceProps(body); - if (props.body.defined() && !props.is_fixed) { + if (props.body.defined() && props.is_normal()) { VLOG(2) << "revising body of function:" << std::endl << PrettyPrint(GetRef(function_node)) << std::endl << "to be fixed to SEScope " << props.se_scope; - body = OnDevice(props.body, props.se_scope, /*is_fixed=*/true); + body = MaybeOnDeviceFixed(props.body, props.se_scope); } - return WithFields(GetRef(function_node), std::move(function_node->params), - std::move(body)); + return WithFields(GetRef(function_node), function_node->params, std::move(body)); } + + /*! \brief Module we are rewriting, so we can lookup global definitions. */ + IRModule mod_; }; -/****** -******* Phase 1 -*******/ +/* =============== Phase 1 =============== */ /* * \brief Collects the system of device constraints for all sub-expressions in a module. @@ -374,10 +392,15 @@ class DeviceAnalyzer : public ExprVisitor { */ std::unique_ptr Analyze() { VLOG_CONTEXT << "DeviceAnalyzer"; - for (const auto& pair : mod_->functions) { - VLOG(2) << "collecting constraints for '" << PrettyPrint(pair.first) << "'"; - domains_->UnifyExprExact(pair.first, pair.second); - VisitExpr(pair.second); + for (const auto& kv : mod_->functions) { + // The global variable and what it is bound to must obviously agree on domain. + if (const auto* function_node = AsOptimizableFunctionNode(kv.second)) { + VLOG(2) << "collecting constraints from Relay Function '" << kv.first->name_hint << "'"; + domains_->UnifyExprExact(kv.first, kv.second); + VisitExpr(GetRef(function_node)); + } else { + VLOG(2) << "skipping '" << kv.first->name_hint << "'"; + } } return std::move(domains_); } @@ -386,16 +409,19 @@ class DeviceAnalyzer : public ExprVisitor { void VisitExpr_(const CallNode* call_node) final { auto call = GetRef(call_node); + // We don't care if the call is in pre- or post-lowered form. + auto vanilla_call = GetAnyCall(call_node); + // Find the higher-order domain for the callee. See DomainForCallee for the special rules // for primitives. - VisitExpr(call_node->op); + VisitExpr(vanilla_call->op); auto func_domain = domains_->DomainForCallee(call); // higher-order // Build the domain for the function implied by its arguments and call context. - ICHECK_EQ(func_domain->function_arity(), call_node->args.size()); + ICHECK_EQ(func_domain->function_arity(), vanilla_call->args.size()) << PrettyPrint(call); std::vector args_and_result_domains; - args_and_result_domains.reserve(call_node->args.size() + 1); - for (const auto& arg : call_node->args) { + args_and_result_domains.reserve(vanilla_call->args.size() + 1); + for (const auto& arg : vanilla_call->args) { args_and_result_domains.emplace_back(domains_->DomainFor(arg)); VisitExpr(arg); } @@ -416,9 +442,9 @@ class DeviceAnalyzer : public ExprVisitor { LOG(FATAL) << "Function parameters and result SEScopes do not match those of call. Call:" << std::endl << PrettyPrint(call) << std::endl - << "with function scopes:" << std::endl + << "with function virtual devices:" << std::endl << domains_->ToString(func_domain) << std::endl - << "and implied call scopes:" << std::endl + << "and implied call virtual devices:" << std::endl << domains_->ToString(implied_domain); } @@ -494,9 +520,9 @@ class DeviceAnalyzer : public ExprVisitor { << "Function SEScopes are incompatible with its \"on_device\" annotation. Function:" << std::endl << PrettyPrint(function) << std::endl - << "with function scopes:" << std::endl + << "with function virtual devices:" << std::endl << domains_->ToString(func_domain) << std::endl - << "and annotation scopes:" << std::endl + << "and annotation virtual devices:" << std::endl << domains_->ToString(annotation_domain); } } @@ -621,9 +647,52 @@ class DeviceAnalyzer : public ExprVisitor { std::unique_ptr domains_; }; -/****** -******* Phase 2 -*******/ +/* =============== Phase 2 =============== */ + +/*! + * \brief Calls to 'free' "on_device" annotations (ie where both constrain_body=false and + * constrain_result=false) indicate a device_copy is allowed if required, but no particular + * device is imposed on the body or the context. At this stage we can attempt to unify the + * body and device contexts. In this way we can avoid the defaulting rules in \p DeviceDefaulter + * from choosing default devices which are only going to induce a device copy. + * + * TODO(mbs): The order in which we encounter the "on_device" calls can influence the final global + * device assignment. However we visit global functions in hash map order. + */ +class FreeOnDeviceDefaulter : public ExprVisitor { + public: + FreeOnDeviceDefaulter(IRModule mod, std::unique_ptr domains) + : mod_(std::move(mod)), domains_(std::move(domains)) {} + + std::unique_ptr Default() { + VLOG_CONTEXT << "FreeOnDeviceDefaulter"; + VLOG(0) << "unifying free on_device annotations"; + for (const auto& kv : mod_->functions) { + if (const auto* function_node = AsOptimizableFunctionNode(kv.second)) { + VLOG(2) << "unifying for '" << kv.first->name_hint << "'"; + VisitExpr(GetRef(function_node)); + } else { + VLOG(2) << "skipping '" << kv.first->name_hint << "'"; + } + } + return std::move(domains_); + } + + private: + void VisitExpr_(const CallNode* call_node) final { + auto call = GetRef(call_node); + OnDeviceProps props = GetOnDeviceProps(call_node); + ExprVisitor::VisitExpr_(call_node); + if (props.body.defined() && !props.constrain_body && !props.constrain_result) { + domains_->OptionalUnifyExprExact(call, props.body); + } + } + + /*! \brief The module we are processing. */ + IRModule mod_; + /*! \brief The domains for all expressions. */ + std::unique_ptr domains_; +}; /*! * \brief Ensures every sub-expression in a module has a device type, using both the global @@ -634,10 +703,11 @@ class DeviceAnalyzer : public ExprVisitor { * def @main(%x, %y, %z) { * let %a = add(%x, %y); * multiply(%a, on_device(%z, se_scope=d)) + * } * \endcode * we know the parameter \p %z must be on device \p d, but the devices for \p %x and \p %y, * and the device for the function result, are still 'free'. The global 'default' device type - * is first used to 'fix' \p @main's result type, which in turn 'fixes' \p %x and \p %y, which + * is first used to 'fix' \p main's result type, which in turn 'fixes' \p %x and \p %y, which * in turn 'fixes' the device on which the \p add and \p multiply are executed. * * TODO(mbs): I think this is deterministic? We do however visit the top-level defs in hashmap @@ -651,9 +721,13 @@ class DeviceDefaulter : public ExprVisitor { std::unique_ptr Default() { VLOG_CONTEXT << "DeviceDefaulter"; VLOG(0) << "defaulting to SEScope " << domains_->config()->default_primitive_se_scope; - for (const auto& pair : mod_->functions) { - VLOG(2) << "defaulting devices for '" << PrettyPrint(pair.first) << "'"; - VisitExpr(pair.second); + for (const auto& kv : mod_->functions) { + if (const auto* function_node = AsOptimizableFunctionNode(kv.second)) { + VLOG(2) << "defaulting devices for '" << kv.first->name_hint << "'"; + VisitExpr(GetRef(function_node)); + } else { + VLOG(2) << "skipping '" << kv.first->name_hint << "'"; + } } return std::move(domains_); } @@ -678,8 +752,12 @@ class DeviceDefaulter : public ExprVisitor { void VisitExpr_(const CallNode* call_node) final { auto call = GetRef(call_node); + + // We don't care if the call is pre- or post-lowered. + auto vanilla_call = GetAnyCall(call_node); + auto func_domain = domains_->DomainForCallee(call); // higher-order - ICHECK_EQ(func_domain->function_arity(), call_node->args.size()); + ICHECK_EQ(func_domain->function_arity(), vanilla_call->args.size()); if (!domains_->IsFullyConstrained(func_domain)) { // For calls to Relay functions this step is identical to that for VisitExpr_(FunctionNode*) // above. But for calls to primitives we may still need to force free domains to be @@ -720,9 +798,7 @@ class DeviceDefaulter : public ExprVisitor { std::unique_ptr domains_; }; -/****** -******* Phase 3 -*******/ +/* =============== Phase 3 =============== */ /*! * \brief Inserts missing "device_copy" CallNodes, and ensures the device type of every @@ -749,6 +825,8 @@ class DeviceDefaulter : public ExprVisitor { * keeps downstream processing simple. The "on_device" calls should be removed before code gen, * which is easily done on-the-fly. * + * - Update memory scopes in PrimFunc buffer maps. + * * For example, we'll end up with programs that look like: * \code * def @main(%x, %y, param_se_scopes=[...], result_se_scope=...) { @@ -767,9 +845,14 @@ class DeviceCapturer : public ExprMutator { VLOG_CONTEXT << "CaptureDevices"; IRModule result(/*functions=*/{}, mod_->type_definitions, mod_->Imports(), mod_->source_map, mod_->attrs); - for (const auto& pair : mod_->functions) { - VLOG(2) << "capturing devices for '" << PrettyPrint(pair.first) << "'"; - result->Add(pair.first, Downcast(Mutate(pair.second))); + for (const auto& kv : mod_->functions) { + if (const auto* function_node = AsOptimizableFunctionNode(kv.second)) { + VLOG(2) << "capturing devices for Relay Function '" << kv.first->name_hint << "'"; + result->Add(kv.first, Downcast(Mutate(GetRef(function_node)))); + } else { + VLOG(2) << "skipping '" << kv.first->name_hint << "'"; + result->Add(kv.first, kv.second); + } } return result; } @@ -825,6 +908,11 @@ class DeviceCapturer : public ExprMutator { Expr VisitExpr_(const CallNode* call_node) final { auto call = GetRef(call_node); + + // We don't care if the call is pre- or post-lowered + // (However we'll preserve the form in the result below.) + auto vanilla_call = GetAnyCall(call_node); + SEScope call_se_scope = GetSEScope(call); auto on_device_props = GetOnDeviceProps(call_node); @@ -844,7 +932,8 @@ class DeviceCapturer : public ExprMutator { // match. return VisitExpr(device_copy_props.body); } else { - return VisitChild(/*lexical_se_scope=*/dst_se_scope, + return VisitChild(/*lexical_se_scope=*/ + dst_se_scope, /*expected_se_scope=*/dst_se_scope, /*child_se_scope=*/src_se_scope, device_copy_props.body); } @@ -854,7 +943,7 @@ class DeviceCapturer : public ExprMutator { auto func_domain = domains_->DomainForCallee(call); // higher-order VLOG(2) << "considering call:" << std::endl << PrettyPrint(call) << std::endl - << "in scope " << call_se_scope << " with function domain:" << std::endl + << "in scope " << call_se_scope << " with function virtual devices:" << std::endl << domains_->ToString(func_domain); SEScope result_se_scope = domains_->ResultSEScope(func_domain); ICHECK(!result_se_scope->IsFullyUnconstrained()); @@ -863,25 +952,32 @@ class DeviceCapturer : public ExprMutator { Expr op = VisitChild( /*lexical_se_scope=*/call_se_scope, /*expected_se_scope=*/call_se_scope, - /*child_se_scope=*/result_se_scope, call_node->op); + /*child_se_scope=*/result_se_scope, vanilla_call->op); // Each argument can be on the device for the corresponding function parameter. However if // any of those differ from the overall call device then wrap them in an "on_device" to // help downstream transforms track devices lexically. Array args; - args.reserve(call_node->args.size()); - ICHECK_EQ(func_domain->function_arity(), call->args.size()); - for (size_t i = 0; i < call_node->args.size(); ++i) { + args.reserve(vanilla_call->args.size()); + ICHECK_EQ(func_domain->function_arity(), vanilla_call->args.size()); + for (size_t i = 0; i < vanilla_call->args.size(); ++i) { SEScope param_se_scope = domains_->ResultSEScope(func_domain->function_param(i)); ICHECK(!param_se_scope->IsFullyUnconstrained()) << "for parameter " << i << " for call:" << std::endl << PrettyPrint(call); args.push_back(VisitChild(/*lexical_se_scope=*/call_se_scope, /*expected_se_scope=*/param_se_scope, - /*child_se_scope=*/GetSEScope(call_node->args[i]), - call_node->args[i])); + /*child_se_scope=*/GetSEScope(vanilla_call->args[i]), + vanilla_call->args[i])); + } + + if (call_node->op == CallLoweredOp()) { + Call new_call = + CallLowered(Downcast(op), args, /*call_lowered_attrs=*/{}, /*span=*/{}); + return WithFields(call, std::move(new_call->op), std::move(new_call->args)); + } else { + return WithFields(call, std::move(op), std::move(args)); } - return WithFields(GetRef(call_node), std::move(op), std::move(args)); } Expr VisitExpr_(const LetNode* let_node) final { @@ -895,11 +991,11 @@ class DeviceCapturer : public ExprMutator { // We have a device transition which needs to be handled. break; } - // The let-bound value can be on a different device than the overall let. However if those - // devices don't agree wrap the let-bound value in an "on_device" to help downstream - // transforms track devices lexically. + // The let-bound value can be on a different device than the overall let. + // By using the fully-unconstrained SEScope for the 'lexical' scope we'll force the let-bound + // value to *always* be wrapped by an "on_device" (see introductory comment for motivation.) Expr value = - VisitChild(/*lexical_se_scope=*/let_se_scope, + VisitChild(/*lexical_se_scope=*/SEScope::FullyUnconstrained(), /*expected_se_scope=*/GetSEScope(inner_let_node->var), /*child_se_scope=*/GetSEScope(inner_let_node->value), inner_let_node->value); bindings.emplace_back(inner_let_node->var, value, inner_let_node->span); @@ -967,7 +1063,7 @@ class DeviceCapturer : public ExprMutator { OnDeviceProps props = GetOnDeviceProps(expr); Expr true_expr = props.body.defined() ? props.body : expr; ICHECK(domains_->contains(true_expr)); - // If expr is higher order we'll return only the result domain's SEScope. + // If expr is higher order we'll return only the result domain's device. SEScope se_scope = domains_->ResultSEScope(domains_->DomainFor(true_expr)); ICHECK(!se_scope->IsFullyUnconstrained()) << "no SEScope was determined for expression:" << std::endl @@ -1003,7 +1099,6 @@ class DeviceCapturer : public ExprMutator { */ Expr VisitChild(const SEScope& lexical_se_scope, const SEScope& expected_se_scope, const SEScope& child_se_scope, const Expr& child) { - ICHECK(!lexical_se_scope->IsFullyUnconstrained()); ICHECK(!expected_se_scope->IsFullyUnconstrained()); if (child->IsInstance() || child->IsInstance()) { // Primitive operators and contructors don't need to be rewritten and can have a @@ -1012,19 +1107,19 @@ class DeviceCapturer : public ExprMutator { } Expr result = VisitExpr(child); if (child_se_scope != expected_se_scope) { - VLOG(2) << "creating " << DeviceCopyOp()->name << " from scope " << child_se_scope - << " to scope " << expected_se_scope << " for:" << std::endl + VLOG(2) << "creating " << DeviceCopyOp()->name << " from virtual device " << child_se_scope + << " to virtual device " << expected_se_scope << " for:" << std::endl << PrettyPrint(result); // Also wrap the child in an "on_device" so downstream transforms can track devices // lexically. - result = MaybeOnDevice(result, child_se_scope, /*is_fixed=*/true); + result = MaybeOnDeviceFixed(result, child_se_scope); result = DeviceCopy(result, child_se_scope, expected_se_scope); } if (expected_se_scope != lexical_se_scope) { - VLOG(2) << "creating " << OnDeviceOp()->name << " for scope " << expected_se_scope + VLOG(2) << "creating " << OnDeviceOp()->name << " for virtual device " << expected_se_scope << " for:" << std::endl << PrettyPrint(result); - result = MaybeOnDevice(result, expected_se_scope, /*is_fixed=*/true); + result = MaybeOnDeviceFixed(result, expected_se_scope); } return result; } @@ -1048,7 +1143,7 @@ class DeviceCapturer : public ExprMutator { /*! \brief Rewrite the "on_device" calls (and implicitly re-type-check). */ tvm::transform::Pass Rewrite() { auto pass_func = [](Function f, IRModule m, transform::PassContext ctxt) { - return Downcast(RewriteOnDevices().Mutate(f)); + return Downcast(RewriteOnDevices(std::move(m)).Mutate(f)); }; return tvm::relay::transform::CreateFunctionPass(pass_func, 0, "PlanDevicesRewrite", {}); } @@ -1065,6 +1160,7 @@ tvm::transform::Pass PlanDevicesCore(CompilationConfig config) { // Choose sensible default devices for every sub-expression if otherwise unconstrained // by existing "on_device" or "device_copy" calls. + domains = FreeOnDeviceDefaulter(mod, std::move(domains)).Default(); domains = DeviceDefaulter(mod, std::move(domains)).Default(); VLOG(3) << "Domains after defaulting: " << std::endl << domains->ToString(); @@ -1078,9 +1174,7 @@ tvm::transform::Pass PlanDevicesCore(CompilationConfig config) { } // namespace -/****** -******* Overall composite Pass -*******/ +/* =============== Driver =============== */ // This function is declared in the public . tvm::transform::Pass PlanDevices(CompilationConfig config) { diff --git a/src/relay/transforms/fold_constant.cc b/src/relay/transforms/fold_constant.cc index 05ee9d5ad592..831d28b48540 100644 --- a/src/relay/transforms/fold_constant.cc +++ b/src/relay/transforms/fold_constant.cc @@ -145,7 +145,7 @@ class ConstantFolder : public MixedModeMutator { Expr Rewrite_(const CallNode* pre_call_node, const Expr& post) final { Call pre_call = GetRef(pre_call_node); if (inside_primitive_) { - return pre_call; + return std::move(pre_call); } Call post_call = Downcast(post); @@ -215,12 +215,12 @@ class ConstantFolder : public MixedModeMutator { OnDeviceProps props = GetOnDeviceProps(post_tuple_get_item_node->tuple); if (props.body.defined()) { // (on_device((x, y, z), se_scope=D).1 ==> on_device(y, se_scope=D) - return MaybeOnDevice(result, props.se_scope, props.is_fixed); + return MaybeOnDeviceWithProps(result, props); } else { return result; } } - return std::move(post_tuple_get_item); + return post_tuple_get_item; } // Convert value to expression. diff --git a/src/relay/transforms/memory_alloc.cc b/src/relay/transforms/memory_alloc.cc index 00be629eabff..25827d5e918d 100644 --- a/src/relay/transforms/memory_alloc.cc +++ b/src/relay/transforms/memory_alloc.cc @@ -69,6 +69,8 @@ class DialectRewriter : public transform::DeviceAwareExprMutator { Function Rewrite(const Function& expr) { return Downcast(Mutate(expr)); } private: + using ExprMutator::VisitExpr_; + Expr VisitExpr_(const TupleNode* tuple_node) final { LetList& scope = scopes_.back(); Array new_fields; @@ -77,8 +79,10 @@ class DialectRewriter : public transform::DeviceAwareExprMutator { for (auto field : tuple_node->fields) { auto new_field = Mutate(field); if (new_field->IsInstance()) { + SEScope se_scope = GetSEScope(field); + ICHECK(!se_scope->IsFullyUnconstrained()); Var const_var("const", Type(nullptr)); - new_field = scope.Push(const_var, new_field); + new_field = scope.Push(const_var, MaybeOnDeviceFixed(new_field, se_scope)); } new_fields.push_back(new_field); } @@ -89,7 +93,9 @@ class DialectRewriter : public transform::DeviceAwareExprMutator { std::pair PreVisitLetBinding_(const Var& var, const Expr& value) final { Expr new_value = Mutate(value); - scopes_.back().Push(var, new_value); + SEScope se_scope = GetSEScope(value); + ICHECK(!se_scope->IsFullyUnconstrained()); + scopes_.back().Push(var, MaybeOnDeviceFixed(new_value, se_scope)); // Since we always need a let block on which to bind sub-expressions the rewritten bindings // are tracked in the current scopes. But return the rewritten binding anyway. return {var, new_value}; @@ -127,6 +133,7 @@ class DialectRewriter : public transform::DeviceAwareExprMutator { VLOG(1) << "converting lowered call to DPS:" << std::endl << PrettyPrint(call); SEScope se_scope = GetSEScope(call); + ICHECK(!se_scope->IsFullyUnconstrained()); LetList& scope = scopes_.back(); std::vector new_args; @@ -176,7 +183,7 @@ class DialectRewriter : public transform::DeviceAwareExprMutator { Expr invoke = InvokeTVMOp(call_lowered_props.lowered_func, ins, outs, Downcast(call_lowered_props.attrs.metadata.at("relay_attrs"))); - scope.Push(OnDevice(invoke, se_scope, /*is_fixed=*/true)); + scope.Push(MaybeOnDeviceFixed(invoke, se_scope)); return ToTupleType(ret_type, std::vector(outputs.begin(), outputs.end())); } @@ -192,8 +199,7 @@ class DialectRewriter : public transform::DeviceAwareExprMutator { /*! Returns an \p alloc_tensor call for a tensor of \p shape and \p dtype over \p storage. */ inline Expr AllocTensor(const Expr& storage, tvm::relay::Expr shape, DataType dtype, Array assert_shape) { - Expr offset = OnDevice(MakeConstantScalar(DataType::Int(64), 0), host_se_scope_, - /*is_fixed=*/true); + Expr offset = MaybeOnDeviceFixed(MakeConstantScalar(DataType::Int(64), 0), host_se_scope_); return tvm::relay::AllocTensor(storage, std::move(offset), std::move(shape), dtype, assert_shape); } @@ -236,22 +242,20 @@ class DialectRewriter : public transform::DeviceAwareExprMutator { CHECK(imm) << "expect static int shape"; int_shape.push_back(imm->value); } - Expr shape = OnDevice(MakeConstant(int_shape), host_se_scope_, /*is_fixed=*/true); - Expr size = OnDevice(ComputeStorage(type), host_se_scope_, /*is_fixed=*/true); + Expr shape = MaybeOnDeviceFixed(MakeConstant(int_shape), host_se_scope_); + Expr size = MaybeOnDeviceFixed(ComputeStorage(type), host_se_scope_); // Alignment is directly captured in the instruction rather than calculated, so we // don't want to wrap it with an "on_device". Expr alignment = ComputeAlignment(type->dtype); // Run type inference later to get the correct type. Var var("storage_" + name_hint, Type(nullptr)); - Expr value = OnDevice(AllocStorage(size, alignment, se_scope, type->dtype), se_scope, - /*is_fixed=*/true); - auto sto = scope->Push(var, value); + Expr value = AllocStorage(size, alignment, se_scope, type->dtype); + auto sto = scope->Push(var, MaybeOnDeviceFixed(value, se_scope)); // TODO(@jroesch): There is a bug with typing based on the constant shape. - auto tensor = OnDevice(AllocTensor(sto, shape, type->dtype, /*assert_shape=*/type->shape), - se_scope, /*is_fixed=*/true); + auto tensor = AllocTensor(sto, shape, type->dtype, /*assert_shape=*/type->shape); Var tensor_var("tensor_" + name_hint, Type(nullptr)); - return scope->Push(tensor_var, tensor); + return scope->Push(tensor_var, MaybeOnDeviceFixed(tensor, se_scope)); } /*! @@ -287,23 +291,24 @@ class DialectRewriter : public transform::DeviceAwareExprMutator { if (state == tec::kNeedInputShape) { std::vector exprs = FromTupleType(ty, arg); for (size_t j = 0; j < exprs.size(); ++j) { - Expr sh_of = Mutate(ShapeOf(exprs[j])); // already accounts for device + Expr sh_of = Mutate(ShapeOf(exprs[j])); Var in_shape_var("in_shape_" + std::to_string(input_pos + j), Type(nullptr)); - shape_func_ins.push_back(scope->Push(in_shape_var, sh_of)); + shape_func_ins.push_back( + scope->Push(in_shape_var, MaybeOnDeviceFixed(sh_of, host_se_scope_))); input_pos++; } } else if (state == tec::kNeedInputData) { auto new_arg = Mutate(arg); // already accounts for device SEScope arg_se_scope = GetSEScope(arg); + ICHECK(!arg_se_scope->IsFullyUnconstrained()); // The dynamic shape function is expecting its data on the host/CPU, so insert a // device_copy otherwise. (We'll need to fuse & lower these copies in the same way // we fuse & lower other operators we insert for, eg, dynamic tensor size calculation.) - if (arg_se_scope != host_se_scope_) { - new_arg = OnDevice(DeviceCopy(new_arg, arg_se_scope, host_se_scope_), host_se_scope_, - /*is_fixed=*/true); - } + new_arg = MaybeDeviceCopy(MaybeOnDeviceFixed(new_arg, arg_se_scope), arg_se_scope, + host_se_scope_); Var in_shape_var("in_shape_" + std::to_string(input_pos), Type(nullptr)); - shape_func_ins.push_back(scope->Push(in_shape_var, new_arg)); + shape_func_ins.push_back( + scope->Push(in_shape_var, MaybeOnDeviceFixed(new_arg, host_se_scope_))); input_pos++; } else { // TODO(@jroesch): handle kNeedBoth @@ -322,20 +327,16 @@ class DialectRewriter : public transform::DeviceAwareExprMutator { ICHECK(tensor_type_node); // Put the shape func on the host. This also ensures that everything between // shape_of and shape_func is similarly on the host. - Expr alloc = MakeStaticAllocation(scope, GetRef(tensor_type_node), host_se_scope_, - std::to_string(i)); - // TODO(mbs): Don't really need a fresh var here since alloc will always be a var. - Var shape_func_out_var("shape_func_out_" + std::to_string(i), Type(nullptr)); - alloc = scope->Push(shape_func_out_var, alloc); + Var alloc = MakeStaticAllocation(scope, GetRef(tensor_type_node), host_se_scope_, + "out_shape_" + std::to_string(i)); out_shapes.push_back(alloc); } // Represent the call in DPS form. - auto shape_call = OnDevice(InvokeTVMOp(prim_fn_var, Tuple(shape_func_ins), Tuple(out_shapes), - Downcast(attrs.metadata.at("relay_attrs"))), - host_se_scope_, /*is_fixed=*/true); + auto shape_call = InvokeTVMOp(prim_fn_var, Tuple(shape_func_ins), Tuple(out_shapes), + Downcast(attrs.metadata.at("relay_attrs"))); Var shape_func_var("shape_func", Type(nullptr)); - scope->Push(shape_func_var, shape_call); + scope->Push(shape_func_var, MaybeOnDeviceFixed(shape_call, host_se_scope_)); return out_shapes; } @@ -349,14 +350,12 @@ class DialectRewriter : public transform::DeviceAwareExprMutator { for (size_t i = 0; i < out_shapes.size(); ++i) { auto out_shape = out_shapes[i]; auto out_type = out_types[i]; - auto size = OnDevice(ComputeStorageInRelay(out_shape, out_type), host_se_scope_, - /*is_fixed=*/true); + auto size = MaybeOnDeviceFixed(ComputeStorageInRelay(out_shape, out_type), host_se_scope_); // Alignment is directly captured in the instruction so don't wrap in "on_device". auto alignment = ComputeAlignment(out_type->dtype); Var sto_var("storage_" + std::to_string(i), Type(nullptr)); - auto val = OnDevice(AllocStorage(size, alignment, se_scope, out_type->dtype), se_scope, - /*is_fixed=*/true); - storages.push_back(scope->Push(sto_var, val)); + auto val = AllocStorage(size, alignment, se_scope, out_type->dtype); + storages.push_back(scope->Push(sto_var, MaybeOnDeviceFixed(val, se_scope))); } Array outs; @@ -364,17 +363,15 @@ class DialectRewriter : public transform::DeviceAwareExprMutator { auto out_shape = out_shapes[i]; auto out_type = out_types[i]; auto storage = storages[i]; - auto alloc = OnDevice(AllocTensor(storage, out_shape, out_type->dtype, out_type->shape), - se_scope, /*is_fixed=*/true); + auto alloc = AllocTensor(storage, out_shape, out_type->dtype, out_type->shape); Var out_var("out_" + std::to_string(i), Type(nullptr)); - outs.push_back(scope->Push(out_var, alloc)); + outs.push_back(scope->Push(out_var, MaybeOnDeviceFixed(alloc, se_scope))); } Tuple tuple_outs(outs); auto call = InvokeTVMOp(func, ins, tuple_outs, Downcast(attrs.metadata.at("relay_attrs"))); - auto invoke = OnDevice(call, se_scope, /*is_fixed=*/true); - scope->Push(invoke); + scope->Push(MaybeOnDeviceFixed(call, se_scope)); return ToTupleType(ret_type, std::vector(tuple_outs->fields.begin(), tuple_outs->fields.end())); } @@ -398,7 +395,7 @@ class DialectRewriter : public transform::DeviceAwareExprMutator { CHECK(imm) << "expect static int shape"; shape.push_back(imm->value); } - shape_expr = OnDevice(MakeConstant(shape), host_se_scope_, /*is_fixed=*/true); + shape_expr = MaybeOnDeviceFixed(MakeConstant(shape), host_se_scope_); } return ReshapeTensor(ins->fields[0], shape_expr, ret_ty->shape); } diff --git a/src/relay/transforms/to_a_normal_form.cc b/src/relay/transforms/to_a_normal_form.cc index f958a600551e..c955269e3412 100644 --- a/src/relay/transforms/to_a_normal_form.cc +++ b/src/relay/transforms/to_a_normal_form.cc @@ -211,14 +211,14 @@ class Fill : ExprFunctor, private transform::Lexi } Expr Atomic(const Expr& e, const Var& v) { - Expr annotated_expr = MaybeOnDevice(e, GetSEScope(e), /*is_fixed=*/true); + Expr annotated_expr = MaybeOnDeviceFixed(e, GetSEScope(e)); return v.defined() ? GetScope(e)->let_list->Push(v, annotated_expr) : annotated_expr; } // Bind expression `now` to var `v` if the original expression is in the include set, or if // v is already defined (e.g. coming from a Let expression). Otherwise return `now` directly Expr Compound(const Expr& orig, const Expr& now, const Var& v) { - Expr annotated_expr = MaybeOnDevice(now, GetSEScope(orig), /*is_fixed=*/true); + Expr annotated_expr = MaybeOnDeviceFixed(now, GetSEScope(orig)); Var var = v.defined() ? v : Var(String("x"), Type()); bool not_included = include_set_ && include_set_->find(orig) == include_set_->end(); if (!v.defined() && not_included) { @@ -230,14 +230,14 @@ class Fill : ExprFunctor, private transform::Lexi Expr VisitExpr_(const CallNode* c, const Var& v) final { OnDeviceProps props = GetOnDeviceProps(c); - if (props.body.defined() && props.is_fixed) { + if (props.body.defined() && props.is_fixed()) { // Keep track of expression device type for lexically enclosing sub-expressions. PushSEScope(props.se_scope); Expr body = VisitExpr(props.body, v); // We are done with this sub-expression. PopSEScope(); // Preserve the "on_device" annotations. - return OnDevice(body, props.se_scope, props.is_fixed); + return OnDeviceWithProps(body, props); } Expr e = GetRef(c); diff --git a/src/runtime/vm/executable.cc b/src/runtime/vm/executable.cc index b613a03bfc5c..44971c0bcee9 100644 --- a/src/runtime/vm/executable.cc +++ b/src/runtime/vm/executable.cc @@ -178,7 +178,7 @@ std::string Executable::GetConstants() const { for (size_t i = 0; i < constants.size(); ++i) { const auto& constant = constants[i]; auto ndarray = Downcast(constant); - oss << "VM Constant[" << i << "]: has shape " << ShapeString(ndarray.Shape(), ndarray->dtype) + oss << "VM Const[" << i << "]: has shape " << ShapeString(ndarray.Shape(), ndarray->dtype) << " on device index " << const_device_indexes[i] << std::endl; } return oss.str(); diff --git a/src/target/se_scope.cc b/src/target/se_scope.cc index 95d5a7de5775..8e6c6fe7f2a2 100644 --- a/src/target/se_scope.cc +++ b/src/target/se_scope.cc @@ -62,10 +62,6 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "memory_scope='" << node->memory_scope << "'"; } } -#if TVM_LOG_DEBUG - // We rely on object identity of SEScopes, so include the object address to help debugging. - p->stream << ", id=" << reinterpret_cast(ref.get()); -#endif p->stream << ")"; }); @@ -173,11 +169,9 @@ SEScope SEScopeCache::Make(DLDeviceType device_type, int virtual_device_id, Targ SEScope prototype(device_type, virtual_device_id, std::move(target), std::move(memory_scope)); auto itr = cache_.find(prototype); if (itr == cache_.end()) { - VLOG(1) << "added new scope " << prototype; cache_.emplace(prototype); return prototype; } else { - VLOG(1) << "reusing existing scope " << *itr; ICHECK_EQ(prototype->target.defined(), (*itr)->target.defined()); if (prototype->target.defined()) { ICHECK_EQ(prototype->target->host.defined(), (*itr)->target->host.defined()); diff --git a/src/target/target.cc b/src/target/target.cc index 792884061db6..a5c493a582ab 100644 --- a/src/target/target.cc +++ b/src/target/target.cc @@ -566,10 +566,6 @@ String TargetNode::ToDebugString() const { if (host.defined()) { os << ", host=" << GetHost().value()->ToDebugString(); } -#if TVM_LOG_DEBUG - // We depend on pointer equality so include that in the debug representation. - os << ", id=" << reinterpret_cast(this); -#endif os << ")"; return os.str(); } diff --git a/tests/cpp/relay/op/memory/on_device_test.cc b/tests/cpp/relay/op/memory/on_device_test.cc new file mode 100644 index 000000000000..45d4f881c454 --- /dev/null +++ b/tests/cpp/relay/op/memory/on_device_test.cc @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "../../../../../src/relay/op/memory/on_device.h" + +#include + +#include + +namespace tvm { +namespace relay { + +TEST(OnDeviceOp, Name) { EXPECT_EQ(OnDeviceOp()->name, "on_device"); } + +TEST(OnDevice, Default) { + Var body("x", {}); + SEScope se_scope = SEScope::ForDeviceType(kDLCPU, 3); + Call call = OnDevice(body, se_scope); + EXPECT_EQ(call->op, OnDeviceOp()); + EXPECT_EQ(call->args.size(), 1); + EXPECT_EQ(call->args[0], body); + const auto* attrs = call->attrs.as(); + ASSERT_TRUE(attrs != nullptr); + EXPECT_EQ(attrs->se_scope, se_scope); + EXPECT_FALSE(attrs->constrain_result); + EXPECT_TRUE(attrs->constrain_body); +} + +TEST(OnDevice, Fixed) { + Var body("x", {}); + SEScope se_scope = SEScope::ForDeviceType(kDLCPU, 3); + Call call = OnDevice(body, se_scope, /*constrain_result=*/true); + const auto* attrs = call->attrs.as(); + ASSERT_TRUE(attrs != nullptr); + EXPECT_TRUE(attrs->constrain_result); + EXPECT_TRUE(attrs->constrain_body); +} + +TEST(OnDevice, Free) { + Var body("x", {}); + SEScope se_scope = SEScope::ForDeviceType(kDLCPU, 3); + Call call = OnDevice(body, se_scope, /*constrain_result=*/false, /*constrain_body=*/false); + const auto* attrs = call->attrs.as(); + ASSERT_TRUE(attrs != nullptr); + EXPECT_FALSE(attrs->constrain_result); + EXPECT_FALSE(attrs->constrain_body); +} + +TEST(GetOnDeviceProps, Correct) { + Var body("x", {}); + SEScope se_scope = SEScope::ForDeviceType(kDLCPU, 3); + Call call = OnDevice(body, se_scope, /*constrain_result=*/true, /*constrain_body=*/false); + OnDeviceProps props = GetOnDeviceProps(call); + ASSERT_TRUE(props.body.defined()); + ASSERT_EQ(props.se_scope, se_scope); + ASSERT_TRUE(props.constrain_result); + ASSERT_FALSE(props.constrain_body); +} + +TEST(MaybeOnDevice, Wrapped) { + SEScope se_scope = SEScope::ForDeviceType(kDLCPU, 3); + Var body("x", {}); + Call inner = OnDevice(body, se_scope); + Call outer = OnDevice(inner, se_scope); + OnDeviceProps props = GetOnDeviceProps(outer); + ASSERT_TRUE(props.body.defined()); + ASSERT_EQ(props.se_scope, se_scope); + ASSERT_FALSE(props.constrain_result); + ASSERT_TRUE(props.constrain_body); +} + +} // namespace relay +} // namespace tvm diff --git a/tests/python/relay/op/annotation/test_annotation.py b/tests/python/relay/op/annotation/test_annotation.py index 8ba91976523a..5ad2a59e39ab 100644 --- a/tests/python/relay/op/annotation/test_annotation.py +++ b/tests/python/relay/op/annotation/test_annotation.py @@ -30,7 +30,8 @@ def test_on_device_via_string(): assert call.attrs.se_scope.virtual_device_id == 0 assert call.attrs.se_scope.target is None assert call.attrs.se_scope.memory_scope == "" - assert not call.attrs.is_fixed + assert call.attrs.constrain_body + assert not call.attrs.constrain_result def test_on_device_via_device(): @@ -44,11 +45,20 @@ def test_on_device_invalid_device(): pytest.raises(ValueError, lambda: relay.annotation.on_device(x, "bogus")) -def test_on_device_is_fixed(): +def test_on_device_fixed(): x = relay.Var("x") - call = relay.annotation.on_device(x, "cuda", True) + call = relay.annotation.on_device(x, "cuda", constrain_result=True) assert call.attrs.se_scope.device_type_int == 2 # ie kDLCUDA - assert call.attrs.is_fixed + assert call.attrs.constrain_body + assert call.attrs.constrain_result + + +def test_on_device_free(): + x = relay.Var("x") + call = relay.annotation.on_device(x, "cuda", constrain_result=False, constrain_body=False) + assert call.attrs.se_scope.device_type_int == -1 # ie kInvalidDeviceType + assert not call.attrs.constrain_body + assert not call.attrs.constrain_result def test_function_on_device(): diff --git a/tests/python/relay/test_pass_fold_constant.py b/tests/python/relay/test_pass_fold_constant.py index 3a5f458d5970..7b7ef0ce920f 100644 --- a/tests/python/relay/test_pass_fold_constant.py +++ b/tests/python/relay/test_pass_fold_constant.py @@ -29,7 +29,7 @@ def annot_func(f): def annot_expr(e): """Returns e wrapped with an on_device annotation.""" - return relay.op.annotation.on_device(e, tvm.cpu(), is_fixed=True) + return relay.op.annotation.on_device(e, tvm.cpu(), constrain_result=True) def run_opt_pass(expr, opt_pass): diff --git a/tests/python/relay/test_pass_plan_devices.py b/tests/python/relay/test_pass_plan_devices.py index 4ef726fbeac2..6bf103ea0c2a 100644 --- a/tests/python/relay/test_pass_plan_devices.py +++ b/tests/python/relay/test_pass_plan_devices.py @@ -45,6 +45,9 @@ GPU = tvm.target.make_se_scope(GPU_DEVICE, GPU_TARGET) # device_type=2 DEFAULT = GPU +CPU_SCOPE_A = tvm.target.make_se_scope(CPU_DEVICE, CPU_TARGET, memory_scope="scopeA") +CPU_SCOPE_B = tvm.target.make_se_scope(CPU_DEVICE, CPU_TARGET, memory_scope="scopeB") + CTXT = tvm.transform.PassContext(config={"relay.fallback_device_type": DEFAULT.device_type_int}) core = tvm.IRModule() @@ -178,7 +181,7 @@ def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], param_se_scopes=[meta[SEScope][0], meta[SEScope][0], meta[SEScope][1], meta[SEScope][1]], result_se_scope=meta[SEScope][1]) { %0 = add(%a, %b); - %1 = on_device(%0, se_scope=meta[SEScope][0], is_fixed=True); + %1 = on_device(%0, se_scope=meta[SEScope][0], constrain_result=True); %2 = device_copy(%1, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); %3 = add(%c, %d); subtract(%2, %3) @@ -225,7 +228,7 @@ def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], param_se_scopes=[meta[SEScope][0], meta[SEScope][0], meta[SEScope][1], meta[SEScope][1]], result_se_scope=meta[SEScope][1]) { %0 = add(%a, %b); - %1 = on_device(%0, se_scope=meta[SEScope][0], is_fixed=True); + %1 = on_device(%0, se_scope=meta[SEScope][0], constrain_result=True); %2 = device_copy(%1, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); %3 = add(%c, %d); subtract(%2, %3) @@ -272,9 +275,9 @@ def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], param_se_scopes=[meta[SEScope][0], meta[SEScope][0], meta[SEScope][0], meta[SEScope][0]], result_se_scope=meta[SEScope][1]) { %0 = add(%a, %b); - %1 = on_device(%0, se_scope=meta[SEScope][0], is_fixed=True); + %1 = on_device(%0, se_scope=meta[SEScope][0], constrain_result=True); %2 = add(%c, %d); - %3 = on_device(%2, se_scope=meta[SEScope][0], is_fixed=True); + %3 = on_device(%2, se_scope=meta[SEScope][0], constrain_result=True); %4 = device_copy(%1, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); %5 = device_copy(%3, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); subtract(%4, %5) @@ -318,8 +321,8 @@ def expected(): def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], param_se_scopes=[meta[SEScope][0], meta[SEScope][0]], result_se_scope=meta[SEScope][1]) { %0 = add(%a, %b); - %1 = on_device(%0, se_scope=meta[SEScope][0], is_fixed=True); - %2 = on_device(%0, se_scope=meta[SEScope][0], is_fixed=True); + %1 = on_device(%0, se_scope=meta[SEScope][0], constrain_result=True); + %2 = on_device(%0, se_scope=meta[SEScope][0], constrain_result=True); %3 = device_copy(%1, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); %4 = device_copy(%2, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); subtract(%3, %4) @@ -367,8 +370,8 @@ def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], param_se_scopes=[meta[SEScope][0], meta[SEScope][0], meta[SEScope][1], meta[SEScope][1]], result_se_scope=meta[SEScope][1]) { %0 = add(%a, %b); - let %l = on_device(%0, se_scope=meta[SEScope][0], is_fixed=True); - let %r = add(%c, %d); + let %l = on_device(%0, se_scope=meta[SEScope][0], constrain_result=True); + let %r = on_device(add(%c, %d), se_scope=meta[SEScope][1], constrain_result=True); %1 = device_copy(%l, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); subtract(%1, %r) } @@ -473,7 +476,7 @@ def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], add(%x, %y) }; %1 = %f(%a, %b); - %2 = on_device(%1, se_scope=meta[SEScope][0], is_fixed=True); + %2 = on_device(%1, se_scope=meta[SEScope][0], constrain_result=True); %3 = device_copy(%2, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); %4 = add(%c, %d); subtract(%3, %4) @@ -591,7 +594,7 @@ def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], param_se_scopes=[meta[SEScope][0], meta[SEScope][0]], result_se_scope=meta[SEScope][0]) { add(%a, %b) }; - let %t = (%f, %x); + let %t = on_device((%f, %x), se_scope=meta[SEScope][0], constrain_result=True); %0 = %t.1; %1 = %t.0; %1(%0, %y) @@ -650,7 +653,7 @@ def ref(x): def test_shape_of(): metatable = {"SEScope": [HOST, GPU]} - # We need to use is_fixed=True in the on_device call so that the tensor will be on the GPU. Otherwise the + # We need to use constrain_result=True in the on_device call so that the tensor will be on the GPU. Otherwise the # result defaults to the result device for @main which is the CPU, thus forcing a copy. # TODO(mbs): Perhaps the defaulting heuristics are being too clever? def input(): @@ -658,7 +661,7 @@ def input(): """ #[version = "0.0.5"] def @main(%x: Tensor[(?, ?), float32]) { - %0 = on_device(%x, se_scope=meta[SEScope][1], is_fixed=True); + %0 = on_device(%x, se_scope=meta[SEScope][1], constrain_result=True); vm.shape_of(%0, dtype="int64") } """, @@ -744,8 +747,8 @@ def expected(): """ #[version = "0.0.5"] def @main(%sto: Storage[], param_se_scopes=[meta[SEScope][1]], result_se_scope=meta[SEScope][1]) { - %0 = on_device(0, se_scope=meta[SEScope][0], is_fixed=True); - %1 = on_device(meta[relay.Constant][0], se_scope=meta[SEScope][0], is_fixed=True); + %0 = on_device(0, se_scope=meta[SEScope][0], constrain_result=True); + %1 = on_device(meta[relay.Constant][0], se_scope=meta[SEScope][0], constrain_result=True); memory.alloc_tensor(%sto, %0, %1, const_shape=meta[relay.Constant][0], assert_shape=[]) } """, @@ -781,7 +784,7 @@ def expected(): #[version = "0.0.5"] def @main(%x: Tensor[(2, 8), float32], param_se_scopes=[meta[SEScope][1]], result_se_scope=meta[SEScope][1]) { - %0 = on_device(meta[relay.Constant][0], se_scope=meta[SEScope][0], is_fixed=True); + %0 = on_device(meta[relay.Constant][0], se_scope=meta[SEScope][0], constrain_result=True); vm.reshape_tensor(%x, %0, newshape=[2, 4, 2]) } """, @@ -861,9 +864,9 @@ def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], %z: Tensor[( param_se_scopes=[meta[SEScope][0], meta[SEScope][0], meta[SEScope][1]], result_se_scope=meta[SEScope][1]) { %0 = add(%x, %y); - %1 = on_device(%0, se_scope=meta[SEScope][0], is_fixed=True); + %1 = on_device(%0, se_scope=meta[SEScope][0], constrain_result=True); %2 = device_copy(%1, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); - %3 = on_device(%0, se_scope=meta[SEScope][0], is_fixed=True); + %3 = on_device(%0, se_scope=meta[SEScope][0], constrain_result=True); %4 = subtract(%2, %z); %5 = device_copy(%3, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); add(%4, %5) @@ -908,7 +911,7 @@ def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], %z: Tensor[( param_se_scopes=[meta[SEScope][1], meta[SEScope][1], meta[SEScope][0]], result_se_scope=meta[SEScope][0]) { %0 = add(%x, %y); - %1 = on_device(%0, se_scope=meta[SEScope][1], is_fixed=True); + %1 = on_device(%0, se_scope=meta[SEScope][1], constrain_result=True); %2 = device_copy(%1, src_se_scope=meta[SEScope][1], dst_se_scope=meta[SEScope][0]); subtract(%2, %z) } @@ -1010,13 +1013,13 @@ def @main(%data1: Tensor[(1, 64, 56, 56), float32], %data2: Tensor[(1, 64, 56, 5 param_se_scopes=[meta[SEScope][0], meta[SEScope][0], meta[SEScope][0]], result_se_scope=meta[SEScope][0]) { %0 = nn.conv2d(%data1, %weight, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]); - %1 = on_device(%0, se_scope=meta[SEScope][0], is_fixed=True); + %1 = on_device(%0, se_scope=meta[SEScope][0], constrain_result=True); %2 = nn.conv2d(%data2, %weight, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]); - %3 = on_device(%2, se_scope=meta[SEScope][0], is_fixed=True); + %3 = on_device(%2, se_scope=meta[SEScope][0], constrain_result=True); %4 = device_copy(%1, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); %5 = device_copy(%3, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); %6 = add(%4, %5); - %7 = on_device(%6, se_scope=meta[SEScope][1], is_fixed=True); + %7 = on_device(%6, se_scope=meta[SEScope][1], constrain_result=True); %8 = device_copy(%7, src_se_scope=meta[SEScope][1], dst_se_scope=meta[SEScope][0]); nn.conv2d(%8, %weight, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]) } @@ -1061,11 +1064,11 @@ def expected(): def @main(%x: Tensor[(3, 3, 4), float32], param_se_scopes=[meta[SEScope][0]], result_se_scope=meta[SEScope][1]) { %0 = split(%x, indices_or_sections=3); - let %t = on_device(%0, se_scope=meta[SEScope][0], is_fixed=True); + let %t = on_device(%0, se_scope=meta[SEScope][0], constrain_result=True); %1 = %t.0; - %2 = on_device(%1, se_scope=meta[SEScope][0], is_fixed=True); + %2 = on_device(%1, se_scope=meta[SEScope][0], constrain_result=True); %3 = %t.1; - %4 = on_device(%3, se_scope=meta[SEScope][0], is_fixed=True); + %4 = on_device(%3, se_scope=meta[SEScope][0], constrain_result=True); %5 = device_copy(%2, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); %6 = device_copy(%4, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); subtract(%5, %6) @@ -1129,14 +1132,14 @@ def expected(): def @main(%x: Tensor[(5, 7), float32], param_se_scopes=[meta[SEScope][0]], result_se_scope=meta[SEScope][0]) { %0 = negative(%x); - %1 = on_device(%0, se_scope=meta[SEScope][0], is_fixed=True); + %1 = on_device(%0, se_scope=meta[SEScope][0], constrain_result=True); %2 = device_copy(%1, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); - %3 = on_device(%0, se_scope=meta[SEScope][0], is_fixed=True); + %3 = on_device(%0, se_scope=meta[SEScope][0], constrain_result=True); %4 = device_copy(%3, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); %5 = negative(%2); %6 = negative(%4); %7 = add(%5, %6); - %8 = on_device(%7, se_scope=meta[SEScope][1], is_fixed=True); + %8 = on_device(%7, se_scope=meta[SEScope][1], constrain_result=True); %9 = device_copy(%8, src_se_scope=meta[SEScope][1], dst_se_scope=meta[SEScope][0]); negative(%9) } @@ -1199,14 +1202,14 @@ def expected(): def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], param_se_scopes=[meta[SEScope][1], meta[SEScope][1]], result_se_scope=meta[SEScope][0]) { %0 = add(%x, %y); - %1 = on_device(%0, se_scope=meta[SEScope][1], is_fixed=True); + %1 = on_device(%0, se_scope=meta[SEScope][1], constrain_result=True); %2 = device_copy(%1, src_se_scope=meta[SEScope][1], dst_se_scope=meta[SEScope][0]); %3 = negative(%2); - %4 = on_device(%3, se_scope=meta[SEScope][0], is_fixed=True); + %4 = on_device(%3, se_scope=meta[SEScope][0], constrain_result=True); %5 = device_copy(%4, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); %6 = negative(%0); %7 = add(%5, %6); - %8 = on_device(%7, se_scope=meta[SEScope][1], is_fixed=True); + %8 = on_device(%7, se_scope=meta[SEScope][1], constrain_result=True); %9 = device_copy(%8, src_se_scope=meta[SEScope][1], dst_se_scope=meta[SEScope][0]); negative(%9) } @@ -1267,7 +1270,7 @@ def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], param_se_scopes=[meta[SEScope][0], meta[SEScope][0], meta[SEScope][1], meta[SEScope][1]], result_se_scope=meta[SEScope][0]) { %0 = multiply(%c, %d); - %1 = on_device(%0, se_scope=meta[SEScope][1], is_fixed=True); + %1 = on_device(%0, se_scope=meta[SEScope][1], constrain_result=True); %2 = add(%a, %b); %3 = device_copy(%1, src_se_scope=meta[SEScope][1], dst_se_scope=meta[SEScope][0]); subtract(%2, %3) @@ -1294,7 +1297,7 @@ def input(): #[version = "0.0.5"] def @main(%x: bool, %y: Tensor[(5, 7), float32], %z: Tensor[(5, 7), float32]) { let %f = fn (%a) { - %0 = on_device(%y, se_scope=meta[SEScope][0], is_fixed=True); + %0 = on_device(%y, se_scope=meta[SEScope][0], constrain_result=True); add(%a, %0) }; let %g = fn (%a1) { @@ -1326,11 +1329,11 @@ def @main(%x: bool, %y: Tensor[(5, 7), float32], %z: Tensor[(5, 7), float32], let %g = fn (%a1, param_se_scopes=[meta[SEScope][0]], result_se_scope=meta[SEScope][0]) { subtract(%a1, %y) }; - let %h = if (%x) { + let %h = on_device(if (%x) { %f } else { %g - }; + }, se_scope=meta[SEScope][0], constrain_result=True); %h(%z) } """, @@ -1430,9 +1433,9 @@ def expected(): #[version = "0.0.5"] def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], param_se_scopes=[meta[SEScope][1], meta[SEScope][0]], result_se_scope=meta[SEScope][1]) { - let %r = ref(%x); + let %r = on_device(ref(%x), se_scope=meta[SEScope][1], constrain_result=True); %0 = device_copy(%y, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); - ref_write(%r, %0); + on_device(ref_write(%r, %0), se_scope=meta[SEScope][1], constrain_result=True); %1 = ref_read(%r); add(%x, %1) } @@ -1463,7 +1466,7 @@ def input(): Nil, } def @main(%x : Tensor[(5, 7), float32], %y : Tensor[(5, 7), float32]) { - %0 = on_device(%y, se_scope=meta[SEScope][0], is_fixed=True); + %0 = on_device(%y, se_scope=meta[SEScope][0], constrain_result=True); %1 = Nil; %2 = Cons(%0, %1); let %l = Cons(%x, %2); @@ -1489,7 +1492,7 @@ def @main(%x : Tensor[(5, 7), float32], %y : Tensor[(5, 7), float32], param_se_scopes=[meta[SEScope][0], meta[SEScope][0]], result_se_scope=meta[SEScope][0]) { %0 = Nil; %1 = Cons(%y, %0); - let %l = Cons(%x, %1); + let %l = on_device(Cons(%x, %1), se_scope=meta[SEScope][0], constrain_result=True); match? (%l) { Cons(%z, _) => %z } @@ -1507,6 +1510,76 @@ def ref(x, y): exercise(input(), expected(), ref, rands((5, 7), 2)) +def test_free_on_device(): + """Tests that the 'free' form of on_device (ie with constrain_body=False) can be used to allow + a device_copy to be inserted if necessary, but otherwise does not prevent the flow of + device information.""" + metatable = { + "SEScope": [ + CPU, # no memory scope constraint + CPU_SCOPE_A, # constrain to scopeA + CPU_SCOPE_B, + ] + } # constrain to scopeB + + # Everything defaults to GPU + def input(): + return tvm.parser.parse( + """ + #[version = "0.0.5"] + def @on_scope_b(%x: Tensor[(5, 7), float32], + param_se_scopes=[meta[SEScope][2]], + result_se_scope=meta[SEScope][2]) -> Tensor[(5, 7), float32] { + %x + } + def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], %c: Tensor[(5, 7), float32], + param_se_scopes=[meta[SEScope][0], meta[SEScope][1], meta[SEScope][2]], + result_se_scope=meta[SEScope][1]) { + // %a's memory scope is unconstrained, so will take on "scopeB" and on_device has no effect + %0 = @on_scope_b(on_device(%a, se_scope=meta[SEScope][0], constrain_body=False)); + // %b's memory scope is "scopeA", so will require a "scopeA"->"scopeB" copy. + %1 = @on_scope_b(on_device(%b, se_scope=meta[SEScope][0], constrain_body=False)); + // %c's memory scope is "scopeB", so no copy required. + %2 = @on_scope_b(on_device(%c, se_scope=meta[SEScope][0], constrain_body=False)); + // result's memory scope is is on "scopeA", so will require a "scopeB"->"scopeA" copy. + %3 = add(add(%0, %1), %2); + on_device(%3, se_scope=meta[SEScope][0], constrain_body=False) + } + """, + "from_string", + None, + metatable, + ) + + def expected(): + return tvm.parser.parse( + """ + #[version = "0.0.5"] + def @on_scope_b(%x: Tensor[(5, 7), float32], + param_se_scopes=[meta[SEScope][2]], + result_se_scope=meta[SEScope][2]) -> Tensor[(5, 7), float32] { + %x + } + def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], %c: Tensor[(5, 7), float32], + param_se_scopes=[meta[SEScope][2], meta[SEScope][1], meta[SEScope][2]], + result_se_scope=meta[SEScope][1]) { + %0 = @on_scope_b(%a); + %1 = device_copy(%b, src_se_scope=meta[SEScope][1], dst_se_scope=meta[SEScope][2]); + %2 = @on_scope_b(%1); + %3 = @on_scope_b(%c); + %4 = add(add(%0, %2), %3); + %5 = on_device(%4, se_scope=meta[SEScope][2], constrain_result=True); + device_copy(%5, src_se_scope=meta[SEScope][2], dst_se_scope=meta[SEScope][1]) + } + """, + "from_string", + None, + metatable, + ) + + exercise(input(), expected(), None, None) + + if __name__ == "__main__": import sys import pytest diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py index a35f75bd1aae..ea1f4ddb3b62 100644 --- a/tests/python/relay/test_vm.py +++ b/tests/python/relay/test_vm.py @@ -1040,8 +1040,8 @@ def @main(%a: Tensor[(5, 7), float32], # - The offset of the tensor within the storage (second arg) to alloc_tensor # Both should be on the CPU assert "VirtualDevice[0]: device type 1" in exe.virtual_devices - assert "Constant[0]: has shape int64[] on device index 0" in exe.constants - assert "Constant[1]: has shape int64[] on device index 0" in exe.constants + assert "Const[0]: has shape int64[] on device index 0" in exe.constants + assert "Const[1]: has shape int64[] on device index 0" in exe.constants @tvm.testing.requires_cuda @@ -1073,7 +1073,7 @@ def @main(%x: Tensor[(2, 8), float32], # The newshape annotation should have been turned into a constant on the CPU. assert "VirtualDevice[0]: device type 1" in exe.virtual_devices - assert "Constant[0]: has shape int64[3] on device index 0" in exe.constants + assert "Const[0]: has shape int64[3] on device index 0" in exe.constants @tvm.testing.requires_cuda