Skip to content

Commit

Permalink
Implement explicit IR representation of memory alloction
Browse files Browse the repository at this point in the history
  • Loading branch information
jroesch committed Oct 18, 2019
1 parent 2dac17d commit bf9555a
Show file tree
Hide file tree
Showing 41 changed files with 1,752 additions and 316 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,7 @@ add_library(tvm_runtime SHARED ${RUNTIME_SRCS})
if(USE_RELAY_DEBUG)
message(STATUS "Building Relay in debug mode...")
set_target_properties(tvm PROPERTIES COMPILE_DEFINITIONS "USE_RELAY_DEBUG")
set_target_properties(tvm PROPERTIES COMPILE_DEFINITIONS "DMLC_LOG_DEBUG")
else()
set_target_properties(tvm PROPERTIES COMPILE_DEFINITIONS "NDEBUG")
endif(USE_RELAY_DEBUG)
Expand Down
37 changes: 37 additions & 0 deletions include/tvm/relay/attrs/annotation.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#define TVM_RELAY_ATTRS_ANNOTATION_H_

#include <tvm/attrs.h>
#include <tvm/relay/expr.h>
#include <string>

namespace tvm {
Expand Down Expand Up @@ -57,6 +58,42 @@ struct CastHintAttrs : public tvm::AttrsNode<CastHintAttrs> {
}
};

/*!
* \brief Options for the device annotation operators.
*/
struct AllocTensorAttrs : public tvm::AttrsNode<AllocTensorAttrs> {
tvm::relay::Constant const_shape;
Array<IndexExpr> assert_shape;
DataType dtype;

TVM_DECLARE_ATTRS(AllocTensorAttrs, "relay.attrs.AllocTensorAttrs") {
TVM_ATTR_FIELD(dtype)
.describe(
"The virutal device/context type that an expression is annotated with.")
.set_default(Float(32, 1));
TVM_ATTR_FIELD(const_shape)
.describe(
"The virutal device/context type that an expression is annotated with.");
TVM_ATTR_FIELD(assert_shape)
.describe(
"The virutal device/context type that an expression is annotated with.");
}
};

/*!
* \brief Options for the device annotation operators.
*/
struct ShapeFuncAttrs : public tvm::AttrsNode<ShapeFuncAttrs> {
bool dependent{false};

TVM_DECLARE_ATTRS(ShapeFuncAttrs, "relay.attrs.ShapeFuncAttrs") {
TVM_ATTR_FIELD(dependent)
.describe(
"Wheather the shape function is input dependent.")
.set_default(false);
}
};

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_ANNOTATION_H_
6 changes: 6 additions & 0 deletions include/tvm/relay/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,12 @@ namespace relay {
(*fdebug)("RELAY_DEBUG", __FILE__, __LINE__, __VA_ARGS__); \
}

#define RELAY_DEBUG_INTERP(...) \
{ auto fdebug = runtime::Registry::Get("relay.debug_interp"); \
CHECK(fdebug) << "Could not find Relay Python debugger function."; \
(*fdebug)("RELAY_DEBUG", __FILE__, __LINE__, __VA_ARGS__); \
}

/*!
* \brief We always used NodeRef for referencing nodes.
*
Expand Down
30 changes: 30 additions & 0 deletions include/tvm/relay/expr_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,36 @@ class ExprMutator
std::unordered_map<Expr, Expr, NodeHash, NodeEqual> memo_;
};

template<typename R>
class OpMatch {
public:
using MatchFunc =
std::function<R(const Array<Expr>& args, const Attrs& attrs, const Array<Type>& type_args)>;

inline OpMatch& Match(const std::string& op_name, MatchFunc func) {
auto op = Op::Get(op_name);
match_map_.insert({op, func});
return *this;
}

inline R operator()(const Call& call) {
auto it = match_map_.find(Downcast<Op>(call->op));
if (it != match_map_.end()) {
return it->second(call->args, call->attrs, call->type_args);
} else {
if (default_ != nullptr) {
return default_(call->args, call->attrs, call->type_args);
} else {
LOG(FATAL) << "unexpected operation " << call->op;
}
}
}

private:
std::unordered_map<Op, MatchFunc, NodeHash, NodeEqual> match_map_;
MatchFunc default_;
};

/*!
* \brief recursively visit the ir in post DFS order node, apply fvisit
* Each node is guaranteed to be visited only once.
Expand Down
9 changes: 8 additions & 1 deletion include/tvm/relay/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ class ModuleNode : public RelayNode {
/*! \brief A map from global type vars to ADT type data. */
tvm::Map<GlobalTypeVar, TypeData> type_definitions;


