Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 33 additions & 51 deletions docs/arch/runtime.rst
Original file line number Diff line number Diff line change
Expand Up @@ -206,69 +206,51 @@ adding the code back to the central repo. To ease the speed of dispatching, we a
Since usually one ``Object`` could be referenced in multiple places in the language, we use a shared_ptr to keep
track of reference. We use ``ObjectRef`` class to represent a reference to the ``Object``.
We can roughly view ``ObjectRef`` class as shared_ptr to the ``Object`` container.
We can also define subclass ``ObjectRef`` to hold each subtypes of ``Object``. Each subclass of ``Object`` needs to define the VisitAttr function.
We can also define subclass ``ObjectRef`` to hold each subtypes of ``Object``. Each subclass of ``Object`` needs to define the
RegisterReflection function.

.. code:: c

class AttrVisitor {
public:
virtual void Visit(const char* key, double* value) = 0;
virtual void Visit(const char* key, int64_t* value) = 0;
virtual void Visit(const char* key, uint64_t* value) = 0;
virtual void Visit(const char* key, int* value) = 0;
virtual void Visit(const char* key, bool* value) = 0;
virtual void Visit(const char* key, std::string* value) = 0;
virtual void Visit(const char* key, void** value) = 0;
virtual void Visit(const char* key, Type* value) = 0;
virtual void Visit(const char* key, ObjectRef* value) = 0;
// ...
};

class BaseAttrsNode : public Object {
public:
virtual void VisitAttrs(AttrVisitor* v) {}
// ...
};

Each ``Object`` subclass will override this to visit its members. Here is an example implementation of TensorNode.
Each ``Object`` subclass will override this to register its members. Here is an example implementation of IntImmNode.

.. code:: c

class TensorNode : public Object {
public:
/*! \brief The shape of the tensor */
Array<Expr> shape;
/*! \brief data type in the content of the tensor */
Type dtype;
/*! \brief the source operation, can be None */
Operation op;
/*! \brief the output index from source operation */
int value_index{0};
/*! \brief constructor */
TensorNode() {}

void VisitAttrs(AttrVisitor* v) final {
v->Visit("shape", &shape);
v->Visit("dtype", &dtype);
v->Visit("op", &op);
v->Visit("value_index", &value_index);
}
};

In the above examples, both ``Operation`` and ``Array<Expr>`` are ObjectRef.
The VisitAttrs gives us a reflection API to visit each member of the object.
class IntImmNode : public PrimExprNode {
public:
/*! \brief the Internal value. */
int64_t value;

static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<IntImmNode>().def_ro("value", &IntImmNode::value);
}

bool SEqualReduce(const IntImmNode* other, SEqualReducer equal) const {
return equal(dtype, other->dtype) && equal(value, other->value);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(dtype);
hash_reduce(value);
}

static constexpr const char* _type_key = "IntImm";
TVM_DECLARE_FINAL_OBJECT_INFO(IntImmNode, PrimExprNode);
};
// in cc file
TVM_FFI_STATIC_INIT_BLOCK({ IntImmNode::RegisterReflection(); });

The RegisterReflection gives us a reflection API to register each member of the object.
We can use this function to visit the node and serialize any language object recursively.
It also allows us to get members of an object easily in front-end language.
For example, in the following code, we accessed the op field of the TensorNode.
For example, we can access the value field of the IntImmNode.

.. code:: python

import tvm
from tvm import te

x = te.placeholder((3,4), name="x")
# access the op field of TensorNode
print(x.op.name)
x = tvm.tir.IntImm("int32", 1)
# access the value field of IntImmNode
print(x.value)

New ``Object`` can be added to C++ without changing the front-end runtime, making it easy to make extensions to the compiler stack.
Note that this is not the fastest way to expose members to front-end language, but might be one of the simplest
Expand Down
32 changes: 22 additions & 10 deletions include/tvm/ir/diagnostic.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#ifndef TVM_IR_DIAGNOSTIC_H_
#define TVM_IR_DIAGNOSTIC_H_

#include <tvm/ffi/reflection/reflection.h>
#include <tvm/ir/module.h>

