Skip to content

Commit

Permalink
Fix ir_text_printer
Browse files Browse the repository at this point in the history
  • Loading branch information
jroesch committed Aug 8, 2020
1 parent 2e7bdbd commit e3e5ef8
Show file tree
Hide file tree
Showing 7 changed files with 139 additions and 71 deletions.
9 changes: 6 additions & 3 deletions include/tvm/parser/source_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,18 +41,21 @@ namespace parser {
* source of a TVM program.
*/
struct Source {
/*! \brief The source name. */
SourceName source_name;

/*! \brief The raw source. */
std::string source;
/*! \brief A mapping of line breaks into the raw source. */
std::vector<std::pair<int, int>> line_map;

/*! \brief An empty source. */
Source() : source(), line_map() {}
Source() : source_name(), source(), line_map() {}

/*! \brief Construct a source from a string. */
TVM_DLL explicit Source(const std::string& source);
TVM_DLL explicit Source(const SourceName& src_name, const std::string& source);

TVM_DLL Source(const Source& source) : source(source.source), line_map(source.line_map) {}
TVM_DLL Source(const Source& source) : source_name(source.source_name), source(source.source), line_map(source.line_map) {}

/*! \brief Generate an error message at a specific line and column with the
* annotated message.
Expand Down
109 changes: 89 additions & 20 deletions src/parser/parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -517,9 +517,29 @@ class Parser {
return Bracket(TokenType::kLCurly, TokenType::kRCurly, parser);
}

Object ParseMetaRef() {
Consume(TokenType::kMetaReference);
LOG(FATAL) << "implement me";
ObjectRef ParseMetaRef() {
auto meta_ref = Match(TokenType::kMetaReference);
Call ref = Downcast<Call>(meta_ref->data);
auto attrs = ref->attrs.as<MetaRefAttrs>();
auto type_key = attrs->node_type_key;
auto index = attrs->node_index;
auto it = this->meta_table.find(type_key);
if (it != this->meta_table.end()) {
auto nodes = (*it).second;
if (index < nodes.size()) {
return nodes[index];
} else {
this->diag_ctx->Emit(
Diagnostic::Error(meta_ref->span)
<< "the node index `" << index << "` is out of bounds for `" << type_key << "`");
return ObjectRef();
}
} else {
this->diag_ctx->Emit(
Diagnostic::Error(meta_ref->span)
<< "no entry in the meta table for `" << type_key << "`");
return ObjectRef();
}
}
/*! \brief Parses a sequence beginning with a start token, seperated by a seperator token, and
* ending with a stop token.
Expand Down Expand Up @@ -607,8 +627,7 @@ class Parser {
auto mod = IRModule({}, types);

for (auto func : defs.funcs) {
auto function = ExpandMetaRefs(metadata, func.function);
mod->Add(func.global, function);
mod->Add(func.global, func.function);
}

return mod;
Expand Down Expand Up @@ -801,8 +820,14 @@ class Parser {
case TokenType::kFreeVar: {
Consume(TokenType::kFreeVar);
auto var_token = Match(TokenType::kLocal);
Match(TokenType::kColon);
auto type = ParseType();

Type type;
if (WhenMatch(TokenType::kColon)) {
type = ParseType();
} else {
type = IncompleteType();
}

BindFreeVar(var_token.ToString(), type);
break;
}
Expand Down Expand Up @@ -950,7 +975,7 @@ class Parser {

/*! Parse a function definition without a leading keyword or identifier.
*
* Handles things of the form [T1, ..., TN](arg1: U1, ..., argN, UN) -> Ret { body }.
* Handles things of the form [T1, ..., TN](arg1: U1, ..., argN : UN) -> Ret { body }.
*/
Function ParseFunctionDef() {
DLOG(INFO) << "Parser::ParseFunctionDef";
Expand All @@ -968,6 +993,8 @@ class Parser {
});
}

Map<String, ObjectRef> raw_attrs;

auto params =
ParseSequence<Var>(TokenType::kOpenParen, TokenType::kComma, TokenType::kCloseParen, [&]() {
auto token = Match(TokenType::kLocal);
Expand All @@ -977,6 +1004,16 @@ class Parser {
type = ParseType();
}
return BindVar(string, type);
}, [&] {
auto is_ident = Lookahead(1)->token_type == TokenType::kIdentifier;
auto next_is_equal = Lookahead(2)->token_type == TokenType::kEqual;

if (is_ident && next_is_equal) {
raw_attrs = ParseAttrs();
return true;
}

return false;
});

Type ret_type;
Expand All @@ -990,7 +1027,12 @@ class Parser {
PopTypeScopes(1);
PopScopes(1);

return relay::Function(params, body, ret_type, generics);
// TODO(@jroesch): attributes should never be null, they should always be empty.
if (raw_attrs.size()) {
return relay::Function(params, body, ret_type, generics, DictAttrs(raw_attrs));
} else {
return relay::Function(params, body, ret_type, generics);
}
}

/*! \brief Parse an if-expression. */
Expand Down Expand Up @@ -1170,6 +1212,22 @@ class Parser {
return ParseSequence<ObjectRef>(TokenType::kLSquare, TokenType::kComma, TokenType::kRSquare,
[&]() { return ParseAttributeValue(); });
}
case TokenType::kOpenParen: {
// TODO(@jroesch: need to figure out bracket vs. sequence)
// return ParseSequence<ObjectRef>(TokenType::kOpenParen, TokenType::kComma, TokenType::kCloseParen,
// [&]() { return ParseAttributeValue(); });
return Bracket<ObjectRef>(TokenType::kOpenParen, TokenType::kCloseParen, [&]() { return ParseAttributeValue(); });
}
// TODO(@jroesch): not sure about this being the right way to handle nulls.
case TokenType::kIdentifier: {
if (auto text = next->data.as<tvm::StringObj>()) {
std::string id = GetRef<String>(text);
if (id == "nullptr") {
Match(TokenType::kIdentifier);
return ObjectRef();
}
}
}
default:
return ParseAtomicExpr();
}
Expand Down Expand Up @@ -1278,6 +1336,7 @@ class Parser {
}