ModuleNode() {}

void VisitAttrs(tvm::AttrVisitor* v) final {
Expand All @@ -76,7 +77,8 @@ class ModuleNode : public RelayNode {
}

TVM_DLL static Module make(tvm::Map<GlobalVar, Function> global_funcs,
tvm::Map<GlobalTypeVar, TypeData> global_type_defs);
tvm::Map<GlobalTypeVar, TypeData> global_type_defs,
std::unordered_set<std::string> imports = {});

/*!
* \brief Add a function to the global environment.
Expand Down Expand Up @@ -235,6 +237,11 @@ class ModuleNode : public RelayNode {
*/
TVM_DLL void ImportFromStd(const std::string& path);

/*!
* \brief The set of imported files.
*/
TVM_DLL std::unordered_set<std::string> Imports() const;

/*! \brief Construct a module from a standalone expression.
*
* Allows one to optionally pass a global function map and
Expand Down
14 changes: 14 additions & 0 deletions include/tvm/runtime/object.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ enum struct ObjectTag {
kClosure = 1U,
/*! \brief The tag of a structure. */
kDatatype = 2U,
/*! \brief An external resource. */
kExternal = 3U,
};

std::ostream& operator<<(std::ostream& os, const ObjectTag&);
Expand Down Expand Up @@ -308,6 +310,7 @@ class ObjectPtr {
struct TensorCell;
struct DatatypeCell;
struct ClosureCell;
struct ExternalCell;

/*!
* \brief A managed object in the TVM runtime.
Expand All @@ -334,10 +337,13 @@ class Object {
static Object Tuple(const std::vector<Object>& fields);
/*! \brief Construct a closure object. */
static Object Closure(size_t func_index, const std::vector<Object>& free_vars);
/*! \brief Construct a closure object. */
static Object External(void* ext_data);

ObjectPtr<TensorCell> AsTensor() const;
ObjectPtr<DatatypeCell> AsDatatype() const;
ObjectPtr<ClosureCell> AsClosure() const;
ObjectPtr<ExternalCell> AsExt() const;
};

/*! \brief An object containing an NDArray. */
Expand Down Expand Up @@ -369,6 +375,14 @@ struct ClosureCell : public ObjectCell {
: ObjectCell(ObjectTag::kClosure), func_index(func_index), free_vars(free_vars) {}
};

/*! \brief An object representing a closure. */
struct ExternalCell : public ObjectCell {
void *ext_data;

ExternalCell(void *ext_data)
: ObjectCell(ObjectTag::kExternal), ext_data(ext_data) {}
};

/*! \brief Extract the NDArray from a tensor object. */
NDArray ToNDArray(const Object& obj);

Expand Down
22 changes: 19 additions & 3 deletions include/tvm/runtime/vm.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ enum class Opcode {
GetTag = 13U,
LoadConsti = 14U,
Fatal = 15U,
AllocStorage = 16U,
};

