diff --git a/include/tvm/target/target.h b/include/tvm/target/target.h index e0d34c87dda7..21760bdc8dbf 100644 --- a/include/tvm/target/target.h +++ b/include/tvm/target/target.h @@ -66,6 +66,15 @@ class TargetNode : public Object { /*! \return The Optional typed target host of the TargetNode */ TVM_DLL Optional GetHost() const; + /*! + * \brief Returns a human readable representation of \p Target which includes all fields, + * especially the host. Useful for diagnostic messages and debugging. + * + * TODO(mbs): The ReprPrinter version should perhaps switch to this form, however currently + * code depends on str() and << being the same. + */ + String ToDebugString() const; + void VisitAttrs(AttrVisitor* v) { v->Visit("kind", &kind); v->Visit("tag", &tag); diff --git a/src/parser/parser.cc b/src/parser/parser.cc index 486799603354..092d5b61eeec 100644 --- a/src/parser/parser.cc +++ b/src/parser/parser.cc @@ -1955,7 +1955,8 @@ TVM_REGISTER_GLOBAL("parser.ParseExpr") TVM_REGISTER_GLOBAL("relay._transform.AnnotateSpans").set_body_typed([]() { return CreateModulePass( [](const IRModule& mod, const PassContext& ctx) { - auto text = AsText(mod, true); + String text = AsText(mod, /*show_meta_data=*/true); + VLOG(1) << "AnnotateSpans intermediate text:" << std::endl << text; return ParseModule("GeneratedSource", text); }, 0, "AnnotateSpans", {}); diff --git a/src/printer/relay_text_printer.cc b/src/printer/relay_text_printer.cc index 9eca038e5c93..7454cfdf336e 100644 --- a/src/printer/relay_text_printer.cc +++ b/src/printer/relay_text_printer.cc @@ -37,6 +37,7 @@ #include #include #include +#include #include #include "../ir/attr_functor.h" @@ -120,9 +121,6 @@ Doc RelayTextPrinter::Print(const ObjectRef& node, bool meta, bool try_inline) { return PrintPattern(Downcast(node), meta); } else if (node.as()) { return PrintMod(Downcast(node)); - } else if (!show_meta_data_ && node.as()) { - // Show attributes in readable form. - return PrintAttrs(Downcast(node)); } else { // default module. std::ostringstream os; @@ -444,7 +442,7 @@ Doc RelayTextPrinter::PrintFunc(const Doc& prefix, const relay::Function& fn) { for (Var param : fn->params) { params.push_back(AllocVar(param)); } - for (const Doc& d : PrintFuncAttrs(fn->attrs)) { + for (const Doc& d : PrintDictAttrs(fn->attrs)) { params.push_back(d); } doc << Doc::Concat(params) << ") "; @@ -684,8 +682,10 @@ Doc RelayTextPrinter::VisitType_(const TensorTypeNode* node) { Doc doc; doc << "Tensor[("; std::vector shapes; - for (ObjectRef shape : node->shape) { - shapes.push_back(PrintAttr(shape)); + for (const PrimExpr& prim_expr : node->shape) { + // Though not bound within an attribute the attribute visitor will handle the PrimExprs we + // care about. + shapes.push_back(PrintAttributeValue(prim_expr)); } doc << Doc::Concat(shapes); return doc << "), " << PrintDType(node->dtype) << "]"; @@ -766,34 +766,18 @@ Doc RelayTextPrinter::VisitType_(const TypeDataNode* node) { // Overload of Attr printing functions //------------------------------------ -Doc RelayTextPrinter::PrintAttr(const ObjectRef& value, bool meta) { - if (value.defined()) { - Doc printed_attr; - if (value.as()) { - printed_attr << "?"; - } else if (auto str_obj = value.as()) { - printed_attr << Doc::StrLiteral(GetRef(str_obj)); - } else if (meta) { - printed_attr = meta_->GetMetaNode(Downcast(value)); - } else { - printed_attr = VisitAttr(value); - } - return printed_attr; - } else { - return Doc::Text("None"); - } -} - Doc RelayTextPrinter::VisitAttrDefault_(const Object* op) { - return PrintAttr(GetRef(op), /*meta=*/true); + // Since we don't have any overload for a specific attribute type we'll need to force + // the meta[...] representation to avoid infinite regress. + return PrintAttributeValue(GetRef(op), /*force_meta=*/true); } Doc RelayTextPrinter::VisitAttr_(const ArrayNode* op) { Doc doc; doc << "["; std::vector arr_vals; - for (auto val : *op) { - arr_vals.push_back(PrintAttr(val)); + for (const auto& val : *op) { + arr_vals.push_back(PrintAttributeValue(val)); } doc << Doc::Concat(arr_vals); doc << "]"; @@ -831,6 +815,7 @@ class RelayTextPrinter::AttrPrinter : public AttrVisitor { doc << key << "=" << *value << "f"; docs->push_back(doc); } + void Visit(const char* key, int64_t* value) final { PrintKV(key, *value); } void Visit(const char* key, uint64_t* value) final { PrintKV(key, *value); } void Visit(const char* key, int* value) final { PrintKV(key, *value); } @@ -844,7 +829,7 @@ class RelayTextPrinter::AttrPrinter : public AttrVisitor { LOG(FATAL) << "do not allow NDarray as argument"; } void Visit(const char* key, runtime::ObjectRef* obj) final { - PrintKV(key, parent_->PrintAttr(*obj)); + PrintKV(key, parent_->PrintAttributeValue(*obj)); } private: @@ -852,50 +837,126 @@ class RelayTextPrinter::AttrPrinter : public AttrVisitor { RelayTextPrinter* parent_; }; -Doc RelayTextPrinter::PrintAttrs(const Attrs& attrs) { - std::vector docs; - AttrPrinter printer(&docs, this); - const_cast(attrs.operator->())->VisitNonDefaultAttrs(&printer); - Doc doc; - doc << "{" << Doc::Concat(docs) << "}"; - - return doc; +void RelayTextPrinter::AppendGenericAttrs(std::vector* docs, const Attrs& attrs, + bool include_type_key) { + if (!attrs.defined()) { + return; + } + AttrPrinter printer(docs, this); + // Need to drop cost cast since in general VisitNonDefaultAttrs can mutate, but in this + // case we are read-only. + const_cast(attrs.get())->VisitNonDefaultAttrs(&printer); + if (include_type_key) { + std::string s = attrs->GetTypeKey(); + printer.Visit("attrs_type_key", &s); + } } std::vector RelayTextPrinter::PrintCallAttrs(const Attrs& attrs, const Expr& op) { std::vector docs; - if (!attrs.defined()) return docs; + if (!attrs.defined()) { + return docs; + } const auto* op_node = op.as(); if (show_meta_data_ && op_node && (attrs->type_index() != op_node->attrs_type_index)) { - // fallback + // The parser can only understand calls with attributes if they match the operator's + // declared attribute type. If that's not the case fall back to the meta[...] representation. + docs.push_back(meta_->GetMetaNode(attrs)); + } else { + AppendGenericAttrs(&docs, attrs, /*include_type_key=*/!op_node); + } + return docs; +} + +std::vector RelayTextPrinter::PrintDictAttrs(const DictAttrs& dict_attrs) { + if (!dict_attrs.defined()) { + return {}; + } + return PrintDictAttrs(dict_attrs->dict); +} + +std::vector RelayTextPrinter::PrintDictAttrs(const Map& dict_attrs) { + std::vector docs; + if (!dict_attrs.defined()) { + return docs; + } + for (const auto& k : dict_attrs) { Doc doc; - doc << meta_->GetMetaNode(attrs); + doc << k.first << "=" << PrintAttributeValue(k.second); docs.push_back(doc); - return docs; - } else { - // Show attributes in readable form. - AttrPrinter printer(&docs, this); - const_cast(attrs.operator->())->VisitNonDefaultAttrs(&printer); - if (!op_node) { - // print call attr type key to restore expr for relay parser - std::string s = std::string(attrs->GetTypeKey()); - printer.Visit("attrs_type_key", &s); + } + return docs; +} + +Doc RelayTextPrinter::PrintAttributeValue(const ObjectRef& value, bool force_meta) { + if (value.defined()) { + Doc printed_attr; + if (value.as()) { + printed_attr << "?"; + } else if (auto str_obj = value.as()) { + printed_attr << Doc::StrLiteral(GetRef(str_obj)); + } else if (force_meta) { + printed_attr = meta_->GetMetaNode(Downcast(value)); + } else if (const auto* se_scope_node = value.as()) { + if (show_meta_data_) { + printed_attr = meta_->GetMetaNode(GetRef(se_scope_node)); + } else { + // Special case: The ReprPrinter for SEScopeNodes is much easier to work with while + // debugging. + std::ostringstream os; + os << GetRef(se_scope_node); + return Doc::Text(os.str()); + } + } else if (const auto* base_attr_node = value.as()) { + if (show_meta_data_) { + printed_attr = meta_->GetMetaNode(GetRef(base_attr_node)); + } else { + // Special case: The non-meta form for attributes are much easier to work with while + // debugging. + printed_attr = PrintAttrsAsAttributeValue(GetRef(base_attr_node)); + } + } else if (const auto* base_map_node = value.as()) { + if (show_meta_data_) { + printed_attr = meta_->GetMetaNode(GetRef(base_map_node)); + } else { + // Special case: Show maps fields as key=value pairs to help debugging. + printed_attr << PrintMapAsAttributeValue(GetRef>(base_map_node)); + } + } else if (const auto* global_var_node = value.as()) { + if (show_meta_data_) { + printed_attr = meta_->GetMetaNode(GetRef(global_var_node)); + } else { + printed_attr << "'" << global_var_node->name_hint << "'"; + } + } else { + printed_attr = VisitAttr(value); } - return docs; + return printed_attr; + } else { + return Doc::Text("None"); } } -std::vector RelayTextPrinter::PrintFuncAttrs(const Attrs& attrs) { +Doc RelayTextPrinter::PrintAttrsAsAttributeValue(const Attrs& attrs) { std::vector docs; - if (!attrs.defined()) return docs; - const auto* dict_attrs = attrs.as(); - ICHECK(dict_attrs); - for (const auto& k : dict_attrs->dict) { + AppendGenericAttrs(&docs, attrs, /*include_type_key=*/false); + Doc doc; + doc << "{" << Doc::Concat(docs) << "}"; + return doc; +} + +Doc RelayTextPrinter::PrintMapAsAttributeValue(const Map& map) { + std::vector docs; + for (const auto& k : map) { Doc doc; - doc << k.first << "=" << Print(k.second); + doc << PrintAttributeValue(k.first); + doc << "="; + doc << PrintAttributeValue(k.second); docs.push_back(doc); } - return docs; + Doc doc; + doc << "{" << Doc::Concat(docs) << "}"; + return doc; } Doc RelayTextPrinter::PrintSpan(const Span& span) { diff --git a/src/printer/text_printer.cc b/src/printer/text_printer.cc index b8533a5d8801..444cb0828c94 100644 --- a/src/printer/text_printer.cc +++ b/src/printer/text_printer.cc @@ -58,6 +58,7 @@ Doc TextPrinter::PrintMod(const IRModule& mod) { os << "def @" << kv.first->name_hint; doc << relay_text_printer_.PrintFunc(Doc::Text(os.str()), kv.second); } else if (kv.second.as()) { + doc << "@" << kv.first->name_hint << " = "; doc << tir_text_printer_.PrintPrimFunc(Downcast(kv.second)); } doc << Doc::NewLine(); diff --git a/src/printer/text_printer.h b/src/printer/text_printer.h index 316d59631782..ebd667ae2ac7 100644 --- a/src/printer/text_printer.h +++ b/src/printer/text_printer.h @@ -77,9 +77,42 @@ class RelayTextPrinter : public ExprFunctor, // numbers to be reused and prevents hoisted vars from escaping too far Doc PrintScope(const ObjectRef& node); Doc PrintFinal(const ObjectRef& node); - Doc PrintAttrs(const Attrs& attrs); + + /*! + * \brief Returns \p attrs printed using the generic attribute visitor, as a sequence + * of key=value entries, if any. + */ + void AppendGenericAttrs(std::vector* docs, const Attrs& attrs, bool include_type_key); + + /*! + * \brief Returns \p attrs printed as a sequence of key=value entries, if any. + * This is used for call attributes. + */ std::vector PrintCallAttrs(const Attrs& attrs, const Expr& op); - std::vector PrintFuncAttrs(const Attrs& attrs); + + /*! + * \brief Returns \p dict_attrs printed as a sequence of key=value entries, if any. + * This is used for function definition attributes. + */ + std::vector PrintDictAttrs(const DictAttrs& dict_attrs); + std::vector PrintDictAttrs(const Map& dict_attrs); + + /*! + * \brief Returns \p value printed as the rhs of an attribute key=value entry. If \p force_meta + * is true then value is printed in meta[...] for irrespective of the show_meta_data_ flag. + */ + Doc PrintAttributeValue(const ObjectRef& value, bool force_meta = false); + + /*! + * \brief Returns \p attrs printed as a self-contained value, ie wrapped in braces. + */ + Doc PrintAttrsAsAttributeValue(const Attrs& attrs); + + /*! + * \brief Returns \p map printed as a self-contained value, ie wrapped in braces. + */ + Doc PrintMapAsAttributeValue(const Map& map); + Doc PrintSpan(const Span& span); Doc Print(const ObjectRef& node, bool meta = false, bool try_inline = false); @@ -162,7 +195,6 @@ class RelayTextPrinter : public ExprFunctor, //------------------------------------ // Overload of Attr printing functions //------------------------------------ - Doc PrintAttr(const ObjectRef& value, bool meta = false); Doc VisitAttrDefault_(const Object* op) final; Doc VisitAttr_(const ArrayNode* op) final; Doc VisitAttr_(const tir::IntImmNode* op) final; diff --git a/src/printer/tir_text_printer.cc b/src/printer/tir_text_printer.cc index 302c4491cebe..e479af1b2fe9 100644 --- a/src/printer/tir_text_printer.cc +++ b/src/printer/tir_text_printer.cc @@ -27,6 +27,7 @@ #include #include #include +#include #include #include #include @@ -71,6 +72,8 @@ Doc TIRTextPrinter::Print(const ObjectRef& node) { return PrintString(node.as()); } else if (node->IsInstance()) { return PrintBufferRegion(node.as()); + } else if (node->IsInstance()) { + return Doc::Text(node.as()->ToDebugString()); } else { return this->meta_->GetMetaNode(node); } diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index a596e09907d5..13b855624461 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -34,6 +34,7 @@ #include #include #include +#include #include "../op/annotation/annotation.h" #include "../transforms/pass_utils.h" @@ -292,8 +293,11 @@ InterpreterState::InterpreterState(Expr current_expr, InterpreterState::Stack st class Interpreter : public ExprFunctor, PatternFunctor { public: - Interpreter(IRModule unified_mod, Device device, Target target) - : unified_mod_(unified_mod), device_(device), target_(target), debug_op_(Op::Get("debug")) {} + Interpreter(IRModule unified_mod, CompilationConfig config, Device device) + : unified_mod_(unified_mod), + config_(std::move(config)), + device_(device), + debug_op_(Op::Get("debug")) {} template T WithFrame(const Frame& fr, const std::function& f) { @@ -386,12 +390,12 @@ class Interpreter : public ExprFunctor, per_target_module_std_map = backend::TargetModuleMapToTargetStrModuleMap(per_target_module); auto mod_itr = per_target_module_std_map.find(target); ICHECK(mod_itr != per_target_module_std_map.end()) - << "No target module for target '" << target->str() << "'"; + << "No target module for target " << target->ToDebugString(); const IRModule& target_module = (*mod_itr).second; for (const auto& var : all_tir_fn_vars) { ICHECK(target_module->ContainGlobalVar(var->name_hint)) - << "No global var for '" << var->name_hint << "' in module for target '" << target->str() - << "'"; + << "No global var for '" << var->name_hint << "' in module for target " + << target->ToDebugString(); lowered_projected_mod->Add(var, target_module->Lookup(var->name_hint)); } @@ -407,8 +411,9 @@ class Interpreter : public ExprFunctor, // Extract all the packed functions. for (const auto& var : all_tir_fn_vars) { PackedFunc packed_func = runtime_module.GetFunction(var->name_hint); - ICHECK(packed_func != nullptr) << "No packed function for global var '" << var->name_hint - << "' in compiled module for target '" << target->str() << "'"; + ICHECK(packed_func != nullptr) + << "No packed function for global var '" << var->name_hint + << "' in compiled module for target " << target->ToDebugString(); compiled_packed_funcs_.emplace(std::make_pair(target, var->name_hint), packed_func); } @@ -734,9 +739,11 @@ class Interpreter : public ExprFunctor, Downcast(attrs->metadata.at("prim_shape_fn_num_outputs"))->value); } - return InvokePrimitiveOp(GetRef(gvn), all_prim_fn_vars, target_, - prim_shape_fn_var, all_prim_shape_fn_vars, prim_shape_fn_states, - num_shape_inputs, num_shape_outputs, cpu_target_, args); + ICHECK(config_->optional_homogeneous_target.defined()); + return InvokePrimitiveOp(GetRef(gvn), all_prim_fn_vars, + config_->optional_homogeneous_target, prim_shape_fn_var, + all_prim_shape_fn_vars, prim_shape_fn_states, num_shape_inputs, + num_shape_outputs, config_->host_se_scope->target, args); } } @@ -884,13 +891,11 @@ class Interpreter : public ExprFunctor, // Cached packed functions for the primitives and shape functions, keyed by target and // global var name. std::unordered_map, PackedFunc, PairHash> compiled_packed_funcs_; + /*! \brief Compilation config describing the available targets. */ + CompilationConfig config_; // Unique device on which primitives (but not shape functions) will be executed. // (For simplicity we only run the interpreter on a single device.) Device device_; - // Unique target describing how to compile for primitives (but not shape functions). - Target target_; - // Default 'CPU' target for shape primitives. - Target cpu_target_{"llvm"}; // Call stack. Stack stack_; // The distinguished 'debug' operator, which is handled specially. @@ -898,25 +903,21 @@ class Interpreter : public ExprFunctor, }; /*! - * Lowers all calls to primitives in \p mod appropriate for device and target. Returns the + * Lowers all calls to primitives in \p mod appropriate for \p config. Returns the * rewritten \p mod and target-specific modules containing bindings for all TIR primitive * functions needed by the rewritten module. */ -IRModule Prepare(IRModule mod, Device device, Target target) { - // Things to initialize to pass into tec::LowerTEPass - // We only have one device-specific target. - tec::TargetMap targets = {{device.device_type, target}}; - if (device.device_type != kDLCPU) { - // However some primitives (eg dynamic shape functions) must always execute on the CPU, - // so make sure we have a target for that. - targets.emplace(kDLCPU, Target("llvm")); +IRModule Prepare(IRModule mod, CompilationConfig config) { + tec::TargetMap tec_target_map; + for (const auto& pair : config->legacy_target_map) { + tec_target_map.emplace(static_cast(pair.first->value), pair.second); } - // Run minimal transforms on module to establish invariants needed by interpreter. transform::Sequential seq( {transform::SimplifyInference(), // Figure out which devices should be used to execute. - transform::PlanDevices(device.device_type), + // TODO(mbs): Should ignore all existing annotations when constant folding + transform::PlanDevices(config->default_primitive_se_scope->device_type()), // FuseOps will mark wrapped calls to prim-ops with the 'Primitive' // attribute. transform::FuseOps(/*fuse_opt_level=*/0), @@ -926,7 +927,8 @@ IRModule Prepare(IRModule mod, Device device, Target target) { transform::EtaExpand( /*expand_constructor=*/true, /*expand_global_var=*/false), transform::InferType(), - tec::LowerTEPass(targets, /*module_name=*/"intrp", [](Function func) { /* no-op */ })}); + tec::LowerTEPass(tec_target_map, /*module_name=*/"intrp", + [](Function func) { /* no-op */ })}); transform::PassContext pass_ctx = transform::PassContext::Current(); With ctx(pass_ctx); @@ -979,7 +981,15 @@ class NeedsPreparationVisitor : public ExprVisitor { TypedPackedFunc)> EvalFunction(IRModule mod, Expr expr, Device device, Target target) { VLOG_CONTEXT << "EvalFunction"; - VLOG(1) << "evaling module:\n" << PrettyPrint(mod) << "and expression:\n" << PrettyPrint(expr); + VLOG(1) << "evaling module:" << std::endl + << PrettyPrint(mod) << "and expression:" << std::endl + << PrettyPrint(expr); + + ICHECK_EQ(device.device_type, target->kind->device_type); + TargetMap targets; + targets.Set(device.device_type, target); + CompilationConfig config(transform::PassContext::Current(), targets, + /*optional_host_target_arg=*/{}); // // Step 1: Prepare mod. @@ -1024,9 +1034,9 @@ TypedPackedFunc)> EvalFunction(IRModule mod, Expr expr, De // and can just eval it directly. expr_to_eval = expr; } - IRModule lowered_mod = Prepare(mod_with_expr, device, target); + IRModule lowered_mod = Prepare(mod_with_expr, config); - std::shared_ptr intrp = std::make_shared(lowered_mod, device, target); + std::shared_ptr intrp = std::make_shared(lowered_mod, config, device); // // Step 2: Evaluate target function to a closure. @@ -1065,12 +1075,18 @@ TypedPackedFunc)> EvalFunction(IRModule mod, Expr expr, De ObjectRef Eval(Expr expr, Map type_definitions, std::unordered_set import_set, Device device, Target target) { + ICHECK_EQ(device.device_type, target->kind->device_type); + TargetMap targets; + targets.Set(device.device_type, target); + CompilationConfig config(transform::PassContext::Current(), targets, + /*optional_host_target_arg=*/{}); + std::pair mod_and_global = IRModule::FromExprInContext(expr, /*global_funcs=*/{}, type_definitions, import_set); - IRModule mod = Prepare(mod_and_global.first, device, target); + IRModule mod = Prepare(mod_and_global.first, config); - Interpreter intrp(mod, device, target); + Interpreter intrp(mod, config, device); Expr expr_to_eval = mod->GetGlobalVar(mod_and_global.second->name_hint); if (expr.as() == nullptr) { // TODO(mbs): IRModule::FromExpr will implicitly close over the free vars of expr diff --git a/src/target/compilation_config.cc b/src/target/compilation_config.cc index b3491d656625..37f6e1e3d15a 100644 --- a/src/target/compilation_config.cc +++ b/src/target/compilation_config.cc @@ -61,32 +61,48 @@ void CompilationConfigNode::EstablishDefaultSEScopes(const transform::PassContex if (host_target.defined()) { CHECK(!host_target->host.defined()) << "Host targets are not expected to have hosts"; host_device_type = static_cast(host_target->kind->device_type); - if (host_device_type != kDLCPU) { - LOG(WARNING) << "Using the given host target '" << host_target << "' of non-CPU device type " - << host_device_type << " for all host operations and data"; - } else { - LOG(INFO) << "Using the given host target '" << host_target << "' of device type " - << host_device_type << " for all host operations and data"; + LOG(INFO) << "Using the given host target " << host_target->ToDebugString() + << " of device type " << host_device_type << " for the host target"; + for (const auto& primitive_target : primitive_targets) { + if (primitive_target->host.defined() && + !StructuralEqual()(primitive_target->host, host_target)) { + LOG(WARNING) << "The primitive target " << primitive_target->ToDebugString() + << " already has a host which disagrees with the desired host target. It " + "will be ignored."; + } } + } else if (primitive_targets.size() == 1 && primitive_targets.front()->host.defined()) { + host_target = primitive_targets.front()->GetHost().value(); + CHECK(!host_target->host.defined()) << "Host targets are not expected to have hosts"; + host_device_type = static_cast(host_target->kind->device_type); + LOG(INFO) << "Using the host of the unique primitive target, namely " + << host_target->ToDebugString() << " of device type " << host_device_type + << " for the host target"; } else if (primitive_targets.size() == 1 && primitive_targets.front()->kind->device_type == kDLCPU) { // In the homogenous case without an explicit host target just use the given target so long as - // it's a CPU. However make sure we 'forget' any host it may already have. + // it's a CPU. host_device_type = kDLCPU; - host_target = Target(primitive_targets.front()); - LOG(INFO) << "Using the unique target '" << host_target << "' of device type " - << host_device_type << " for all host operations and data"; + host_target = primitive_targets.front(); + LOG(INFO) << "Using the unique primitive target " << host_target->ToDebugString() + << " of device type " << host_device_type << " for the host target"; } else { // Fallback. host_device_type = kDLCPU; // Even if the list of available targets already includes one for kDLCPU we won't use it - // since its options may not be appropriate for host code (eg shape functions). Instead, - // create a fresh default Target. + // in the hetrogeneous case since its options may not be appropriate for host code + // (eg shape functions). Instead, create a fresh default Target. host_target = MakeDefaultTarget(host_device_type); - LOG(WARNING) << "Using the default host target '" << host_target << "' of device type " - << host_device_type << " for all host operations and data"; + LOG(WARNING) << "Using the default target " << host_target->ToDebugString() + << " of device type " << host_device_type << " for the host target"; } ICHECK(host_target.defined()); + ICHECK(!host_target->host.defined()); + + if (host_device_type != kDLCPU) { + // I think we're on thin ice here until we've audited the code base for assumed kDLCPU. + LOG(WARNING) << "The host target is not a CPU."; + } // // Establish the host SEScope. @@ -112,24 +128,19 @@ void CompilationConfigNode::EstablishDefaultSEScopes(const transform::PassContex Optional opt_fallback_dev = pass_ctx->GetConfig("relay.fallback_device_type"); if (opt_fallback_dev) { const int64_t v = opt_fallback_dev.value()->value; - if (v <= 0) { - LOG(FATAL) - << "The 'relay.fallback_device_type' pass attribute is set to an invalid device type " - << v; - default_primitive_device_type = kDLCPU; - } else { - default_primitive_device_type = static_cast(v); - LOG(INFO) << "Using the 'relay.fallback_device_type' pass attribute " - << default_primitive_device_type - << " as the default device type for all primitive operations"; - } + CHECK_GT(v, 0) + << "The 'relay.fallback_device_type' pass attribute is set to an invalid device type " << v; + default_primitive_device_type = static_cast(v); + LOG(INFO) << "Using the 'relay.fallback_device_type' pass attribute " + << default_primitive_device_type + << " as the default device type for all primitive operations"; } else if (primitive_targets.size() == 1) { // In the homogeneous case there's no free choice. default_primitive_device_type = static_cast(primitive_targets.front()->kind->device_type); - LOG(INFO) << "Using the unique target '" << primitive_targets.front() << "' of device type " - << default_primitive_device_type - << " as the default device type for all primitive operations"; + LOG(INFO) << "Using the device type " << default_primitive_device_type + << " of the unique primitive target as the default device type for all primitive " + "operations"; } else { // Fallback. Note that we'll require a primitive Target of kDLCPU device_type to be given // and won't manufacture one out of thin air. @@ -154,6 +165,7 @@ void CompilationConfigNode::EstablishDefaultSEScopes(const transform::PassContex return Target("llvm"); } else { // LLVM is not available. + // TODO(mbs): Already deprecated? return Target("stackvm"); } } else { @@ -178,10 +190,10 @@ CompilationConfig::CompilationConfig(const transform::PassContext& pass_ctx, auto node = make_object(); for (const auto& pair : legacy_target_map_arg) { - VLOG(0) << "Available primitive target " << pair.first << " = '" << pair.second << "'"; + VLOG(0) << "Available primitive target " << pair.first << " = " << pair.second->ToDebugString(); } if (optional_host_target_arg.defined()) { - VLOG(0) << "Available host target '" << optional_host_target_arg << "'"; + VLOG(0) << "Available host target " << optional_host_target_arg->ToDebugString(); } // Capture the arguments in our representation. @@ -210,8 +222,8 @@ CompilationConfig::CompilationConfig(const transform::PassContext& pass_ctx, node->primitive_targets.size() == 1 ? *node->primitive_targets.begin() : Target(); for (const auto& target : node->primitive_targets) { - LOG(INFO) << "Target '" << target << "' of device type " << target->kind->device_type - << " is available for primitives"; + LOG(INFO) << "Target " << target->ToDebugString() << " of device type " + << target->kind->device_type << " is available for primitives"; } LOG(INFO) << "Using default primitive scope " << node->default_primitive_se_scope; LOG(INFO) << "Using host scope " << node->host_se_scope; diff --git a/src/target/se_scope.cc b/src/target/se_scope.cc index 150a883cb565..95d5a7de5775 100644 --- a/src/target/se_scope.cc +++ b/src/target/se_scope.cc @@ -52,7 +52,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) if (need_sep) { p->stream << ", "; } - p->stream << "target='" << node->target << "'"; + p->stream << "target=" << node->target->ToDebugString(); need_sep = true; } if (!node->memory_scope.empty()) { @@ -62,13 +62,17 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "memory_scope='" << node->memory_scope << "'"; } } +#if TVM_LOG_DEBUG + // We rely on object identity of SEScopes, so include the object address to help debugging. + p->stream << ", id=" << reinterpret_cast(ref.get()); +#endif p->stream << ")"; }); SEScope::SEScope(DLDeviceType device_type, int virtual_device_id, Target target, MemoryScope memory_scope) { ICHECK(!target.defined() || device_type == target->kind->device_type) - << "target '" << target << "' has device type " << target->kind->device_type + << "target " << target->ToDebugString() << " has device type " << target->kind->device_type << " but scope has device type " << device_type; auto node = make_object(); node->device_type_int = device_type; @@ -173,7 +177,7 @@ SEScope SEScopeCache::Make(DLDeviceType device_type, int virtual_device_id, Targ cache_.emplace(prototype); return prototype; } else { - VLOG(1) << "reusing '" << *itr << "' for '" << prototype << "'"; + VLOG(1) << "reusing existing scope " << *itr; ICHECK_EQ(prototype->target.defined(), (*itr)->target.defined()); if (prototype->target.defined()) { ICHECK_EQ(prototype->target->host.defined(), (*itr)->target->host.defined()); diff --git a/src/target/target.cc b/src/target/target.cc index d1c85c583b3b..6f5e8ee67b30 100644 --- a/src/target/target.cc +++ b/src/target/target.cc @@ -457,6 +457,7 @@ const std::string& TargetNode::str() const { if (Optional attrs_str = TargetInternal::StringifyAttrsToRaw(attrs)) { os << ' ' << attrs_str.value(); } + str_repr_ = os.str(); } return str_repr_; @@ -531,6 +532,48 @@ Optional TargetNode::GetHost() const { return GetRef>(this->host.as()); } +String TargetNode::ToDebugString() const { + std::ostringstream os; + os << "Target("; + os << "kind='" << kind->name << "'"; + if (!tag.empty()) { + os << ", tag='" << tag << "'"; + } + if (!keys.empty()) { + os << ", keys={"; + bool first = true; + for (const auto& key : keys) { + if (!first) { + os << ", "; + } + os << "'" << key << "'"; + first = false; + } + os << "}"; + } + if (!attrs.empty()) { + os << ", attrs={"; + bool first = true; + for (const auto& pair : attrs) { + if (!first) { + os << ", "; + } + os << '"' << pair.first << "': " << pair.second; + first = false; + } + os << "}"; + } + if (host.defined()) { + os << ", host=" << GetHost().value()->ToDebugString(); + } +#if TVM_LOG_DEBUG + // We depend on pointer equality so include that in the debug representation. + os << ", id=" << reinterpret_cast(this); +#endif + os << ")"; + return os.str(); +} + bool TargetNode::SEqualReduce(const TargetNode* other, SEqualReducer equal) const { return equal(kind.get(), other->kind.get()) && equal(host, other->host) && equal(tag, other->tag) && equal(keys, other->keys) && equal(attrs, other->attrs); diff --git a/tests/cpp/target/compilation_config_test.cc b/tests/cpp/target/compilation_config_test.cc index 5c2b7990a498..ae5f5d0c3dc4 100644 --- a/tests/cpp/target/compilation_config_test.cc +++ b/tests/cpp/target/compilation_config_test.cc @@ -40,13 +40,13 @@ CompilationConfig TestCompilationConfig() { return CompilationConfig(pass_ctx, legacy_target_map, TestDefaultCpuTarget()); } -TEST(CompilationConfig, Constructor_Homogeneous_DefaultHost) { +TEST(CompilationConfig, Constructor_Homogeneous_FallbackCPUHost) { transform::PassContext pass_ctx = transform::PassContext::Create(); Target host_target = TestDefaultCpuTarget(); Target cuda_target = TestCudaTarget(); TargetMap legacy_target_map; legacy_target_map.Set(Integer(static_cast(kDLCUDA)), cuda_target); - CompilationConfig config(pass_ctx, legacy_target_map, /*optional_host_target=*/{}); + CompilationConfig config(pass_ctx, legacy_target_map, /*optional_host_target_arg=*/{}); SEScope expected_default_primitive_se_scope(kDLCUDA, 0, Target::WithHost(cuda_target, host_target)); @@ -68,7 +68,31 @@ TEST(CompilationConfig, Constructor_Homogeneous_DefaultHost) { Target::WithHost(cuda_target, host_target))); } -TEST(CompilationConfig, Constructor_Hetrogeneous_DefaultHost) { +TEST(CompilationConfig, Constructor_Homegenoous_InnerHost) { + transform::PassContext pass_ctx = transform::PassContext::Create(); + Target host_target = TestCpuTarget(); + Target cuda_target = Target::WithHost(TestCudaTarget(), host_target); + TargetMap legacy_target_map; + legacy_target_map.Set(Integer(static_cast(kDLCUDA)), cuda_target); + CompilationConfig config(pass_ctx, legacy_target_map, /*optional_host_target_arg=*/{}); + + EXPECT_TRUE(StructuralEqual()(config->host_target, host_target)); +} + +TEST(CompilationConfig, Constructor_Homogenous_CPUHost) { + transform::PassContext pass_ctx = transform::PassContext::Create(); + Target cpu_target = TestCpuTarget(); + TargetMap legacy_target_map; + legacy_target_map.Set(Integer(static_cast(kDLCPU)), cpu_target); + CompilationConfig config(pass_ctx, legacy_target_map, /*optional_host_target_arg=*/{}); + + EXPECT_TRUE(StructuralEqual()(config->host_target, cpu_target)); + ASSERT_TRUE(config->optional_homogeneous_target.defined()); + EXPECT_TRUE(StructuralEqual()(config->optional_homogeneous_target, + Target::WithHost(cpu_target, cpu_target))); +} + +TEST(CompilationConfig, Constructor_Hetrogeneous_FallbackCPUHost) { transform::PassContext pass_ctx = transform::PassContext::Create(); pass_ctx->config.Set("relay.fallback_device_type", Integer(static_cast(kDLCUDA))); Target host_target = TestDefaultCpuTarget(); @@ -77,7 +101,7 @@ TEST(CompilationConfig, Constructor_Hetrogeneous_DefaultHost) { TargetMap legacy_target_map; legacy_target_map.Set(Integer(static_cast(kDLCPU)), cpu_target); legacy_target_map.Set(Integer(static_cast(kDLCUDA)), cuda_target); - CompilationConfig config(pass_ctx, legacy_target_map, /*optional_host_target=*/{}); + CompilationConfig config(pass_ctx, legacy_target_map, /*optional_host_target_arg=*/{}); SEScope expected_default_primitive_se_scope(kDLCUDA, 0, Target::WithHost(cuda_target, host_target)); @@ -123,7 +147,7 @@ TEST(CompilationConfig, Constructor_InvalidAttribute) { TargetMap legacy_target_map; legacy_target_map.Set(Integer(static_cast(kDLCUDA)), TestCudaTarget()); EXPECT_ANY_THROW( - CompilationConfig config(pass_ctx, legacy_target_map, /*optional_host_target=*/{})); + CompilationConfig config(pass_ctx, legacy_target_map, /*optional_host_target_arg=*/{})); } TEST(CompilationConfig, Constructor_NoMatchingPrimitiveTarget) { @@ -132,7 +156,7 @@ TEST(CompilationConfig, Constructor_NoMatchingPrimitiveTarget) { TargetMap legacy_target_map; legacy_target_map.Set(Integer(static_cast(kDLCUDA)), TestCudaTarget()); EXPECT_ANY_THROW( - CompilationConfig config(pass_ctx, legacy_target_map, /*optional_host_target=*/{})); + CompilationConfig config(pass_ctx, legacy_target_map, /*optional_host_target_arg=*/{})); } TEST(CompilationConfig, Constructor_DefaultNoMatchingPrimitiveTarget) { @@ -141,7 +165,7 @@ TEST(CompilationConfig, Constructor_DefaultNoMatchingPrimitiveTarget) { legacy_target_map.Set(Integer(static_cast(kDLCUDA)), TestCudaTarget()); legacy_target_map.Set(Integer(static_cast(kDLExtDev)), TestExtDevTarget()); EXPECT_ANY_THROW( - CompilationConfig config(pass_ctx, legacy_target_map, /*optional_host_target=*/{})); + CompilationConfig config(pass_ctx, legacy_target_map, /*optional_host_target_arg=*/{})); } TEST(CompilationConfig, CanonicalSEScope) { diff --git a/tests/python/relay/test_ir_text_printer.py b/tests/python/relay/test_ir_text_printer.py index 2834bba9248b..21c460fa0371 100644 --- a/tests/python/relay/test_ir_text_printer.py +++ b/tests/python/relay/test_ir_text_printer.py @@ -254,7 +254,7 @@ def test_null_attribute(): z = relay.Function([x], y) z = z.with_attr("TestAttribute", None) txt = astext(z) - assert "TestAttribute=(nullptr)" in txt + assert "TestAttribute=None" in txt def test_span():