Skip to content

Commit

Permalink
Enable use json for graph attr exchange (apache#5)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen committed May 26, 2018
1 parent 26026d4 commit 88520e1
Show file tree
Hide file tree
Showing 6 changed files with 71 additions and 29 deletions.
29 changes: 18 additions & 11 deletions nnvm/include/nnvm/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -248,27 +248,34 @@ NNVM_DLL int NNGraphFree(GraphHandle handle);
*/
NNVM_DLL int NNGraphGetSymbol(GraphHandle graph, SymbolHandle *symbol);
/*!
* \brief Get Set a std::string typed attribute to graph.
* \brief Get Set a attribute in json format.
* This feature allows pass graph attributes back and forth in reasonable speed.
*
* \param handle The graph handle.
* \param key The key to the attribute.
* \param value The value to be exposed.
* \param json_value The value need to be in format [type_name, value],
* Where type_name is a registered type string in C++ side via DMLC_JSON_ENABLE_ANY.
* \return 0 when success, -1 when failure happens
*/
NNVM_DLL int NNGraphSetStrAttr(GraphHandle handle,
const char* key,
const char* value);
NNVM_DLL int NNGraphSetJSONAttr(GraphHandle handle,
const char* key,
const char* json_value);
/*!
* \brief Get Set a std::string typed attribute from graph attribute.
* \brief Get a serialized attrirbute from graph.
* This feature allows pass graph attributes back and forth in reasonable speed.
*
* \param handle The graph handle.
* \param key The key to the attribute.
* \param out The result attribute, can be NULL if the attribute do not exist.
* \param json_out The result attribute, can be NULL if the attribute do not exist.
* The json_out is an array of [type_name, value].
* Where the type_name is a registered type string in C++ side via DMLC_JSON_ENABLE_ANY.
* \param success Whether the result is contained in out.
* \return 0 when success, -1 when failure happens
*/
NNVM_DLL int NNGraphGetStrAttr(SymbolHandle handle,
const char* key,
const char** out,
int *success);
NNVM_DLL int NNGraphGetJSONAttr(SymbolHandle handle,
const char* key,
const char** json_out,
int *success);
/*!
* \brief Apply pass on the src graph.
* \param src The source graph handle.
Expand Down
1 change: 1 addition & 0 deletions nnvm/python/nnvm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def _load_lib():
SymbolHandle = ctypes.c_void_p
GraphHandle = ctypes.c_void_p


#----------------------------
# helper function definition
#----------------------------
Expand Down
29 changes: 20 additions & 9 deletions nnvm/python/nnvm/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@

import ctypes
import sys
import json
from .base import _LIB
from .base import c_array, c_str, nn_uint, py_str, string_types
from .base import GraphHandle, SymbolHandle
from .base import check_call
from .symbol import Symbol


class Graph(object):
"""Graph is the graph object that can be used to apply optimization pass.
It contains additional graphwise attribute besides the internal symbol.
Expand All @@ -31,7 +33,7 @@ def __init__(self, handle):
def __del__(self):
check_call(_LIB.NNGraphFree(self.handle))

def attr(self, key):
def json_attr(self, key):
"""Get attribute string from the graph.
Parameters
Expand All @@ -46,24 +48,33 @@ def attr(self, key):
"""
ret = ctypes.c_char_p()
success = ctypes.c_int()
check_call(_LIB.NNGraphGetStrAttr(
check_call(_LIB.NNGraphGetJSONAttr(
self.handle, c_str(key), ctypes.byref(ret), ctypes.byref(success)))
if success.value != 0:
return py_str(ret.value)
json_str = py_str(ret.value)
return json.loads(json_str)[1]
else:
return None

def _set_attr(self, **kwargs):
def _set_json_attr(self, key, value, type_name=None):
"""Set the attribute of the symbol.
Parameters
----------
**kwargs
The attributes to set
key : string
The key of the attribute
value : value
The any type that can be dumped to json
type_name : string
The typename registered on c++ side.
"""
for k, v in kwargs.items():
check_call(_LIB.NNGraphSetStrAttr(
self.handle, c_str(k), c_str(v)))
if isinstance(value, string_types):
type_name = 'str'
elif type_name is None:
raise ValueError("Need to specify type_name")
json_value = json.dumps([type_name, value])
check_call(_LIB.NNGraphSetJSONAttr(
self.handle, c_str(key), c_str(json_value)))

@property
def symbol(self):
Expand Down
26 changes: 18 additions & 8 deletions nnvm/src/c_api/c_api_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <nnvm/symbolic.h>
#include <nnvm/graph.h>
#include <nnvm/pass.h>
#include <dmlc/json.h>
#include "./c_api_common.h"

using namespace nnvm;
Expand All @@ -34,26 +35,35 @@ int NNGraphGetSymbol(GraphHandle graph, SymbolHandle *symbol) {
API_END_HANDLE_ERROR(delete s);
}

int NNGraphSetStrAttr(GraphHandle handle,
const char* key,
const char* value) {
int NNGraphSetJSONAttr(GraphHandle handle,
const char* key,
const char* json_value) {
API_BEGIN();
Graph* g = static_cast<Graph*>(handle);
g->attrs[std::string(key)] = std::make_shared<any>(std::string(value));
std::string temp(json_value);
std::istringstream is(temp);
dmlc::JSONReader reader(&is);
nnvm::any value;
reader.Read(&value);
g->attrs[std::string(key)] = std::make_shared<any>(std::move(value));
API_END();
}

int NNGraphGetStrAttr(GraphHandle handle,
int NNGraphGetJSONAttr(GraphHandle handle,
const char* key,
const char** out,
const char** json_out,
int *success) {
NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get();
API_BEGIN();
Graph* g = static_cast<Graph*>(handle);
std::string skey(key);
auto it = g->attrs.find(skey);
if (it != g->attrs.end()) {
const std::string& str = nnvm::get<std::string>(*it->second.get());
*out = str.c_str();
std::ostringstream os;
dmlc::JSONWriter writer(&os);
writer.Write(*it->second.get());
ret->ret_str = os.str();
*json_out = (ret->ret_str).c_str();
*success = 1;
} else {
*success = 0;
Expand Down
4 changes: 4 additions & 0 deletions nnvm/src/pass/saveload_json.cc
Original file line number Diff line number Diff line change
Expand Up @@ -203,5 +203,9 @@ NNVM_REGISTER_PASS(SaveJSON)
.set_change_graph(true)
.provide_graph_attr("json");


DMLC_JSON_ENABLE_ANY(std::string, str);
DMLC_JSON_ENABLE_ANY(std::vector<int>, list_int);

} // namespace pass
} // namespace nnvm
11 changes: 10 additions & 1 deletion nnvm/tests/python/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,18 @@ def test_json_pass():
y = sym.conv2d(data=x, name='conv', stride=(2,2))
g = graph.create(y)
ret = g.apply('SaveJSON')
ret._set_json_attr('json', ret.json_attr('json'))
g2 = ret.apply('LoadJSON')
assert g2.apply('SaveJSON').attr('json') == ret.attr('json')
assert g2.apply('SaveJSON').json_attr('json') == ret.json_attr('json')

def test_graph_json_attr():
x = sym.Variable('x')
y = sym.conv2d(data=x, name='conv', stride=(2,2))
g = graph.create(y)
g._set_json_attr('ilist', [1,2,3], 'list_int')
assert g.json_attr('ilist') == [1,2,3]


if __name__ == "__main__":
test_graph_json_attr()
test_json_pass()

0 comments on commit 88520e1

Please sign in to comment.