Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay][Params] Add APIs for storing and retrieving parameters from individual functions. #4194

Merged
merged 3 commits into from
Oct 27, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion include/tvm/relay/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,19 @@ class FunctionNode : public ExprNode {
tvm::Array<TypeVar> ty_params,
tvm::Attrs attrs = Attrs());

/*!
* \brief Attach the function's parameters to its attributes for use in analysis.
* \return The function with its parameters attached.
*/
Function SetParams(const tvm::Map<Var, Constant>& parameters) const;

/*!
* \brief Retrieve the function's parameters.
*
* \return The function's parameter.
*/
tvm::Map<Var, Constant> GetParams() const;

static constexpr const char* _type_key = "relay.Function";
TVM_DECLARE_NODE_TYPE_INFO(FunctionNode, ExprNode);
};
Expand All @@ -284,7 +297,6 @@ RELAY_DEFINE_NODE_REF(Function, FunctionNode, Expr);
TVM_DLL NodeRef FunctionGetAttr(const Function& func, const std::string& key);
TVM_DLL Function FunctionSetAttr(const Function& func, const std::string& key, const NodeRef& data);


/*!
* \brief Call corresponds to operator invocation.
* Corresponds to the operator in computational graph terminology.
Expand Down
12 changes: 12 additions & 0 deletions python/tvm/relay/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from .._ffi import base as _base
from .. import nd as _nd
from .. import convert
from ..ndarray import NDArray

# will be registered afterwards
_op_make = None
Expand Down Expand Up @@ -305,6 +306,17 @@ def __call__(self, *args):
"""
return Call(self, args, None, None)

def get_params(self):
return _expr.FunctionGetParams(self)

def set_params(self, params):
for key in params:
value = params[key]
if isinstance(value, NDArray):
params[key] = Constant(value)

return _expr.FunctionSetParams(self, params)


@register_relay_node
class Call(Expr):
Expand Down
20 changes: 20 additions & 0 deletions src/relay/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,26 @@ bool FunctionNode::IsPrimitive() const {
return pval && pval->value != 0;
}

Function FunctionNode::SetParams(const tvm::Map<Var, Constant>& parameters) const {
return FunctionSetAttr(GetRef<Function>(this), "__params__", parameters);
}

TVM_REGISTER_API("relay._expr.FunctionSetParams")
.set_body_typed<Function(const Function&, const tvm::Map<Var, Constant>&)>(
[](const Function& func, const tvm::Map<Var, Constant>& parameters) {
return func->SetParams(parameters);
});

tvm::Map<Var, Constant> FunctionNode::GetParams() const {
auto node_ref = FunctionGetAttr(GetRef<Function>(this), "__params__");
return Downcast<tvm::Map<Var, Constant>>(node_ref);
}

TVM_REGISTER_API("relay._expr.FunctionGetParams")
.set_body_typed<tvm::Map<Var, Constant>(const Function&)>([](const Function& func) {
return func->GetParams();
});

NodeRef FunctionGetAttr(const Function& func, const std::string& key) {
if (!func->attrs.defined()) { return NodeRef(); }

Expand Down
33 changes: 31 additions & 2 deletions tests/python/relay/test_ir_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from tvm.expr import *
from tvm.relay import op
from tvm.relay.analysis import graph_equal

import numpy as np

def check_json_roundtrip(node):
json_str = tvm.save_json(node)
Expand Down Expand Up @@ -160,7 +160,6 @@ def test_global_var():
str(gv)
check_json_roundtrip(gv)


def test_function():
param_names = ['a', 'b', 'c', 'd']
params = tvm.convert([relay.Var(n) for n in param_names])
Expand All @@ -175,6 +174,34 @@ def test_function():
str(fn)
check_json_roundtrip(fn)

def test_function_attrs():
param_names = ['a', 'b', 'c', 'd']
params = tvm.convert([relay.var(n, shape=(5, 2)) for n in param_names])
ret_type = relay.TupleType(tvm.convert([]))
body = relay.Tuple(tvm.convert([]))
type_params = tvm.convert([])
fn = relay.Function(params, body, ret_type, type_params)
model_params = {}
for param in params[:1]:
cty = param.type_annotation
tensor = np.random.rand(*[int(sh) for sh in cty.shape]).astype(cty.dtype)
model_params[param] = tvm.nd.array(tensor)
fn = fn.set_params(model_params)
assert fn.params == params
assert fn.body == body
assert fn.type_params == type_params
assert fn.span == None
str(fn)
check_json_roundtrip(fn)
json_str = tvm.save_json(fn)
fn_after = tvm.load_json(json_str)
model_params_after = fn_after.get_params()
after_keys = [item[0] for item in model_params_after.items()]
for key1, key2 in zip(model_params, after_keys):
assert key1.name_hint == key2.name_hint
p1 = model_params[key1]
p2 = model_params_after[key2]
np.testing.assert_allclose(p1.data.asnumpy(), p2.data.asnumpy())

def test_call():
op = relay.Var('f')
Expand Down Expand Up @@ -257,9 +284,11 @@ def test_conv2d_attrs():
test_local_var()
test_global_var()
test_function()
test_function_attrs()
test_call()
test_let()
test_if()
test_tuple_get_item()
test_op()
test_conv2d_attrs()