Skip to content

Commit

Permalink
[Runtime] Allow for parameter sharing in GraphRuntime (apache#3384)
Browse files Browse the repository at this point in the history
Summary:

In multi-threaded applications where we have multiple inferences on the
same model in parallel (consider e.g. a TTS system handling multiple
requests), it can be useful to share the parameters of a model amongst
these multiple instances. This improves the cache utilization behaviour
of the system, as multiple cores can use the same set of weights instead
of evicting the identical copies of weights in a shared cache.

As the underlying `NDArray` instances in `data_entry_` implement a
ref-counted based sharing system, this is a simple modification of the
`GraphRuntime::LoadParams` logic to instead copy parameters from an
existing GraphRuntime instance. This is a little ugly in that we need
both the pre-existing GraphRuntime instance, as well as the 'serialized'
params (since we need to know the set of names we should copy), but
without imposing additional assumptions (i.e. storing the set of param
names in GraphRuntime, and enforcing that shared param names are
identical to the parameters set in the preceding `LoadParams` call),
this seems unavoidable.

Test Plan:

Unit test added.
  • Loading branch information
ajtulloch authored and wweic committed Jun 27, 2019
1 parent 681d324 commit 48704e7
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 4 deletions.
14 changes: 14 additions & 0 deletions python/tvm/contrib/graph_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def __init__(self, module):
self._get_input = module["get_input"]
self._get_num_outputs = module["get_num_outputs"]
self._load_params = module["load_params"]
self._share_params = module["share_params"]

def set_input(self, key=None, value=None, **params):
"""Set inputs to the module via kwargs
Expand Down Expand Up @@ -234,6 +235,19 @@ def load_params(self, params_bytes):
"""
self._load_params(bytearray(params_bytes))

def share_params(self, other, params_bytes):
"""Share parameters from pre-existing GraphRuntime instance.
Parameters
----------
other: GraphRuntime
The parent GraphRuntime from which this instance should share
it's parameters.
params_bytes : bytearray
The serialized parameter dict (used only for the parameter names).
"""
self._share_params(other.module, bytearray(params_bytes))

def __getitem__(self, key):
"""Get internal module function
Expand Down
34 changes: 34 additions & 0 deletions src/runtime/graph/graph_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,32 @@ void GraphRuntime::LoadParams(dmlc::Stream* strm) {
}
}

void GraphRuntime::ShareParams(const GraphRuntime& other, dmlc::Stream* strm) {
uint64_t header, reserved;
CHECK(strm->Read(&header))
<< "Invalid parameters file format";
CHECK(header == kTVMNDArrayListMagic)
<< "Invalid parameters file format";
CHECK(strm->Read(&reserved))
<< "Invalid parameters file format";
std::vector<std::string> names;
CHECK(strm->Read(&names)) << "Invalid parameters file format";
uint64_t sz;
strm->Read(&sz);
size_t size = static_cast<size_t>(sz);
CHECK(size == names.size()) << "Invalid parameters file format";
for (size_t i = 0; i < size; ++i) {
int in_idx = GetInputIndex(names[i]);
CHECK_GE(in_idx, 0) << "Found param for non-existent input: " << names[i];
uint32_t eid = this->entry_id(input_nodes_[in_idx], 0);
CHECK_LT(eid, data_entry_.size());
CHECK_EQ(data_entry_[eid].use_count(), 1);
data_entry_[eid] = other.GetInput(GetInputIndex(names[i]));
CHECK_GT(data_entry_[eid].use_count(), 1);
}
this->SetupOpExecs();
}

void GraphRuntime::SetupStorage() {
// Grab saved optimization plan from graph.
std::vector<TVMType> vtype;
Expand Down Expand Up @@ -411,6 +437,14 @@ PackedFunc GraphRuntime::GetFunction(
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
this->LoadParams(args[0].operator std::string());
});
} else if (name == "share_params") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
const auto& module = args[0].operator Module();
CHECK_EQ(module.operator->()->type_key(), "GraphRuntime");
const auto& param_blob = args[1].operator std::string();
dmlc::MemoryStringStream strm(const_cast<std::string*>(&param_blob));
this->ShareParams(dynamic_cast<const GraphRuntime&>(*module.operator->()), &strm);
});
} else {
return PackedFunc();
}
Expand Down
17 changes: 13 additions & 4 deletions src/runtime/graph/graph_runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -171,10 +171,19 @@ class GraphRuntime : public ModuleNode {
* \param param_blob A binary blob of parameter.
*/
void LoadParams(const std::string& param_blob);
/*!
* \brief Get total number of nodes.
* \return Total number of nodes.
*/

/*!
* \brief Share parameters from pre-existing GraphRuntime instance.
* \param other A GraphRuntime instance, previously with |LoadParams| called with the
* identical input |param_blob|.
* \param strm The input stream.
*/
void ShareParams(const GraphRuntime& other, dmlc::Stream* strm);

/*!
* \brief Get total number of nodes.
* \return Total number of nodes.
*/
uint32_t GetNumOfNodes() const {
return static_cast<uint32_t>(nodes_.size());
}
Expand Down
38 changes: 38 additions & 0 deletions tests/python/unittest/test_runtime_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,46 @@ def check_remote():
out = mod.get_output(0, out)
np.testing.assert_equal(out.asnumpy(), a + 1)

def check_sharing():
from tvm import relay
x = relay.var('x', shape=(1, 10))
y = relay.var('y', shape=(1, 10))
z = relay.add(x, y)
func = relay.Function([x, y], z)

x_in = np.ones((1, 10)).astype("float32")
params = {'x': x_in}
graph, lib, params = relay.build(func, target="llvm", params=params)

if not tvm.module.enabled("llvm"):
print("Skip because llvm is not enabled")
return
mod_shared = graph_runtime.create(graph, lib, tvm.cpu(0))
mod_shared.load_params(relay.save_param_dict(params))
num_mods = 10
mods = [graph_runtime.create(graph, lib, tvm.cpu(0))
for _ in range(num_mods)]

for mod in mods:
mod.share_params(mod_shared, relay.save_param_dict(params))

a = np.random.uniform(size=(1, 10)).astype("float32")
for mod in mods:
mod.run(y=a)
out = mod.get_output(0, tvm.nd.empty((1, 10)))
np.testing.assert_equal(out.asnumpy(), x_in + a)

# Explicitly delete the shared module and verify correctness.
del mod_shared
for mod in mods:
mod.run(y=a)
out = mod.get_output(0, tvm.nd.empty((1, 10)))
np.testing.assert_equal(out.asnumpy(), x_in + a)
del mod

check_verify()
check_remote()
check_sharing()

if __name__ == "__main__":
test_graph_simple()

0 comments on commit 48704e7

Please sign in to comment.