/*! \brief A single virtual machine instruction.
Expand All @@ -89,6 +90,7 @@ struct Instruction {

union {
struct /* AllocTensor Operands */ {
RegName storage;
/*! \brief The number of dimensions. */
uint32_t ndim;
/*! \brief The shape of tensor. */
Expand All @@ -97,6 +99,7 @@ struct Instruction {
DLDataType dtype;
} alloc_tensor;
struct /* AllocTensorReg Operands */ {
RegName storage;
/*! \brief The register to read the shape out of. */
RegName shape_register;
/*! \brief The datatype of tensor to be allocated. */
Expand Down Expand Up @@ -184,6 +187,11 @@ struct Instruction {
/*! \brief The free variables as an array. */
RegName* free_vars;
};
struct /* AllocStorage Operands */ {
RegName allocation_size;
RegName alignment;
TVMType dtype_hint;
} alloc_storage;
};

/*! \brief Construct a return instruction.
Expand All @@ -193,7 +201,7 @@ struct Instruction {
static Instruction Ret(RegName return_reg);
/*! \brief Construct a fatal instruction.
* \return The fatal instruction.
* */
* */
static Instruction Fatal();
/*! \brief Construct a invoke packed instruction.
* \param packed_index The index of the packed function.
Expand All @@ -210,15 +218,17 @@ struct Instruction {
* \param dst The destination register.
* \return The allocate tensor instruction.
*/
static Instruction AllocTensor(std::vector<int64_t> shape, DLDataType dtype, RegName dst);
static Instruction AllocTensor(RegName storage, std::vector<int64_t> shape, DLDataType dtype, RegName dst);
/*! \brief Construct an allocate tensor instruction with register.
* \param The storage to allocate out of.
* \param shape_register The register containing the shape.
* \param dtype The dtype of the tensor.
* \param dst The destination register.
* \return The allocate tensor instruction.
*/
static Instruction AllocTensorReg(RegName shape_register, DLDataType dtype, RegName dst);
static Instruction AllocTensorReg(RegName storage, RegName shape_register, DLDataType dtype, RegName dst);
/*! \brief Construct an allocate datatype instruction.
* \param The storage to allocate out of.
* \param tag The datatype tag.
* \param num_fields The number of fields for the datatype.
* \param fields The registers containing the fields.
Expand Down Expand Up @@ -295,6 +305,12 @@ struct Instruction {
*/
static Instruction Move(RegName src, RegName dst);

/*! \brief Allocate a storage block.
* \param size The size of the allocation.
* \return The alloc storage instruction.
*/
static Instruction AllocStorage(RegName size, RegName alignment, TVMType dtype_hint, RegName dst);

Instruction();
Instruction(const Instruction& instr);
Instruction& operator=(const Instruction& instr);
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/relay/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@
from . import qnn

from .scope_builder import ScopeBuilder
# Load Memory pass
from . import memory_alloc

# Required to traverse large programs
setrecursionlimit(10000)
Expand Down
4 changes: 4 additions & 0 deletions python/tvm/relay/backend/compile_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,10 @@ def lower(self, source_func, target=None):
msg += "--------------------------\n"
raise RuntimeError(msg)

def lower_shape_func(self, source_func, target=None):
key = _get_cache_key(source_func, target)
return _backend._CompileEngineLowerShapeFunc(self, key)

def jit(self, source_func, target=None):
"""JIT a source_func to a tvm.Function.
Expand Down
6 changes: 5 additions & 1 deletion python/tvm/relay/debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,13 @@ def _debugger_init(expr, stack):
import pdb
pdb.set_trace()

# pylint: disable=unused-argument
@register_func("relay.debug")
def _debug(*args):
import pdb; pdb.set_trace()

# pylint: disable=unused-argument
@register_func("relay.debug_interp")
def _debug_interp(*args):
_, _, _, ist = args
print("Relay Debugger")
print(" You can manipulate the expression under evaluation with the name `expr`.")
Expand Down
3 changes: 3 additions & 0 deletions python/tvm/relay/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,9 @@ def __call__(self, *args):
"""
return Call(self, args, None, None)

def set_attribute(self, name, ref):
return _expr.FunctionSetAttr(self, name, ref)


@register_relay_node
class Call(Expr):
Expand Down
5 changes: 4 additions & 1 deletion python/tvm/relay/expr_functor.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,10 @@ def visit_constructor(self, con):
return con

def visit_match(self, m):
return Match(self.visit(m.data), [Clause(c.lhs, self.visit(c.rhs)) for c in m.clauses])
return Match(
self.visit(m.data),
[Clause(c.lhs, self.visit(c.rhs)) for c in m.clauses],
m.complete)

def visit_ref_create(self, r):
return RefCreate(self.visit(r.value))
Expand Down
Loading

0 comments on commit bf9555a

Please sign in to comment.