Skip to content

Commit

Permalink
[Parser] Add support for parsing the any dimension. (#6277)
Browse files Browse the repository at this point in the history
* Add case for any dimensions

* Fix second test case
  • Loading branch information
jroesch authored Aug 14, 2020
1 parent ad0dbe0 commit 4b2c01a
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 2 deletions.
5 changes: 3 additions & 2 deletions src/parser/parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1502,6 +1502,8 @@ class Parser {
tvm::PrimExpr dim;
if (Peek()->token_type == TokenType::kMetaReference) {
dim = Downcast<tvm::PrimExpr>(ParseMetaRef());
} else if (WhenMatch(TokenType::kQuestion)) {
dim = tvm::tir::Any();
} else {
dim = Downcast<tvm::PrimExpr>(Match(TokenType::kInteger)->data);
}
Expand Down Expand Up @@ -1585,8 +1587,7 @@ class Parser {
return ParseNonPrimitiveType(tok);
}
}
}
if (WhenMatch(TokenType::kUnderscore)) {
} else if (WhenMatch(TokenType::kUnderscore)) {
return IncompleteType();
} else {
this->diag_ctx->EmitFatal(Diagnostic::Error(tok->span)
Expand Down
28 changes: 28 additions & 0 deletions tests/python/relay/test_ir_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,6 +591,16 @@ def test_tensor_type():
)
)

assert_parses_as(
"let %_ : Tensor[(?, 1), float32] = (); ()",
relay.Let(
relay.Var("_", relay.TensorType((tvm.tir.Any(), 1), "float32")),
UNIT,
UNIT
)
)



def test_function_type():
assert_parses_as(
Expand Down Expand Up @@ -678,6 +688,24 @@ def test_adt_defn():
mod
)

def test_adt_any():
code = """
type my_dtype {
my_cons(Tensor[(?, 1), uint16]),
}
"""
mod = parse_module(code)
items = mod.type_definitions.items()
global_type_var, type_data = items[0]
assert global_type_var.name_hint == "my_dtype"
ctors = type_data.constructors
assert len(ctors) == 1
my_cons = ctors[0]
assert my_cons.name_hint == "my_cons"
ty_shape = my_cons.inputs[0].shape
assert isinstance(ty_shape[0], tvm.tir.Any)
assert ty_shape[1] == 1


def test_empty_adt_defn():
mod = tvm.IRModule()
Expand Down

0 comments on commit 4b2c01a

Please sign in to comment.