Skip to content

Commit

Permalink
Manifest memory allocations
Browse files Browse the repository at this point in the history
  • Loading branch information
jroesch committed Oct 29, 2019
1 parent 1853ea2 commit c4eb59f
Show file tree
Hide file tree
Showing 26 changed files with 392 additions and 310 deletions.
2 changes: 1 addition & 1 deletion cmake/util/FindANTLR.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -61,5 +61,5 @@ macro(find_antlr use_antlr)
elseif(NOT ${use_antlr} STREQUAL "OFF")
set(ANTLR4 ${JAVA_PROGRAM} -jar ${use_antlr})
endif()
message(STATUS "ANTLR4="${ANTLR4})
message(STATUS "ANTLR4=${ANTLR4}")
endmacro(find_antlr)
4 changes: 2 additions & 2 deletions include/tvm/expr_operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@
#ifndef TVM_EXPR_OPERATOR_H_
#define TVM_EXPR_OPERATOR_H_

#include <tvm/expr.h>
#include <tvm/ir.h>
#include <algorithm>
#include <type_traits>
#include "expr.h"
#include "ir.h"

namespace tvm {

Expand Down
37 changes: 0 additions & 37 deletions include/tvm/relay/attrs/annotation.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
#define TVM_RELAY_ATTRS_ANNOTATION_H_

#include <tvm/attrs.h>
#include <tvm/relay/expr.h>
#include <string>

namespace tvm {
Expand Down Expand Up @@ -58,42 +57,6 @@ struct CastHintAttrs : public tvm::AttrsNode<CastHintAttrs> {
}
};

/*!
* \brief Options for the device annotation operators.
*/
struct AllocTensorAttrs : public tvm::AttrsNode<AllocTensorAttrs> {
tvm::relay::Constant const_shape;
Array<IndexExpr> assert_shape;
DataType dtype;

TVM_DECLARE_ATTRS(AllocTensorAttrs, "relay.attrs.AllocTensorAttrs") {
TVM_ATTR_FIELD(dtype)
.describe(
"The virutal device/context type that an expression is annotated with.")
.set_default(Float(32, 1));
TVM_ATTR_FIELD(const_shape)
.describe(
"The virutal device/context type that an expression is annotated with.");
TVM_ATTR_FIELD(assert_shape)
.describe(
"The virutal device/context type that an expression is annotated with.");
}
};

/*!
* \brief Options for the device annotation operators.
*/
struct ShapeFuncAttrs : public tvm::AttrsNode<ShapeFuncAttrs> {
bool dependent{false};

TVM_DECLARE_ATTRS(ShapeFuncAttrs, "relay.attrs.ShapeFuncAttrs") {
TVM_ATTR_FIELD(dependent)
.describe(
"Wheather the shape function is input dependent.")
.set_default(false);
}
};

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_ANNOTATION_H_
17 changes: 9 additions & 8 deletions include/tvm/relay/attrs/memory.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ namespace tvm {
namespace relay {

/*!
* \brief Options for the device annotation operators.
* \brief Options for allocating tensors.
*/
struct AllocTensorAttrs : public tvm::AttrsNode<AllocTensorAttrs> {
tvm::relay::Constant const_shape;
Expand All @@ -46,24 +46,25 @@ struct AllocTensorAttrs : public tvm::AttrsNode<AllocTensorAttrs> {
.set_default(Float(32, 1));
TVM_ATTR_FIELD(const_shape)
.describe(
"The shape if constant used to aid in type inference.");
"The shape of constant used to aid in type inference.");
TVM_ATTR_FIELD(assert_shape)
.describe(
"The shape to cast the return type of the allocation to, used to specify the shape obtained via further analysis.");
"The shape to cast the return type of the allocation to, "\
"used to specify the shape obtained via further analysis.");
}
};

/*!
* \brief Options for the device annotation operators.
* \brief Options for the shape function operator.
*/
struct ShapeFuncAttrs : public tvm::AttrsNode<ShapeFuncAttrs> {
bool dependent{false};
Array<tvm::Integer> is_input;

TVM_DECLARE_ATTRS(ShapeFuncAttrs, "relay.attrs.ShapeFuncAttrs") {
TVM_ATTR_FIELD(dependent)
TVM_ATTR_FIELD(is_input)
.describe(
"Wheather the shape function is input dependent.")
.set_default(false);
"A bool indicating whether the shape function should"\
"expect shape or input in each position.");
}
};

Expand Down
13 changes: 13 additions & 0 deletions include/tvm/relay/expr_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -230,18 +230,29 @@ class ExprMutator
std::unordered_map<Expr, Expr, NodeHash, NodeEqual> memo_;
};