#include <sstream>
Expand Down Expand Up @@ -65,13 +66,16 @@ class DiagnosticNode : public Object {
/*! \brief The diagnostic message. */
String message;

// override attr visitor
void VisitAttrs(AttrVisitor* v) {
v->Visit("level", &level);
v->Visit("span", &span);
v->Visit("message", &message);
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<DiagnosticNode>()
.def_ro("level", &DiagnosticNode::level)
.def_ro("span", &DiagnosticNode::span)
.def_ro("message", &DiagnosticNode::message);
}

static constexpr bool _type_has_method_visit_attrs = false;

bool SEqualReduce(const DiagnosticNode* other, SEqualReducer equal) const {
return equal(this->level, other->level) && equal(this->span, other->span) &&
equal(this->message, other->message);
Expand Down Expand Up @@ -165,8 +169,12 @@ class DiagnosticRendererNode : public Object {
public:
ffi::TypedFunction<void(DiagnosticContext ctx)> renderer;

// override attr visitor
void VisitAttrs(AttrVisitor* v) {}
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<DiagnosticRendererNode>().def_ro("renderer", &DiagnosticRendererNode::renderer);
}

static constexpr bool _type_has_method_visit_attrs = false;

static constexpr const char* _type_key = "DiagnosticRenderer";
TVM_DECLARE_FINAL_OBJECT_INFO(DiagnosticRendererNode, Object);
Expand Down Expand Up @@ -199,11 +207,15 @@ class DiagnosticContextNode : public Object {
/*! \brief The renderer set for the context. */
DiagnosticRenderer renderer;

void VisitAttrs(AttrVisitor* v) {
v->Visit("module", &module);
v->Visit("diagnostics", &diagnostics);
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<DiagnosticContextNode>()
.def_ro("module", &DiagnosticContextNode::module)
.def_ro("diagnostics", &DiagnosticContextNode::diagnostics);
}

static constexpr bool _type_has_method_visit_attrs = false;

bool SEqualReduce(const DiagnosticContextNode* other, SEqualReducer equal) const {
return equal(module, other->module) && equal(diagnostics, other->diagnostics);
}
Expand Down
8 changes: 7 additions & 1 deletion include/tvm/ir/env_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#define TVM_IR_ENV_FUNC_H_

#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/reflection.h>
#include <tvm/node/reflection.h>

#include <string>
Expand All @@ -48,7 +49,12 @@ class EnvFuncNode : public Object {
/*! \brief constructor */
EnvFuncNode() {}

void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); }
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<EnvFuncNode>().def_ro("name", &EnvFuncNode::name);
}

static constexpr bool _type_has_method_visit_attrs = false;

bool SEqualReduce(const EnvFuncNode* other, SEqualReducer equal) const {
// name uniquely identifies the env function.
Expand Down
5 changes: 5 additions & 0 deletions include/tvm/ir/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,11 @@ class BaseFuncNode : public RelaxExprNode {
return LinkageType::kInternal;
}

static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<BaseFuncNode>().def_ro("attrs", &BaseFuncNode::attrs);
}

static constexpr const char* _type_key = "BaseFunc";
static constexpr const uint32_t _type_child_slots = 2;
TVM_DECLARE_BASE_OBJECT_INFO(BaseFuncNode, RelaxExprNode);
Expand Down
22 changes: 17 additions & 5 deletions include/tvm/ir/global_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#ifndef TVM_IR_GLOBAL_INFO_H_
#define TVM_IR_GLOBAL_INFO_H_

#include <tvm/ffi/reflection/reflection.h>
#include <tvm/ir/expr.h>
#include <tvm/target/target.h>

Expand Down Expand Up @@ -68,12 +69,17 @@ class VDeviceNode : public GlobalInfoNode {
*/
int vdevice_id;
MemoryScope memory_scope;
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("target", &target);
v->Visit("vdevice_id", &vdevice_id);
v->Visit("memory_scope", &memory_scope);

static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<VDeviceNode>()
.def_ro("target", &VDeviceNode::target)
.def_ro("vdevice_id", &VDeviceNode::vdevice_id)
.def_ro("memory_scope", &VDeviceNode::memory_scope);
}

static constexpr bool _type_has_method_visit_attrs = false;

TVM_DLL bool SEqualReduce(const VDeviceNode* other, SEqualReducer equal) const {
return equal(target, other->target) && equal(vdevice_id, other->vdevice_id) &&
equal(memory_scope, other->memory_scope);
Expand Down Expand Up @@ -103,7 +109,13 @@ class VDevice : public GlobalInfo {
*/
class DummyGlobalInfoNode : public GlobalInfoNode {
public:
void VisitAttrs(tvm::AttrVisitor* v) {}
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<DummyGlobalInfoNode>();
}

static constexpr bool _type_has_method_visit_attrs = false;

static constexpr const char* _type_key = "DummyGlobalInfo";

TVM_DLL bool SEqualReduce(const DummyGlobalInfoNode* other, SEqualReducer equal) const {
Expand Down
8 changes: 7 additions & 1 deletion include/tvm/ir/global_var_supply.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <string>
#include <unordered_map>

#include "tvm/ffi/reflection/reflection.h"
#include "tvm/ir/expr.h"
#include "tvm/ir/module.h"
#include "tvm/ir/name_supply.h"
Expand Down Expand Up @@ -75,7 +76,12 @@ class GlobalVarSupplyNode : public Object {
*/
void ReserveGlobalVar(const GlobalVar& var, bool allow_conflict = false);

void VisitAttrs(AttrVisitor* v) {}
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<GlobalVarSupplyNode>();
}

static constexpr bool _type_has_method_visit_attrs = false;

/*! \brief The NameSupply used to generate unique name hints to GlobalVars. */
NameSupply name_supply_;
Expand Down
8 changes: 7 additions & 1 deletion include/tvm/ir/instrument.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#ifndef TVM_IR_INSTRUMENT_H_
#define TVM_IR_INSTRUMENT_H_

#include <tvm/ffi/reflection/reflection.h>
#include <tvm/ffi/string.h>
#include <tvm/node/reflection.h>

Expand Down Expand Up @@ -136,7 +137,12 @@ class PassInstrumentNode : public Object {
*/
virtual void RunAfterPass(const IRModule& mod, const transform::PassInfo& info) const = 0;

void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); }
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<PassInstrumentNode>().def_ro("name", &PassInstrumentNode::name);
}

static constexpr bool _type_has_method_visit_attrs = false;

static constexpr const char* _type_key = "instrument.PassInstrument";
TVM_DECLARE_BASE_OBJECT_INFO(PassInstrumentNode, Object);
Expand Down
19 changes: 12 additions & 7 deletions include/tvm/ir/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

#include <tvm/ffi/container/array.h>
#include <tvm/ffi/container/map.h>
#include <tvm/ffi/reflection/reflection.h>
#include <tvm/ffi/string.h>
#include <tvm/ir/expr.h>
#include <tvm/ir/function.h>
Expand Down Expand Up @@ -77,7 +78,7 @@ class IRModuleNode : public Object {
*
* \return The result
*
* \tparam TOBjectRef the expected object type.
* \tparam TObjectRef the expected object type.
* \throw Error if the key exists but the value does not match TObjectRef
*
* \code
Expand Down Expand Up @@ -129,14 +130,18 @@ class IRModuleNode : public Object {

IRModuleNode() : source_map() {}

void VisitAttrs(AttrVisitor* v) {
v->Visit("functions", &functions);
v->Visit("global_var_map_", &global_var_map_);
v->Visit("source_map", &source_map);
v->Visit("attrs", &attrs);
v->Visit("global_infos", &global_infos);
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<IRModuleNode>()
.def_ro("functions", &IRModuleNode::functions)
.def_ro("global_var_map_", &IRModuleNode::global_var_map_)
.def_ro("source_map", &IRModuleNode::source_map)
.def_ro("attrs", &IRModuleNode::attrs)
.def_ro("global_infos", &IRModuleNode::global_infos);
}

static constexpr bool _type_has_method_visit_attrs = false;

TVM_DLL bool SEqualReduce(const IRModuleNode* other, SEqualReducer equal) const;

TVM_DLL void SHashReduce(SHashReducer hash_reduce) const;
Expand Down
7 changes: 4 additions & 3 deletions include/tvm/ir/name_supply.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,15 @@
#ifndef TVM_IR_NAME_SUPPLY_H_
#define TVM_IR_NAME_SUPPLY_H_

#include <tvm/ffi/reflection/reflection.h>
#include <tvm/ir/expr.h>

#include <algorithm>
#include <cctype>
#include <string>
#include <unordered_map>
#include <utility>

#include "tvm/ir/expr.h"

namespace tvm {

/*!
Expand Down Expand Up @@ -80,7 +81,7 @@ class NameSupplyNode : public Object {
*/
bool ContainsName(const String& name, bool add_prefix = true);

void VisitAttrs(AttrVisitor* v) {}
static constexpr bool _type_has_method_visit_attrs = false;

// Prefix for all GlobalVar names. It can be empty.
std::string prefix_;
Expand Down
21 changes: 13 additions & 8 deletions include/tvm/ir/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#define TVM_IR_OP_H_

#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/reflection.h>
#include <tvm/ir/attrs.h>
#include <tvm/ir/env_func.h>
#include <tvm/ir/expr.h>
Expand Down Expand Up @@ -90,16 +91,20 @@ class OpNode : public RelaxExprNode {
*/
int32_t support_level = 10;

void VisitAttrs(AttrVisitor* v) {
v->Visit("name", &name);
v->Visit("op_type", &op_type);
v->Visit("description", &description);
v->Visit("arguments", &arguments);
v->Visit("attrs_type_key", &attrs_type_key);
v->Visit("num_inputs", &num_inputs);
v->Visit("support_level", &support_level);
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<OpNode>()
.def_ro("name", &OpNode::name)
.def_ro("op_type", &OpNode::op_type)
.def_ro("description", &OpNode::description)
.def_ro("arguments", &OpNode::arguments)
.def_ro("attrs_type_key", &OpNode::attrs_type_key)
.def_ro("num_inputs", &OpNode::num_inputs)
.def_ro("support_level", &OpNode::support_level);
}

static constexpr bool _type_has_method_visit_attrs = false;

bool SEqualReduce(const OpNode* other, SEqualReducer equal) const {
// pointer equality is fine as there is only one op with the same name.
return this == other;
Expand Down
Loading
Loading