Skip to content

Commit

Permalink
Fix json parsing behavior (#83)
Browse files Browse the repository at this point in the history
  • Loading branch information
piiswrong authored and tqchen committed Nov 22, 2016
1 parent 008aef3 commit e693722
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 12 deletions.
20 changes: 14 additions & 6 deletions include/dmlc/parameter.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,9 @@ enum ParamInitOption {
/*! \brief allow unknown parameters */
kAllowUnknown,
/*! \brief need to match exact parameters */
kAllMatch
kAllMatch,
/*! \brief allow unmatched hidden field with format __*__ */
kAllowHidden
};
} // namespace parameter
/*!
Expand Down Expand Up @@ -122,11 +124,11 @@ struct Parameter {
*/
template<typename Container>
inline void Init(const Container &kwargs,
parameter::ParamInitOption option = parameter::kAllowUnknown) {
parameter::ParamInitOption option = parameter::kAllowHidden) {
PType::__MANAGER__()->RunInit(static_cast<PType*>(this),
kwargs.begin(), kwargs.end(),
NULL,
option == parameter::kAllowUnknown);
option);
}
/*!
* \brief initialize the parameter by keyword arguments.
Expand All @@ -143,7 +145,7 @@ struct Parameter {
std::vector<std::pair<std::string, std::string> > unknown;
PType::__MANAGER__()->RunInit(static_cast<PType*>(this),
kwargs.begin(), kwargs.end(),
&unknown, true);
&unknown, parameter::kAllowUnknown);
return unknown;
}
/*!
Expand Down Expand Up @@ -369,7 +371,7 @@ class ParamManager {
RandomAccessIterator begin,
RandomAccessIterator end,
std::vector<std::pair<std::string, std::string> > *unknown_args,
bool allow_unknown) const {
parameter::ParamInitOption option) const {
std::set<FieldAccessEntry*> selected_args;
for (RandomAccessIterator it = begin; it != end; ++it) {
FieldAccessEntry *e = Find(it->first);
Expand All @@ -381,7 +383,13 @@ class ParamManager {
if (unknown_args != NULL) {
unknown_args->push_back(*it);
} else {
if (!allow_unknown) {
if (option != parameter::kAllowUnknown) {
if (option == parameter::kAllowHidden &&
it->first.length() > 4 &&
it->first.find("__") == 0 &&
it->first.rfind("__") == it->first.length()-2) {
continue;
}
std::ostringstream os;
os << "Cannot find argument \'" << it->first << "\', Possible Arguments:\n";
os << "----------------\n";
Expand Down
2 changes: 1 addition & 1 deletion src/core/symbolic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
namespace nnvm {

namespace symbol_constants {
const char *kNamespaceSeparator = "_";
const char *kNamespaceSeparator = "$";
} // namespace symbol_constants

// auxililary version attribute in variable.
Expand Down
13 changes: 9 additions & 4 deletions src/pass/saveload_json.cc
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,6 @@ struct JSONNode {
if (op_type_str != "null") {
try {
node->attrs.op = Op::Get(op_type_str);
// rebuild attribute parser
if (node->op()->attr_parser != nullptr) {
node->op()->attr_parser(&(node->attrs));
}
} catch (const dmlc::Error &err) {
std::ostringstream os;
os << "Failed loading Op " << node->attrs.name
Expand Down Expand Up @@ -163,6 +159,10 @@ Graph LoadJSON(Graph src) {
<< "Load JSON require json to be presented.";
const std::string &json_str =
nnvm::get<std::string>(*src.attrs.at("json"));
bool no_parse = false;
if (src.attrs.count("load_json_no_parse")) {
no_parse = nnvm::get<bool>(*src.attrs.at("load_json_no_parse"));
}
std::istringstream is(json_str);
dmlc::JSONReader reader(&is);
JSONGraph jgraph;
Expand All @@ -179,6 +179,11 @@ Graph LoadJSON(Graph src) {
for (uint32_t nid : n.control_deps) {
n.node->control_deps.push_back(jgraph.nodes[nid].node);
}
// rebuild attribute parser
if (!no_parse && n.node->op() != nullptr &&
n.node->op()->attr_parser != nullptr) {
n.node->op()->attr_parser(&(n.node->attrs));
}
}
// consistent check
for (uint32_t nid : jgraph.arg_nodes) {
Expand Down
2 changes: 1 addition & 1 deletion tests/python/test_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def test_compose():
assert y.list_attr()['gpu'] == '1'
z = y.get_internals()
assert z['add_output'].list_output_names() == ['add_output']
assert y.list_attr(recursive=True)['add_gpu'] == '2'
assert y.list_attr(recursive=True)['add$gpu'] == '2'

def test_default_input():
x = sym.Variable('x')
Expand Down

0 comments on commit e693722

Please sign in to comment.