/*! \brief A helper class for matching and rewriting operators. */
template<typename R>
class OpMatch {
public:
using MatchFunc =
std::function<R(const Array<Expr>& args, const Attrs& attrs, const Array<Type>& type_args)>;

/*! \brief Match an operator with the given name.
* \param op_name The name of the operator to match.
* \param func The function to execute when it matches.
* \return A self-reference for builder style API.
*/
inline OpMatch& Match(const std::string& op_name, MatchFunc func) {
auto op = Op::Get(op_name);
match_map_.insert({op, func});
return *this;
}

/*! \brief Rewrite a call operation based on the operator and the registered
* match functions.
* \param call The call to rewrite.
* \return The result of rewriting.
*/
inline R operator()(const Call& call) {
auto it = match_map_.find(Downcast<Op>(call->op));
if (it != match_map_.end()) {
Expand All @@ -256,7 +267,9 @@ class OpMatch {
}

private:
/*! \brief The match function map. */
std::unordered_map<Op, MatchFunc, NodeHash, NodeEqual> match_map_;
/*! \brief An optional default case. */
MatchFunc default_;
};

Expand Down
12 changes: 12 additions & 0 deletions include/tvm/runtime/object.h
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,8 @@ class Object {
* \note The deleter will be called when ref_counter_ becomes zero.
*/
inline void DecRef();

private:
/*!
* \return The usage count of the cell.
* \note We use stl style naming to be consistent with known API in shared_ptr.
Expand Down Expand Up @@ -675,6 +677,16 @@ struct ObjectEqual {
operator bool() const { return data_ != nullptr; } \
using ContainerType = ObjectName;

#define TVM_DEFINE_OBJECT_REF_METHODS_MUT(TypeName, ParentType, ObjectName) \
TypeName() {} \
explicit TypeName( \
::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) \
: ParentType(n) {} \
ObjectName* operator->() { \
return static_cast<ObjectName*>(data_.get()); \
} \
operator bool() const { return data_ != nullptr; } \
using ContainerType = ObjectName;

// Implementations details below
// Object reference counting.
Expand Down
20 changes: 14 additions & 6 deletions include/tvm/runtime/vm.h
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ struct Instruction {

union {
struct /* AllocTensor Operands */ {
/*! \brief The storage to allocate from. */
RegName storage;
/*! \brief The number of dimensions. */
uint32_t ndim;
Expand All @@ -168,6 +169,7 @@ struct Instruction {
DLDataType dtype;
} alloc_tensor;
struct /* AllocTensorReg Operands */ {
/*! \brief The storage to allocate from. */
RegName storage;
/*! \brief The register to read the shape out of. */
RegName shape_register;
Expand Down Expand Up @@ -257,8 +259,11 @@ struct Instruction {
RegName* free_vars;
};
struct /* AllocStorage Operands */ {
/*! \brief The size of the allocation. */
RegName allocation_size;
/*! \brief The alignment of the allocation. */
RegName alignment;
/*! \brief The hint of the dtype. */
DLDataType dtype_hint;
} alloc_storage;
};
Expand All @@ -282,30 +287,32 @@ struct Instruction {
static Instruction InvokePacked(Index packed_index, Index arity, Index output_size,
const std::vector<RegName>& args);
/*! \brief Construct an allocate tensor instruction with constant shape.
* \param storage The storage to allocate out of.
* \param shape The shape of the tensor.
* \param dtype The dtype of the tensor.
* \param dst The destination register.
* \return The allocate tensor instruction.
*/
static Instruction AllocTensor(RegName storage, const std::vector<int64_t>& shape, DLDataType dtype, RegName dst);
static Instruction AllocTensor(RegName storage,
const std::vector<int64_t>& shape, DLDataType dtype, RegName dst);
/*! \brief Construct an allocate tensor instruction with register.
* \param The storage to allocate out of.
* \param storage The storage to allocate out of.
* \param shape_register The register containing the shape.
* \param dtype The dtype of the tensor.
* \param dst The destination register.
* \return The allocate tensor instruction.
*/
static Instruction AllocTensorReg(RegName storage, RegName shape_register, DLDataType dtype, RegName dst);
static Instruction AllocTensorReg(RegName storage,
RegName shape_register, DLDataType dtype, RegName dst);
/*! \brief Construct an allocate datatype instruction.
* \param The storage to allocate out of.
* \param tag The datatype tag.
* \param num_fields The number of fields for the datatype.
* \param fields The registers containing the fields.
* \param dst The register name of the destination.
* \return The allocate instruction tensor.
*/
static Instruction AllocADT(Index tag, Index num_fields, const std::vector<RegName>& fields,
RegName dst);
RegName dst);
/*! \brief Construct an allocate closure instruction.
* \param func_index The index of the function table.
* \param num_freevar The number of free variables.
Expand Down Expand Up @@ -381,7 +388,8 @@ struct Instruction {
* \param dst The destination to place the storage.
* \return The alloc storage instruction.
*/
static Instruction AllocStorage(RegName size, RegName alignment, DLDataType dtype_hint, RegName dst);
static Instruction AllocStorage(RegName size, RegName alignment,
DLDataType dtype_hint, RegName dst);

Instruction();
Instruction(const Instruction& instr);
Expand Down
3 changes: 2 additions & 1 deletion python/tvm/relay/debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ def _debugger_init(expr, stack):

@register_func("relay.debug")
def _debug(*args):
import pdb; pdb.set_trace()
import pdb
pdb.set_trace()

# pylint: disable=unused-argument
@register_func("relay.debug_interp")
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/relay/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,8 @@ def set_params(self, params):
if isinstance(value, NDArray):
params[key] = Constant(value)

return _expr.FunctionSetParams(self, params)

def set_attribute(self, name, ref):
return _expr.FunctionSetAttr(self, name, ref)

Expand Down
Loading

0 comments on commit c4eb59f

Please sign in to comment.