Skip to content

Commit

Permalink
[Relay] Fix memory leak in the interpreter (#4155)
Browse files Browse the repository at this point in the history
* save

lint

* address reviewer comment
  • Loading branch information
MarisaKirisame authored and icemelon committed Oct 24, 2019
1 parent b08fe81 commit 2e0dbaa
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 7 deletions.
26 changes: 26 additions & 0 deletions include/tvm/relay/interpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,32 @@ class ClosureNode : public ValueNode {

RELAY_DEFINE_NODE_REF(Closure, ClosureNode, Value);

/*! \brief A Relay Recursive Closure. A closure that has a name. */
class RecClosure;

/*! \brief The container type of RecClosure. */
class RecClosureNode : public ValueNode {
public:
/*! \brief The closure. */
Closure clos;
/*! \brief variable the closure bind to. */
Var bind;

RecClosureNode() {}

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("clos", &clos);
v->Visit("bind", &bind);
}

TVM_DLL static RecClosure make(Closure clos, Var bind);

static constexpr const char* _type_key = "relay.RecClosure";
TVM_DECLARE_NODE_TYPE_INFO(RecClosureNode, ValueNode);
};

RELAY_DEFINE_NODE_REF(RecClosure, RecClosureNode, Value);

/*! \brief A tuple value. */
class TupleValue;

Expand Down
5 changes: 5 additions & 0 deletions python/tvm/relay/backend/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,11 @@ class Closure(Value):
"""A closure produced by the interpreter."""


@register_relay_node
class RecClosure(Value):
"""A recursive closure produced by the interpreter."""


@register_relay_node
class ConstructorValue(Value):
def __init__(self, tag, fields, constructor):
Expand Down
35 changes: 28 additions & 7 deletions src/relay/backend/interpreter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,27 @@ TVM_REGISTER_API("relay._make.Closure")

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<ClosureNode>([](const ClosureNode* node, tvm::IRPrinter* p) {
p->stream << "ClosureNode(" << node->func << ")";
p->stream << "ClosureNode(" << node->func << ", " << node->env << ")";
});


// TODO(@jroesch): this doesn't support mutual letrec
/* Value Implementation */
RecClosure RecClosureNode::make(Closure clos, Var bind) {
NodePtr<RecClosureNode> n = make_node<RecClosureNode>();
n->clos = std::move(clos);
n->bind = std::move(bind);
return RecClosure(n);
}

TVM_REGISTER_API("relay._make.RecClosure")
.set_body_typed(RecClosureNode::make);

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<RecClosureNode>([](const RecClosureNode* node, tvm::IRPrinter* p) {
p->stream << "RecClosureNode(" << node->clos << ")";
});

TupleValue TupleValueNode::make(tvm::Array<Value> value) {
NodePtr<TupleValueNode> n = make_node<TupleValueNode>();
n->fields = value;
Expand Down Expand Up @@ -281,7 +299,6 @@ class Interpreter :
return TupleValueNode::make(values);
}

// TODO(@jroesch): this doesn't support mutual letrec
inline Value MakeClosure(const Function& func, Var letrec_name = Var()) {
tvm::Map<Var, Value> captured_mod;
Array<Var> free_vars = FreeVars(func);
Expand All @@ -298,10 +315,8 @@ class Interpreter :

// We must use mutation here to build a self referential closure.
auto closure = ClosureNode::make(captured_mod, func);
auto mut_closure =
static_cast<ClosureNode*>(const_cast<Node*>(closure.get()));
if (letrec_name.defined()) {
mut_closure->env.Set(letrec_name, closure);
return RecClosureNode::make(closure, letrec_name);
}
return std::move(closure);
}
Expand Down Expand Up @@ -559,7 +574,7 @@ class Interpreter :
}

// Invoke the closure
Value Invoke(const Closure& closure, const tvm::Array<Value>& args) {
Value Invoke(const Closure& closure, const tvm::Array<Value>& args, const Var& bind = Var()) {
// Get a reference to the function inside the closure.
if (closure->func->IsPrimitive()) {
return InvokePrimitiveOp(closure->func, args);
Expand All @@ -575,12 +590,16 @@ class Interpreter :
locals.Set(func->params[i], args[i]);
}

// Add the var to value mappings from the Closure's modironment.
// Add the var to value mappings from the Closure's environment.
for (auto it = closure->env.begin(); it != closure->env.end(); ++it) {
CHECK_EQ(locals.count((*it).first), 0);
locals.Set((*it).first, (*it).second);
}

if (bind.defined()) {
locals.Set(bind, RecClosureNode::make(closure, bind));
}

return WithFrame<Value>(Frame(locals), [&]() { return Eval(func->body); });
}

Expand All @@ -607,6 +626,8 @@ class Interpreter :
if (const ClosureNode* closure_node = fn_val.as<ClosureNode>()) {
auto closure = GetRef<Closure>(closure_node);
return this->Invoke(closure, args);
} else if (const RecClosureNode* closure_node = fn_val.as<RecClosureNode>()) {
return this->Invoke(closure_node->clos, args, closure_node->bind);
} else {
LOG(FATAL) << "internal error: type error, expected function value in the call "
<< "position";
Expand Down

0 comments on commit 2e0dbaa

Please sign in to comment.