Skip to content

Commit

Permalink
Better host handling in CompilationConfig & debug printing (apache#9460)
Browse files Browse the repository at this point in the history
(This is a bit of a grab bag in preparation for apache#9326
which I'm trying to minimize)

While switching the device planner to use SEScopes I had a lot
of trouble with Target's not matching up.
- If no explicit host target is given but the given
  TargetMap has targets with hosts, try to use those
  to establish the host_target.
- Make sure both the 'legacy' TargetMap representation
  and the newer representation agree to pointer equality on
  their targets.
- Make sure the Interpreter uses the target from CompilationConfig
  since it's been normalized.

To debug the above:
- When in pretty printing with show_meta_data_ false give as much
  detail on SEScopes, Targets and call attributes as possible.
  That needed some rework in the relay_text_printer.cc.
- Ditto for critical 'target' attribute on PrimFuncs.
- Also added a Target::ToDebugString so I could see the
  host fields along with everything else since a lot of problems
  were caused by a mismatch of 'the same' Target with and without
  a host. (Tried using that for the ReprPrinter but broken unit
  tests.)

Note that the codebase assumes Targets are compared by ObjectPtrEquality,
yet CheckAndUpdateHostConsistency (I count 65 call sites) changes the targets.
Ultimately CompilationConfig or it's ultimate replacement should ensure we munge
targets only once at the 'main' entry points.
  • Loading branch information
mbs-octoml authored and yangulei committed Jan 11, 2022
1 parent 42f4112 commit 479ec76
Show file tree
Hide file tree
Showing 12 changed files with 340 additions and 134 deletions.
9 changes: 9 additions & 0 deletions include/tvm/target/target.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,15 @@ class TargetNode : public Object {
/*! \return The Optional<Target> typed target host of the TargetNode */
TVM_DLL Optional<Target> GetHost() const;

/*!
* \brief Returns a human readable representation of \p Target which includes all fields,
* especially the host. Useful for diagnostic messages and debugging.
*
* TODO(mbs): The ReprPrinter version should perhaps switch to this form, however currently
* code depends on str() and << being the same.
*/
String ToDebugString() const;

void VisitAttrs(AttrVisitor* v) {
v->Visit("kind", &kind);
v->Visit("tag", &tag);
Expand Down
3 changes: 2 additions & 1 deletion src/parser/parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1955,7 +1955,8 @@ TVM_REGISTER_GLOBAL("parser.ParseExpr")
TVM_REGISTER_GLOBAL("relay._transform.AnnotateSpans").set_body_typed([]() {
return CreateModulePass(
[](const IRModule& mod, const PassContext& ctx) {
auto text = AsText(mod, true);
String text = AsText(mod, /*show_meta_data=*/true);
VLOG(1) << "AnnotateSpans intermediate text:" << std::endl << text;
return ParseModule("GeneratedSource", text);
},
0, "AnnotateSpans", {});
Expand Down
173 changes: 117 additions & 56 deletions src/printer/relay_text_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
#include <tvm/relay/attrs/annotation.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/pattern_functor.h>
#include <tvm/target/se_scope.h>
#include <tvm/tir/function.h>

#include "../ir/attr_functor.h"
Expand Down Expand Up @@ -120,9 +121,6 @@ Doc RelayTextPrinter::Print(const ObjectRef& node, bool meta, bool try_inline) {
return PrintPattern(Downcast<Pattern>(node), meta);
} else if (node.as<IRModuleNode>()) {
return PrintMod(Downcast<IRModule>(node));
} else if (!show_meta_data_ && node.as<BaseAttrsNode>()) {
// Show attributes in readable form.
return PrintAttrs(Downcast<Attrs>(node));
} else {
// default module.
std::ostringstream os;
Expand Down Expand Up @@ -444,7 +442,7 @@ Doc RelayTextPrinter::PrintFunc(const Doc& prefix, const relay::Function& fn) {
for (Var param : fn->params) {
params.push_back(AllocVar(param));
}
for (const Doc& d : PrintFuncAttrs(fn->attrs)) {
for (const Doc& d : PrintDictAttrs(fn->attrs)) {
params.push_back(d);
}
doc << Doc::Concat(params) << ") ";
Expand Down Expand Up @@ -684,8 +682,10 @@ Doc RelayTextPrinter::VisitType_(const TensorTypeNode* node) {
Doc doc;
doc << "Tensor[(";
std::vector<Doc> shapes;
for (ObjectRef shape : node->shape) {
shapes.push_back(PrintAttr(shape));
for (const PrimExpr& prim_expr : node->shape) {
// Though not bound within an attribute the attribute visitor will handle the PrimExprs we
// care about.
shapes.push_back(PrintAttributeValue(prim_expr));
}
doc << Doc::Concat(shapes);
return doc << "), " << PrintDType(node->dtype) << "]";
Expand Down Expand Up @@ -766,34 +766,18 @@ Doc RelayTextPrinter::VisitType_(const TypeDataNode* node) {
// Overload of Attr printing functions
//------------------------------------

Doc RelayTextPrinter::PrintAttr(const ObjectRef& value, bool meta) {
if (value.defined()) {
Doc printed_attr;
if (value.as<tvm::tir::AnyNode>()) {
printed_attr << "?";
} else if (auto str_obj = value.as<tvm::StringObj>()) {
printed_attr << Doc::StrLiteral(GetRef<String>(str_obj));
} else if (meta) {
printed_attr = meta_->GetMetaNode(Downcast<ObjectRef>(value));
} else {
printed_attr = VisitAttr(value);
}
return printed_attr;
} else {
return Doc::Text("None");
}
}

Doc RelayTextPrinter::VisitAttrDefault_(const Object* op) {
return PrintAttr(GetRef<ObjectRef>(op), /*meta=*/true);
// Since we don't have any overload for a specific attribute type we'll need to force
// the meta[...] representation to avoid infinite regress.
return PrintAttributeValue(GetRef<ObjectRef>(op), /*force_meta=*/true);
}

Doc RelayTextPrinter::VisitAttr_(const ArrayNode* op) {
Doc doc;
doc << "[";
std::vector<Doc> arr_vals;
for (auto val : *op) {
arr_vals.push_back(PrintAttr(val));
for (const auto& val : *op) {
arr_vals.push_back(PrintAttributeValue(val));
}
doc << Doc::Concat(arr_vals);
doc << "]";
Expand Down Expand Up @@ -831,6 +815,7 @@ class RelayTextPrinter::AttrPrinter : public AttrVisitor {
doc << key << "=" << *value << "f";
docs->push_back(doc);
}

void Visit(const char* key, int64_t* value) final { PrintKV(key, *value); }
void Visit(const char* key, uint64_t* value) final { PrintKV(key, *value); }
void Visit(const char* key, int* value) final { PrintKV(key, *value); }
Expand All @@ -844,58 +829,134 @@ class RelayTextPrinter::AttrPrinter : public AttrVisitor {
LOG(FATAL) << "do not allow NDarray as argument";
}
void Visit(const char* key, runtime::ObjectRef* obj) final {
PrintKV(key, parent_->PrintAttr(*obj));
PrintKV(key, parent_->PrintAttributeValue(*obj));
}

private:
std::vector<Doc>* docs;
RelayTextPrinter* parent_;
};

Doc RelayTextPrinter::PrintAttrs(const Attrs& attrs) {
std::vector<Doc> docs;
AttrPrinter printer(&docs, this);
const_cast<BaseAttrsNode*>(attrs.operator->())->VisitNonDefaultAttrs(&printer);
Doc doc;
doc << "{" << Doc::Concat(docs) << "}";

return doc;
void RelayTextPrinter::AppendGenericAttrs(std::vector<Doc>* docs, const Attrs& attrs,
bool include_type_key) {
if (!attrs.defined()) {
return;
}
AttrPrinter printer(docs, this);
// Need to drop cost cast since in general VisitNonDefaultAttrs can mutate, but in this
// case we are read-only.
const_cast<BaseAttrsNode*>(attrs.get())->VisitNonDefaultAttrs(&printer);
if (include_type_key) {
std::string s = attrs->GetTypeKey();
printer.Visit("attrs_type_key", &s);
}
}

std::vector<Doc> RelayTextPrinter::PrintCallAttrs(const Attrs& attrs, const Expr& op) {
std::vector<Doc> docs;
if (!attrs.defined()) return docs;
if (!attrs.defined()) {
return docs;
}
const auto* op_node = op.as<OpNode>();
if (show_meta_data_ && op_node && (attrs->type_index() != op_node->attrs_type_index)) {
// fallback
// The parser can only understand calls with attributes if they match the operator's
// declared attribute type. If that's not the case fall back to the meta[...] representation.
docs.push_back(meta_->GetMetaNode(attrs));
} else {
AppendGenericAttrs(&docs, attrs, /*include_type_key=*/!op_node);
}
return docs;
}

std::vector<Doc> RelayTextPrinter::PrintDictAttrs(const DictAttrs& dict_attrs) {
if (!dict_attrs.defined()) {
return {};
}
return PrintDictAttrs(dict_attrs->dict);
}

std::vector<Doc> RelayTextPrinter::PrintDictAttrs(const Map<String, ObjectRef>& dict_attrs) {
std::vector<Doc> docs;
if (!dict_attrs.defined()) {
return docs;
}
for (const auto& k : dict_attrs) {
Doc doc;
doc << meta_->GetMetaNode(attrs);
doc << k.first << "=" << PrintAttributeValue(k.second);
docs.push_back(doc);
return docs;
} else {
// Show attributes in readable form.
AttrPrinter printer(&docs, this);
const_cast<BaseAttrsNode*>(attrs.operator->())->VisitNonDefaultAttrs(&printer);
if (!op_node) {
// print call attr type key to restore expr for relay parser
std::string s = std::string(attrs->GetTypeKey());
printer.Visit("attrs_type_key", &s);
}
return docs;
}

Doc RelayTextPrinter::PrintAttributeValue(const ObjectRef& value, bool force_meta) {
if (value.defined()) {
Doc printed_attr;
if (value.as<tvm::tir::AnyNode>()) {
printed_attr << "?";
} else if (auto str_obj = value.as<tvm::StringObj>()) {
printed_attr << Doc::StrLiteral(GetRef<String>(str_obj));
} else if (force_meta) {
printed_attr = meta_->GetMetaNode(Downcast<ObjectRef>(value));
} else if (const auto* se_scope_node = value.as<SEScopeNode>()) {
if (show_meta_data_) {
printed_attr = meta_->GetMetaNode(GetRef<ObjectRef>(se_scope_node));
} else {
// Special case: The ReprPrinter for SEScopeNodes is much easier to work with while
// debugging.
std::ostringstream os;
os << GetRef<SEScope>(se_scope_node);
return Doc::Text(os.str());
}
} else if (const auto* base_attr_node = value.as<BaseAttrsNode>()) {
if (show_meta_data_) {
printed_attr = meta_->GetMetaNode(GetRef<ObjectRef>(base_attr_node));
} else {
// Special case: The non-meta form for attributes are much easier to work with while
// debugging.
printed_attr = PrintAttrsAsAttributeValue(GetRef<Attrs>(base_attr_node));
}
} else if (const auto* base_map_node = value.as<MapNode>()) {
if (show_meta_data_) {
printed_attr = meta_->GetMetaNode(GetRef<ObjectRef>(base_map_node));
} else {
// Special case: Show maps fields as key=value pairs to help debugging.
printed_attr << PrintMapAsAttributeValue(GetRef<Map<ObjectRef, ObjectRef>>(base_map_node));
}
} else if (const auto* global_var_node = value.as<GlobalVarNode>()) {
if (show_meta_data_) {
printed_attr = meta_->GetMetaNode(GetRef<ObjectRef>(global_var_node));
} else {
printed_attr << "'" << global_var_node->name_hint << "'";
}
} else {
printed_attr = VisitAttr(value);
}
return docs;
return printed_attr;
} else {
return Doc::Text("None");
}
}

std::vector<Doc> RelayTextPrinter::PrintFuncAttrs(const Attrs& attrs) {
Doc RelayTextPrinter::PrintAttrsAsAttributeValue(const Attrs& attrs) {
std::vector<Doc> docs;
if (!attrs.defined()) return docs;
const auto* dict_attrs = attrs.as<DictAttrsNode>();
ICHECK(dict_attrs);
for (const auto& k : dict_attrs->dict) {
AppendGenericAttrs(&docs, attrs, /*include_type_key=*/false);
Doc doc;
doc << "{" << Doc::Concat(docs) << "}";
return doc;
}

Doc RelayTextPrinter::PrintMapAsAttributeValue(const Map<ObjectRef, ObjectRef>& map) {
std::vector<Doc> docs;
for (const auto& k : map) {
Doc doc;
doc << k.first << "=" << Print(k.second);
doc << PrintAttributeValue(k.first);
doc << "=";
doc << PrintAttributeValue(k.second);
docs.push_back(doc);
}
return docs;
Doc doc;
doc << "{" << Doc::Concat(docs) << "}";
return doc;
}

Doc RelayTextPrinter::PrintSpan(const Span& span) {
Expand Down
1 change: 1 addition & 0 deletions src/printer/text_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ Doc TextPrinter::PrintMod(const IRModule& mod) {
os << "def @" << kv.first->name_hint;
doc << relay_text_printer_.PrintFunc(Doc::Text(os.str()), kv.second);
} else if (kv.second.as<tir::PrimFuncNode>()) {
doc << "@" << kv.first->name_hint << " = ";
doc << tir_text_printer_.PrintPrimFunc(Downcast<tir::PrimFunc>(kv.second));
}
doc << Doc::NewLine();
Expand Down
38 changes: 35 additions & 3 deletions src/printer/text_printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,42 @@ class RelayTextPrinter : public ExprFunctor<Doc(const Expr&)>,
// numbers to be reused and prevents hoisted vars from escaping too far
Doc PrintScope(const ObjectRef& node);
Doc PrintFinal(const ObjectRef& node);
Doc PrintAttrs(const Attrs& attrs);

/*!
* \brief Returns \p attrs printed using the generic attribute visitor, as a sequence
* of key=value entries, if any.
*/
void AppendGenericAttrs(std::vector<Doc>* docs, const Attrs& attrs, bool include_type_key);

/*!
* \brief Returns \p attrs printed as a sequence of key=value entries, if any.
* This is used for call attributes.
*/
std::vector<Doc> PrintCallAttrs(const Attrs& attrs, const Expr& op);
std::vector<Doc> PrintFuncAttrs(const Attrs& attrs);

/*!
* \brief Returns \p dict_attrs printed as a sequence of key=value entries, if any.
* This is used for function definition attributes.
*/
std::vector<Doc> PrintDictAttrs(const DictAttrs& dict_attrs);
std::vector<Doc> PrintDictAttrs(const Map<String, ObjectRef>& dict_attrs);

/*!
* \brief Returns \p value printed as the rhs of an attribute key=value entry. If \p force_meta
* is true then value is printed in meta[...] for irrespective of the show_meta_data_ flag.
*/
Doc PrintAttributeValue(const ObjectRef& value, bool force_meta = false);

/*!
* \brief Returns \p attrs printed as a self-contained value, ie wrapped in braces.
*/
Doc PrintAttrsAsAttributeValue(const Attrs& attrs);

/*!
* \brief Returns \p map printed as a self-contained value, ie wrapped in braces.
*/
Doc PrintMapAsAttributeValue(const Map<ObjectRef, ObjectRef>& map);

Doc PrintSpan(const Span& span);

Doc Print(const ObjectRef& node, bool meta = false, bool try_inline = false);
Expand Down Expand Up @@ -162,7 +195,6 @@ class RelayTextPrinter : public ExprFunctor<Doc(const Expr&)>,
//------------------------------------
// Overload of Attr printing functions
//------------------------------------
Doc PrintAttr(const ObjectRef& value, bool meta = false);
Doc VisitAttrDefault_(const Object* op) final;
Doc VisitAttr_(const ArrayNode* op) final;
Doc VisitAttr_(const tir::IntImmNode* op) final;
Expand Down
3 changes: 3 additions & 0 deletions src/printer/tir_text_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <tvm/ir/type.h>
#include <tvm/ir/type_functor.h>
#include <tvm/node/serialization.h>
#include <tvm/target/target.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/function.h>
#include <tvm/tir/op.h>
Expand Down Expand Up @@ -71,6 +72,8 @@ Doc TIRTextPrinter::Print(const ObjectRef& node) {
return PrintString(node.as<StringObj>());
} else if (node->IsInstance<BufferRegionNode>()) {
return PrintBufferRegion(node.as<BufferRegionNode>());
} else if (node->IsInstance<TargetNode>()) {
return Doc::Text(node.as<TargetNode>()->ToDebugString());
} else {
return this->meta_->GetMetaNode(node);
}
Expand Down
Loading

0 comments on commit 479ec76

Please sign in to comment.