Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TVMScript] Add ObjectPath to LiteralDoc #13821

Merged
merged 2 commits into from
Jan 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 25 additions & 9 deletions include/tvm/script/printer/doc.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
#include <tvm/node/node.h>
#include <tvm/runtime/data_type.h>

#include <string>

namespace tvm {
namespace script {
namespace printer {
Expand Down Expand Up @@ -243,40 +245,54 @@ class LiteralDocNode : public ExprDocNode {
*/
class LiteralDoc : public ExprDoc {
protected:
explicit LiteralDoc(ObjectRef value);
LiteralDoc(ObjectRef value, ObjectPath object_path);
explicit LiteralDoc(ObjectRef value, const Optional<ObjectPath>& object_path);

public:
/*!
* \brief Create a LiteralDoc to represent None/null/empty value.
* \param p The object path
*/
static LiteralDoc None() { return LiteralDoc(ObjectRef(nullptr)); }
static LiteralDoc None(const Optional<ObjectPath>& p) {
return LiteralDoc(ObjectRef(nullptr), p);
}
/*!
* \brief Create a LiteralDoc to represent integer.
* \param v The integer value.
* \param p The object path
*/
static LiteralDoc Int(int64_t v) { return LiteralDoc(IntImm(DataType::Int(64), v)); }
static LiteralDoc Int(int64_t v, const Optional<ObjectPath>& p) {
return LiteralDoc(IntImm(DataType::Int(64), v), p);
}
/*!
* \brief Create a LiteralDoc to represent boolean.
* \param v The boolean value.
* \param p The object path
*/
static LiteralDoc Boolean(bool v) { return LiteralDoc(IntImm(DataType::Bool(), v)); }
static LiteralDoc Boolean(bool v, const Optional<ObjectPath>& p) {
return LiteralDoc(IntImm(DataType::Bool(), v), p);
}
/*!
* \brief Create a LiteralDoc to represent float.
* \param v The float value.
* \param p The object path
*/
static LiteralDoc Float(double v) { return LiteralDoc(FloatImm(DataType::Float(64), v)); }
static LiteralDoc Float(double v, const Optional<ObjectPath>& p) {
return LiteralDoc(FloatImm(DataType::Float(64), v), p);
}
/*!
* \brief Create a LiteralDoc to represent string.
* \param v The string value.
* \param p The object path
*/
static LiteralDoc Str(const String& v) { return LiteralDoc(v); }
static LiteralDoc Str(const String& v, const Optional<ObjectPath>& p) { return LiteralDoc(v, p); }
/*!
* \brief Create a LiteralDoc to represent string.
* \param v The string value.
* \param p The object path
*/
static LiteralDoc DataType(const DLDataType& v) {
return LiteralDoc::Str(runtime::DLDataType2String(v));
static LiteralDoc DataType(const runtime::DataType& v, const Optional<ObjectPath>& p) {
std::string dtype = v.is_void() ? "void" : runtime::DLDataType2String(v);
return LiteralDoc::Str(dtype, p);
}

TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(LiteralDoc, ExprDoc, LiteralDocNode);
Expand Down
9 changes: 5 additions & 4 deletions include/tvm/script/printer/ir_docsifier.h
Original file line number Diff line number Diff line change
Expand Up @@ -259,11 +259,12 @@ inline void FrameNode::ExitWithScope() {

template <class TDoc>
inline TDoc IRDocsifierNode::AsDoc(const ObjectRef& obj, const ObjectPath& path) const {
if (!obj.defined()) {
return Downcast<TDoc>(LiteralDoc::None());
if (obj.defined()) {
Doc d = IRDocsifier::vtable()(dispatch_tokens.back(), obj, path, GetRef<IRDocsifier>(this));
d->source_paths.push_back(path);
return Downcast<TDoc>(d);
}
return Downcast<TDoc>(
IRDocsifier::vtable()(dispatch_tokens.back(), obj, path, GetRef<IRDocsifier>(this)));
return Downcast<TDoc>(LiteralDoc::None(path));
}

inline void FrameNode::AddDispatchToken(const IRDocsifier& d, const String& token) {
Expand Down
32 changes: 26 additions & 6 deletions python/tvm/script/printer/doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,17 +155,37 @@ class LiteralDoc(ExprDoc):

value: Union[str, IntImm, FloatImm, None]

def __init__(self, value: Union[str, float, bool, int, None]):
def __init__(
self,
value: Union[str, float, bool, int, None],
path: Optional[ObjectPath] = None,
):
if value is None:
self.__init_handle_by_constructor__(_ffi_api.LiteralDocNone) # type: ignore # pylint: disable=no-member
self.__init_handle_by_constructor__(_ffi_api.LiteralDocNone, path) # type: ignore # pylint: disable=no-member
elif isinstance(value, str):
self.__init_handle_by_constructor__(_ffi_api.LiteralDocStr, value) # type: ignore # pylint: disable=no-member
self.__init_handle_by_constructor__(
_ffi_api.LiteralDocStr, # type: ignore # pylint: disable=no-member
value,
path,
)
elif isinstance(value, float):
self.__init_handle_by_constructor__(_ffi_api.LiteralDocFloat, value) # type: ignore # pylint: disable=no-member
self.__init_handle_by_constructor__(
_ffi_api.LiteralDocFloat, # type: ignore # pylint: disable=no-member
value,
path,
)
elif isinstance(value, bool):
self.__init_handle_by_constructor__(_ffi_api.LiteralDocBoolean, value) # type: ignore # pylint: disable=no-member
self.__init_handle_by_constructor__(
_ffi_api.LiteralDocBoolean, # type: ignore # pylint: disable=no-member
value,
path,
)
elif isinstance(value, int):
self.__init_handle_by_constructor__(_ffi_api.LiteralDocInt, value) # type: ignore # pylint: disable=no-member
self.__init_handle_by_constructor__(
_ffi_api.LiteralDocInt, # type: ignore # pylint: disable=no-member
value,
path,
)
else:
raise TypeError(f"Unsupported type {type(value)} for LiteralDoc")

Expand Down
26 changes: 9 additions & 17 deletions src/script/printer/doc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,12 @@ StmtBlockDoc::StmtBlockDoc(Array<StmtDoc> stmts) {
this->data_ = std::move(n);
}

LiteralDoc::LiteralDoc(ObjectRef value) {
LiteralDoc::LiteralDoc(ObjectRef value, const Optional<ObjectPath>& object_path) {
ObjectPtr<LiteralDocNode> n = make_object<LiteralDocNode>();
n->value = value;
this->data_ = std::move(n);
}

LiteralDoc::LiteralDoc(ObjectRef value, ObjectPath object_path) {
ObjectPtr<LiteralDocNode> n = make_object<LiteralDocNode>();
n->value = value;
n->source_paths.push_back(object_path);
if (object_path.defined()) {
n->source_paths.push_back(object_path.value());
}
this->data_ = std::move(n);
}

Expand Down Expand Up @@ -250,15 +246,11 @@ TVM_REGISTER_GLOBAL("script.printer.StmtBlockDoc").set_body_typed([](Array<StmtD
});

TVM_REGISTER_NODE_TYPE(LiteralDocNode);
TVM_REGISTER_GLOBAL("script.printer.LiteralDocNone").set_body_typed<LiteralDoc()>(LiteralDoc::None);
TVM_REGISTER_GLOBAL("script.printer.LiteralDocInt")
.set_body_typed<LiteralDoc(int64_t)>(LiteralDoc::Int);
TVM_REGISTER_GLOBAL("script.printer.LiteralDocBoolean")
.set_body_typed<LiteralDoc(bool)>(LiteralDoc::Boolean);
TVM_REGISTER_GLOBAL("script.printer.LiteralDocFloat")
.set_body_typed<LiteralDoc(double)>(LiteralDoc::Float);
TVM_REGISTER_GLOBAL("script.printer.LiteralDocStr")
.set_body_typed<LiteralDoc(const String&)>(LiteralDoc::Str);
TVM_REGISTER_GLOBAL("script.printer.LiteralDocNone").set_body_typed(LiteralDoc::None);
TVM_REGISTER_GLOBAL("script.printer.LiteralDocInt").set_body_typed(LiteralDoc::Int);
TVM_REGISTER_GLOBAL("script.printer.LiteralDocBoolean").set_body_typed(LiteralDoc::Boolean);
TVM_REGISTER_GLOBAL("script.printer.LiteralDocFloat").set_body_typed(LiteralDoc::Float);
TVM_REGISTER_GLOBAL("script.printer.LiteralDocStr").set_body_typed(LiteralDoc::Str);

TVM_REGISTER_NODE_TYPE(IdDocNode);
TVM_REGISTER_GLOBAL("script.printer.IdDoc").set_body_typed([](String name) { return IdDoc(name); });
Expand Down
18 changes: 9 additions & 9 deletions src/script/printer/ir/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,26 +63,26 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)

TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<GlobalVar>("", [](GlobalVar gv, ObjectPath p, IRDocsifier d) -> Doc {
return IR("GlobalVar")->Call({LiteralDoc::Str(gv->name_hint)});
return IR("GlobalVar")->Call({LiteralDoc::Str(gv->name_hint, p->Attr("name_hint"))});
});

TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<Op>("", [](Op op, ObjectPath p, IRDocsifier d) -> Doc {
return IR("Op")->Call({LiteralDoc::Str(op->name)});
return IR("Op")->Call({LiteralDoc::Str(op->name, p->Attr("name"))});
});

TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<TypeVar>("", [](TypeVar type_var, ObjectPath p, IRDocsifier d) -> Doc {
return IR("TypeVar")->Call({LiteralDoc::Str(type_var->name_hint), //
LiteralDoc::Str(TypeKind2String(type_var->kind))});
.set_dispatch<TypeVar>("", [](TypeVar var, ObjectPath p, IRDocsifier d) -> Doc {
return IR("TypeVar")->Call({LiteralDoc::Str(var->name_hint, p->Attr("name_hint")), //
LiteralDoc::Str(TypeKind2String(var->kind), p->Attr("kind"))});
});

TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<GlobalTypeVar>( //
"", [](GlobalTypeVar type_var, ObjectPath p, IRDocsifier d) -> Doc {
"", [](GlobalTypeVar var, ObjectPath p, IRDocsifier d) -> Doc {
return IR("GlobalTypeVar")
->Call({LiteralDoc::Str(type_var->name_hint), //
LiteralDoc::Str(TypeKind2String(type_var->kind))});
->Call({LiteralDoc::Str(var->name_hint, p->Attr("name_hint")),
LiteralDoc::Str(TypeKind2String(var->kind), p->Attr("kind"))});
});

TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
Expand All @@ -94,7 +94,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<TensorType>("", [](TensorType type, ObjectPath p, IRDocsifier d) -> Doc {
return IR("TensorType")
->Call({d->AsDoc<ExprDoc>(type->shape, p->Attr("shape")),
LiteralDoc::DataType(type->dtype)});
LiteralDoc::DataType(type->dtype, p->Attr("dtype"))});
});

TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
Expand Down
2 changes: 1 addition & 1 deletion src/script/printer/ir/misc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ namespace printer {

TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<String>("", [](String s, ObjectPath p, IRDocsifier d) -> Doc {
return LiteralDoc::Str(s);
return LiteralDoc::Str(s, p);
});

TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
Expand Down
1 change: 0 additions & 1 deletion src/script/printer/legacy_repr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -588,7 +588,6 @@ TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)

TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
.set_dispatch<PrimFuncNode>([](const ObjectRef& ref, ReprLegacyPrinter* p) {
// TODO(tvm-team) redirect to Text printer once we have a good text format.
auto* node = static_cast<const PrimFuncNode*>(ref.get());
(*p) << "PrimFunc(" << node->params << ") ";
if (node->attrs.defined()) {
Expand Down
18 changes: 12 additions & 6 deletions src/script/printer/tir/block.cc
Original file line number Diff line number Diff line change
Expand Up @@ -118,16 +118,20 @@ Doc PrintBlock(IRDocsifier d, tir::Block block, ObjectPath block_p, //
lhs.reserve(m);
loop_var_doc.reserve(m);
std::string binding_type = "";
Array<ObjectPath> binding_paths;
for (int i : remap_vars_indices) {
tir::IterVar iter_var = block->iter_vars[i];
ObjectPath iter_var_p = block_p->Attr("iter_var")->ArrayIndex(i);
ObjectPath iter_var_p = block_p->Attr("iter_vars")->ArrayIndex(i);
lhs.push_back(DefineVar(iter_var->var, *frame, d));
loop_var_doc.push_back(d->AsDoc<ExprDoc>(realize->iter_values[i],
realize_p->Attr("iter_values")->ArrayIndex(i)));
binding_paths.push_back(iter_var_p->Attr("iter_type"));
binding_type += iter_var->iter_type == tir::IterVarType::kDataPar ? "S" : "R";
}
ExprDoc rhs = TIR("axis")->Attr("remap");
rhs = rhs->Call({LiteralDoc::Str(binding_type), ListDoc(loop_var_doc)});
ExprDoc binding_str = LiteralDoc::Str(binding_type, NullOpt);
binding_str->source_paths = std::move(binding_paths);
rhs = rhs->Call({binding_str, ListDoc(loop_var_doc)});
(*frame)->stmts.push_back(AssignDoc(TupleDoc(lhs), rhs, NullOpt));
remap_vars_indices.clear();
}
Expand Down Expand Up @@ -198,11 +202,13 @@ Doc PrintBlock(IRDocsifier d, tir::Block block, ObjectPath block_p, //
Array<ExprDoc> kwargs_values;
if (!realize) {
kwargs_keys.push_back("no_realize");
kwargs_values.push_back(LiteralDoc::Boolean(true));
kwargs_values.push_back(LiteralDoc::Boolean(true, NullOpt));
}
return ScopeDoc(
NullOpt, TIR("block")->Call({LiteralDoc::Str(block->name_hint)}, kwargs_keys, kwargs_values),
(*frame)->stmts);
return ScopeDoc(NullOpt,
TIR("block") //
->Call({LiteralDoc::Str(block->name_hint, block_p->Attr("name_hint"))},
kwargs_keys, kwargs_values),
(*frame)->stmts);
}

TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
Expand Down
15 changes: 9 additions & 6 deletions src/script/printer/tir/buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ Map<String, ExprDoc> BufferAttrs(const tir::Buffer& buffer, const ObjectPath& p,
array_out_line_var_def(buffer->shape, p->Attr("shape"), "shape");
// Step 2. Handle `buffer.dtype`
if (buffer->dtype != Default::BufferDType()) {
kwargs.Set("dtype", LiteralDoc::DataType(buffer->dtype));
kwargs.Set("dtype", LiteralDoc::DataType(buffer->dtype, p->Attr("dtype")));
}
// Step 3. Handle `buffer.data`
implicit_var_def(buffer->data, p->Attr("data"), "data");
Expand All @@ -78,20 +78,22 @@ Map<String, ExprDoc> BufferAttrs(const tir::Buffer& buffer, const ObjectPath& p,
{
String scope = buffer.scope();
if (scope != "global") {
kwargs.Set("scope", LiteralDoc::Str(scope));
kwargs.Set(
"scope",
LiteralDoc::Str(scope, p->Attr("data")->Attr("type_annotation")->Attr("storage_scope")));
}
}
// Step 7. Handle `buffer.data_alignment`
if (buffer->data_alignment != runtime::kAllocAlignment) {
kwargs.Set("align", LiteralDoc::Int(buffer->data_alignment));
kwargs.Set("align", LiteralDoc::Int(buffer->data_alignment, p->Attr("data_alignment")));
}
// Step 8. Handle `buffer.offset_factor`
if (needs_print_factor || buffer->offset_factor != 1) {
kwargs.Set("offset_factor", LiteralDoc::Int(buffer->offset_factor));
kwargs.Set("offset_factor", LiteralDoc::Int(buffer->offset_factor, p->Attr("offset_factor")));
}
// Step 9. Handle `buffer.buffer_type`
if (buffer->buffer_type != tir::BufferType::kDefault) {
kwargs.Set("type", LiteralDoc::Str("auto"));
kwargs.Set("type", LiteralDoc::Str("auto", p->Attr("buffer_type")));
}
// Step 10. Handle `buffer.axis_separator`
if (!buffer->axis_separators.empty()) {
Expand Down Expand Up @@ -130,7 +132,8 @@ ExprDoc BufferAttn(const tir::Buffer& buffer, const ObjectPath& p, const Frame&
const IRDocsifier& d) {
Map<String, ExprDoc> attrs = BufferAttrs(buffer, p, frame, d);
ExprDoc shape = attrs.Get("shape").value();
ExprDoc dtype = attrs.Get("dtype").value_or(LiteralDoc::DataType(buffer->dtype));
ExprDoc dtype =
attrs.Get("dtype").value_or(LiteralDoc::DataType(buffer->dtype, p->Attr("dtype")));
return TIR("Buffer")->Call({shape, dtype}, {}, {});
}

Expand Down
Loading