Expr GetOp(const std::string& op_name, const Token& tok) {
DLOG(INFO) << "op_name=" << op_name << " token=" << tok;
try {
return Op::Get(op_name);
} catch (dmlc::Error e) {
Expand Down Expand Up @@ -1335,6 +1394,7 @@ class Parser {
return Expr(ctor.value());
} else {
auto idents = ParseHierName();
CHECK_NE(idents.size(), 0);
std::stringstream op_name;
int i = 0;
int periods = idents.size() - 1;
Expand All @@ -1354,8 +1414,6 @@ class Parser {
}
case TokenType::kMetaReference: {
return Downcast<Expr>(ParseMetaRef());
Consume(TokenType::kMetaReference);
return Downcast<Expr>(next->data);
}
case TokenType::kFn: {
Consume(TokenType::kFn);
Expand Down Expand Up @@ -1408,7 +1466,8 @@ class Parser {
Array<String> ParseHierName() {
Array<String> idents;
while (Peek()->token_type == TokenType::kIdentifier) {
idents.push_back(Peek().ToString());
auto name = Peek().ToString();
idents.push_back(name);
Consume(TokenType::kIdentifier);

if (Peek()->token_type == TokenType::kPeriod) {
Expand All @@ -1426,8 +1485,14 @@ class Parser {
Array<tvm::PrimExpr> ParseShape() {
auto dims = ParseSequence<tvm::PrimExpr>(TokenType::kOpenParen, TokenType::kComma,
TokenType::kCloseParen, [&]() {
auto tok = Match(TokenType::kInteger);
return Downcast<tvm::PrimExpr>(tok->data);
tvm::PrimExpr dim;
if (Peek()->token_type == TokenType::kMetaReference) {
dim = Downcast<tvm::PrimExpr>(ParseMetaRef());
} else {
dim = Downcast<tvm::PrimExpr>(Match(TokenType::kInteger)->data);
}

return dim;
});
return dims;
}
Expand Down Expand Up @@ -1565,10 +1630,12 @@ class Parser {
IRModule ParseModule(std::string file_name, std::string file_content) {
DLOG(INFO) << "ParseModule";
SourceName src_name = SourceName::Get(file_name);
Source src(file_content);
Source src(src_name, file_content);
DiagnosticContext ctx(src);
auto tokens = Tokenize(&ctx, src_name, file_content);
Parser parser(&ctx, src_name, tokens, DefaultOpTable(), Source(file_content));
auto tokens_and_table = Tokenize(&ctx, src_name, file_content);
auto tokens = tokens_and_table.first;
auto meta_data_table = tokens_and_table.second;
Parser parser(&ctx, src_name, tokens, DefaultOpTable(), src, meta_data_table.ToMetadata());
auto mod = parser.ParseModule();
// NB(@jroesch): it is very important that we render any errors before we procede
// if there were any errors which allow the parser to procede we must render them
Expand All @@ -1580,10 +1647,12 @@ IRModule ParseModule(std::string file_name, std::string file_content) {
Expr ParseExpr(std::string file_name, std::string file_content) {
DLOG(INFO) << "ParseExpr";
SourceName src_name = SourceName::Get(file_name);
Source src(file_content);
Source src(src_name, file_content);
DiagnosticContext ctx(src);
auto tokens = Tokenize(&ctx, src_name, file_content);
Parser parser(&ctx, src_name, tokens, DefaultOpTable(), Source(file_content));
auto tokens_and_table = Tokenize(&ctx, src_name, file_content);
auto tokens = tokens_and_table.first;
auto meta_data_table = tokens_and_table.second;
Parser parser(&ctx, src_name, tokens, DefaultOpTable(), src, meta_data_table.ToMetadata());
parser.ParseSemVer(false);
parser.PushScope();
auto expr = parser.ParseExpr();
Expand Down
2 changes: 1 addition & 1 deletion src/parser/source_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ namespace tvm {
namespace parser {

/*! \brief Construct a source from a string. */
Source::Source(const std::string& source) : source(source) {
Source::Source(const SourceName& src_name, const std::string& source) : source_name(src_name), source(source) {
int index = 0;
int length = 0;
line_map.push_back({index, length});
Expand Down
7 changes: 6 additions & 1 deletion src/parser/token.h
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,12 @@ int64_t Token::ToNumber() const { return Downcast<tvm::Integer>(this->operator->
std::string Token::ToString() const { return Downcast<tvm::String>(this->operator->()->data); }

Map<String, Array<ObjectRef>> Token::ToMetadata() const {
return Downcast<Map<String, Array<ObjectRef>>>(this->operator->()->data);
ObjectRef data = this->operator->()->data;
if (data.defined()) {
return Downcast<Map<String, Array<ObjectRef>>>(data);
} else {
return Map<String, Array<ObjectRef>>({});
}
}

} // namespace parser
Expand Down
19 changes: 15 additions & 4 deletions src/parser/tokenizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -533,12 +533,22 @@ struct Tokenizer {
tokens() {}
};

std::vector<Token> Condense(const std::vector<Token>& tokens) {
std::vector<Token> Condense(const std::vector<Token>& tokens, Token* table) {
std::vector<Token> out;
bool found_metadata = false;

for (size_t i = 0; i < tokens.size(); i++) {
auto current = tokens.at(i);
switch (current->token_type) {
case TokenType::kMetadata: {
if (!found_metadata) {
found_metadata = true;
*table = current;
} else {
LOG(FATAL) << "duplicate metadata section";
}
continue;
}
case TokenType::kPercent: {
auto next = tokens.at(i + 1);
if (next->token_type == TokenType::kIdentifier) {
Expand Down Expand Up @@ -602,15 +612,16 @@ std::vector<Token> Condense(const std::vector<Token>& tokens) {
return out;
}

std::vector<Token> Tokenize(DiagnosticContext* ctx, const SourceName& source_name,
std::pair<std::vector<Token>, Token> Tokenize(DiagnosticContext* ctx, const SourceName& source_name,
const std::string& source) {
auto tokenizer = Tokenizer(ctx, source_name, source);
tokenizer.Tokenize();
auto tokens = Condense(tokenizer.tokens);
Token meta_table(Span(), TokenType::kUnknown, ObjectRef());
auto tokens = Condense(tokenizer.tokens, &meta_table);
for (auto token : tokens) {
CHECK(token.defined());
}
return tokens;
return { tokens, meta_table };
}

} // namespace parser
Expand Down
7 changes: 4 additions & 3 deletions tests/python/relay/test_ir_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,9 +234,10 @@ def test_vars():
assert op.name == "nn.global_avg_pool2d"

def test_meta_ref():
meta_op = parse_text("meta[type_key][1337]")
assert meta_op.attrs.node_type_key == "type_key"
assert meta_op.attrs.node_index == 1337
with pytest.raises(tvm.error.DiagnosticError):
meta_op = parse_text("meta[type_key][1337]")
assert meta_op.attrs.node_type_key == "type_key"
assert meta_op.attrs.node_index == 1337


def test_let():
Expand Down
Loading

0 comments on commit e3e5ef8

Please sign in to comment.