Skip to content

Commit

Permalink
[Relay][Params] Add APIs for storing and retrieving parameters from i…
Browse files Browse the repository at this point in the history
…ndividual functions. (apache#4194)

* Add support for attaching params

* Fix types

* Fix test
  • Loading branch information
jroesch authored and kevinthesun committed Oct 30, 2019
1 parent 81452aa commit d121208
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 3 deletions.
14 changes: 13 additions & 1 deletion include/tvm/relay/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,19 @@ class FunctionNode : public ExprNode {
tvm::Array<TypeVar> ty_params,
tvm::Attrs attrs = Attrs());

/*!
* \brief Attach the function's parameters to its attributes for use in analysis.
* \return The function with its parameters attached.
*/
Function SetParams(const tvm::Map<Var, Constant>& parameters) const;

/*!
* \brief Retrieve the function's parameters.
*
* \return The function's parameter.
*/
tvm::Map<Var, Constant> GetParams() const;

static constexpr const char* _type_key = "relay.Function";
TVM_DECLARE_NODE_TYPE_INFO(FunctionNode, ExprNode);
};
Expand All @@ -284,7 +297,6 @@ RELAY_DEFINE_NODE_REF(Function, FunctionNode, Expr);
TVM_DLL NodeRef FunctionGetAttr(const Function& func, const std::string& key);
TVM_DLL Function FunctionSetAttr(const Function& func, const std::string& key, const NodeRef& data);


/*!
* \brief Call corresponds to operator invocation.
* Corresponds to the operator in computational graph terminology.
Expand Down
12 changes: 12 additions & 0 deletions python/tvm/relay/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from .._ffi import base as _base
from .. import nd as _nd
from .. import convert
from ..ndarray import NDArray

# will be registered afterwards
_op_make = None
Expand Down Expand Up @@ -305,6 +306,17 @@ def __call__(self, *args):
"""
return Call(self, args, None, None)

def get_params(self):
return _expr.FunctionGetParams(self)

def set_params(self, params):
for key in params:
value = params[key]
if isinstance(value, NDArray):
params[key] = Constant(value)

return _expr.FunctionSetParams(self, params)


@register_relay_node
class Call(Expr):
Expand Down
20 changes: 20 additions & 0 deletions src/relay/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,26 @@ bool FunctionNode::IsPrimitive() const {
return pval && pval->value != 0;
}

Function FunctionNode::SetParams(const tvm::Map<Var, Constant>& parameters) const {
return FunctionSetAttr(GetRef<Function>(this), "__params__", parameters);
}

TVM_REGISTER_API("relay._expr.FunctionSetParams")
.set_body_typed<Function(const Function&, const tvm::Map<Var, Constant>&)>(
[](const Function& func, const tvm::Map<Var, Constant>& parameters) {
return func->SetParams(parameters);
});

tvm::Map<Var, Constant> FunctionNode::GetParams() const {
auto node_ref = FunctionGetAttr(GetRef<Function>(this), "__params__");
return Downcast<tvm::Map<Var, Constant>>(node_ref);
}

TVM_REGISTER_API("relay._expr.FunctionGetParams")
.set_body_typed<tvm::Map<Var, Constant>(const Function&)>([](const Function& func) {
return func->GetParams();
});

NodeRef FunctionGetAttr(const Function& func, const std::string& key) {
if (!func->attrs.defined()) { return NodeRef(); }

Expand Down
33 changes: 31 additions & 2 deletions tests/python/relay/test_ir_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from tvm.expr import *
from tvm.relay import op
from tvm.relay.analysis import graph_equal

import numpy as np

def check_json_roundtrip(node):
json_str = tvm.save_json(node)
Expand Down Expand Up @@ -160,7 +160,6 @@ def test_global_var():
str(gv)
check_json_roundtrip(gv)


def test_function():
param_names = ['a', 'b', 'c', 'd']
params = tvm.convert([relay.Var(n) for n in param_names])
Expand All @@ -175,6 +174,34 @@ def test_function():
str(fn)
check_json_roundtrip(fn)

def test_function_attrs():
param_names = ['a', 'b', 'c', 'd']
params = tvm.convert([relay.var(n, shape=(5, 2)) for n in param_names])
ret_type = relay.TupleType(tvm.convert([]))
body = relay.Tuple(tvm.convert([]))
type_params = tvm.convert([])
fn = relay.Function(params, body, ret_type, type_params)
model_params = {}
for param in params[:1]:
cty = param.type_annotation
tensor = np.random.rand(*[int(sh) for sh in cty.shape]).astype(cty.dtype)
model_params[param] = tvm.nd.array(tensor)
fn = fn.set_params(model_params)
assert fn.params == params
assert fn.body == body
assert fn.type_params == type_params
assert fn.span == None
str(fn)
check_json_roundtrip(fn)
json_str = tvm.save_json(fn)
fn_after = tvm.load_json(json_str)
model_params_after = fn_after.get_params()
after_keys = [item[0] for item in model_params_after.items()]
for key1, key2 in zip(model_params, after_keys):
assert key1.name_hint == key2.name_hint
p1 = model_params[key1]
p2 = model_params_after[key2]
np.testing.assert_allclose(p1.data.asnumpy(), p2.data.asnumpy())

def test_call():
op = relay.Var('f')
Expand Down Expand Up @@ -257,9 +284,11 @@ def test_conv2d_attrs():
test_local_var()
test_global_var()
test_function()
test_function_attrs()
test_call()
test_let()
test_if()
test_tuple_get_item()
test_op()
test_conv2d_attrs()

0 comments on commit d121208

Please sign in to comment.