Skip to content

Commit

Permalink
[RELAY][Parser] Optimize relay parser to restore calls attrs (apache#…
Browse files Browse the repository at this point in the history
…7347)

* [RELAY][Parser] Optimize relay parser to restore attrs for non-Operator calls

* To avoid too much modification to the native class, only print out the attrs
  type key of non-Operator Call in relay printer. Then reconstruct the attrs object
  after parsing this attrs type key value in Relay parser.

* fix lint

* fix ci

* add test case
  • Loading branch information
domin1985 authored and trevor-m committed Mar 2, 2021
1 parent 889bd11 commit 76cecf2
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 5 deletions.
22 changes: 17 additions & 5 deletions src/parser/parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1334,6 +1334,8 @@ class Parser {
case TokenType::kBoolean:
case TokenType::kStringLiteral:
return Match(next->token_type)->data;
case TokenType::kMetaReference:
return ParseMetaRef();
case TokenType::kLSquare: {
return ParseSequence<ObjectRef>(TokenType::kLSquare, TokenType::kComma, TokenType::kRSquare,
[&]() { return ParseAttributeValue(); });
Expand Down Expand Up @@ -1408,7 +1410,7 @@ class Parser {
auto last_meta = Lookahead(2)->token_type == TokenType::kCloseParen;
auto is_meta_attrs = is_meta_next && last_meta;

if (is_op && (is_pretty_attrs || is_meta_attrs)) {
if (is_pretty_attrs || is_meta_attrs) {
if (is_meta_attrs) {
auto meta_ref = ParseMetaRef();
if (meta_ref.as<BaseAttrsNode>()) {
Expand All @@ -1420,13 +1422,23 @@ class Parser {
}
} else {
auto raw_attrs = ParseAttrs();
auto attr_obj = tvm::ReflectionVTable::Global()->CreateObject(op_key, raw_attrs);
ICHECK(attr_obj.defined());
attrs = Downcast<Attrs>(attr_obj);
if (is_op && op_key.size()) {
auto attr_obj = tvm::ReflectionVTable::Global()->CreateObject(op_key, raw_attrs);
ICHECK(attr_obj.defined());
attrs = Downcast<Attrs>(attr_obj);
} else if (raw_attrs.count("attrs_type_key")) {
String attr_key = Downcast<String>(raw_attrs["attrs_type_key"]);
if (attr_key.size()) {
raw_attrs.erase("attrs_type_key");
auto tbl = tvm::ReflectionVTable::Global();
auto attr_obj = tbl->CreateObject(attr_key, raw_attrs);
ICHECK(attr_obj.defined());
attrs = Downcast<Attrs>(attr_obj);
}
}
}
return true;
}

return false;
});

Expand Down
5 changes: 5 additions & 0 deletions src/printer/relay_text_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -827,6 +827,11 @@ std::vector<Doc> RelayTextPrinter::PrintCallAttrs(const Attrs& attrs, const Expr
} else {
AttrPrinter printer(&docs, this);
const_cast<BaseAttrsNode*>(attrs.operator->())->VisitNonDefaultAttrs(&printer);
if (!op_node) {
// print call attr type key to restore expr for relay parser
std::string s = std::string(attrs->GetTypeKey());
printer.Visit("attrs_type_key", &s);
}
return docs;
}
}
Expand Down
39 changes: 39 additions & 0 deletions tests/python/relay/test_ir_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -910,6 +910,45 @@ def test_load_prelude():
tvm.parser.parse(mod.astext())


def test_call_attrs():
def get_func(shape, dtype):
x0 = relay.var("data", shape=shape, dtype=dtype)
w0 = relay.var("weight", shape=shape, dtype=dtype)
a = relay.nn.dense(x0, w0)
b = relay.nn.relu(a)
d = relay.add(b, relay.const(1.0, dtype=dtype))
return relay.Function([x0, w0], d)

# build relay graph
shape = (2, 4)
dtype = "float32"
sub_func = get_func(shape, dtype)
p0 = relay.var("p0", shape=shape, dtype=dtype)
p1 = relay.var("p1", shape=shape, dtype=dtype)
attr = tvm.ir.make_node("attrs.TestAttrs", name="func_call_attrs")
call = relay.Call(sub_func, [p0, p1], attrs=attr)
func = relay.Function([p0, p1], call)

# build relay module
mod = tvm.IRModule()
mod["main"] = func
mod = tvm.relay.transform.InferType()(mod)

# assert equal
program = """
def @main(%p0: Tensor[(2, 4), float32], %p1: Tensor[(2, 4), float32]) {
%2 = fn (%data: Tensor[(2, 4), float32], %weight: Tensor[(2, 4), float32]) {
%0 = nn.dense(%data, %weight, units=None);
%1 = nn.relu(%0);
add(%1, 1f)
};
%2(%p0, %p1, name="func_call_attrs", attrs_type_key="attrs.TestAttrs")
}
"""
parsed = parse_module(program)
assert_graph_equal(parsed, mod)


def test_tokenize_inf():
x = relay.var("x", shape=(3, 4), dtype="float32")
y = relay.clip(x, -np.inf, np.inf)
Expand Down

0 comments on commit 76cecf2

Please sign in to comment.