diff --git a/CMakeLists.txt b/CMakeLists.txt index 3f18d673e4a2..032e0bc2af00 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -287,7 +287,6 @@ tvm_file_glob(GLOB_RECURSE COMPILER_SRCS src/tir/*.cc src/topi/*.cc src/driver/*.cc - src/parser/*.cc src/support/*.cc src/script/*.cc ) @@ -317,6 +316,7 @@ tvm_file_glob(GLOB RELAY_BACKEND_SRCS tvm_file_glob(GLOB_RECURSE RELAY_IR_SRCS src/relay/ir/*.cc src/relay/printer/*.cc + src/relay/parser/*.cc ) tvm_file_glob(GLOB_RECURSE RELAY_QNN_SRCS src/relay/qnn/*.cc diff --git a/include/tvm/ir/diagnostic.h b/include/tvm/ir/diagnostic.h index 41130a5be0aa..3b2407491f26 100644 --- a/include/tvm/ir/diagnostic.h +++ b/include/tvm/ir/diagnostic.h @@ -27,14 +27,12 @@ #define TVM_IR_DIAGNOSTIC_H_ #include -#include #include #include namespace tvm { -using tvm::parser::SourceMap; using tvm::runtime::TypedPackedFunc; /*! \brief The diagnostic level, controls the printing of the message. */ diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index 78c09e81b16f..c8531c88465a 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -24,7 +24,7 @@ #ifndef TVM_IR_EXPR_H_ #define TVM_IR_EXPR_H_ -#include +#include #include #include #include diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h index 0a5bac182fd9..fdb44b11887c 100644 --- a/include/tvm/ir/module.h +++ b/include/tvm/ir/module.h @@ -27,8 +27,8 @@ #include #include #include +#include #include -#include #include #include #include @@ -60,7 +60,7 @@ class IRModuleNode : public Object { /*! \brief A map from global type vars to ADT type data. */ Map type_definitions; /*! \brief The source map for the module. */ - parser::SourceMap source_map; + SourceMap source_map; /* \brief Additional attributes storing meta-data about the module. */ DictAttrs attrs; /*! @@ -357,7 +357,7 @@ class IRModule : public ObjectRef { */ TVM_DLL explicit IRModule(Map functions, Map type_definitions = {}, - std::unordered_set import_set = {}, parser::SourceMap map = {}, + std::unordered_set import_set = {}, SourceMap map = {}, DictAttrs attrs = {}); /*! \brief default constructor */ diff --git a/include/tvm/ir/span.h b/include/tvm/ir/source_map.h similarity index 59% rename from include/tvm/ir/span.h rename to include/tvm/ir/source_map.h index b53ca2921fe7..536099f3114b 100644 --- a/include/tvm/ir/span.h +++ b/include/tvm/ir/source_map.h @@ -16,20 +16,25 @@ * specific language governing permissions and limitations * under the License. */ - /*! - * \file tvm/ir/span.h - * \brief Span information for debugging purposes. + * \file source_map.h + * \brief A map from source names to source code. */ -#ifndef TVM_IR_SPAN_H_ -#define TVM_IR_SPAN_H_ +#ifndef TVM_IR_SOURCE_MAP_H_ +#define TVM_IR_SOURCE_MAP_H_ #include #include +#include +#include +#include #include +#include +#include namespace tvm { + /*! * \brief The source name in the Span * \sa SourceNameNode, Span @@ -122,5 +127,84 @@ class Span : public ObjectRef { TVM_DEFINE_OBJECT_REF_METHODS(Span, ObjectRef, SpanNode); }; +/*! \brief A program source in any language. + * + * Could represent the source from an ML framework or a source + * representing a tvm::IRModule. + */ +class Source; + +class SourceNode : public Object { + public: + /*! \brief The source name. */ + SourceName source_name; + + /*! \brief The raw source. */ + String source; + + /*! \brief A mapping of line breaks into the raw source. */ + std::vector> line_map; + + // override attr visitor + void VisitAttrs(AttrVisitor* v) { + v->Visit("source_name", &source_name); + v->Visit("source", &source); + } + + static constexpr const char* _type_key = "Source"; + TVM_DECLARE_FINAL_OBJECT_INFO(SourceNode, Object); +}; + +class Source : public ObjectRef { + public: + TVM_DLL Source(SourceName src_name, std::string source); + TVM_DLL tvm::String GetLine(int line); + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Source, ObjectRef, SourceNode); +}; + +/*! + * \brief A mapping from a unique source name to source fragment. + */ +class SourceMap; +/*! + * \brief Stores locations in frontend source that generated a node. + */ +class SourceMapNode : public Object { + public: + /*! \brief The source mapping. */ + Map source_map; + + // override attr visitor + void VisitAttrs(AttrVisitor* v) { v->Visit("source_map", &source_map); } + + bool SEqualReduce(const SourceMapNode* other, SEqualReducer equal) const { + return equal(source_map, other->source_map); + } + + static constexpr const char* _type_key = "SourceMap"; + TVM_DECLARE_FINAL_OBJECT_INFO(SourceMapNode, Object); +}; + +class SourceMap : public ObjectRef { + public: + explicit SourceMap(Map source_map); + + explicit SourceMap(std::initializer_list> source_map) + : SourceMap(Map(source_map)) {} + + SourceMap() : SourceMap(Map()) {} + + void Add(const Source& source); + + SourceMapNode* operator->() { + ICHECK(get() != nullptr); + return static_cast(get_mutable()); + } + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(SourceMap, ObjectRef, SourceMapNode); +}; + } // namespace tvm -#endif // TVM_IR_SPAN_H_ + +#endif // TVM_IR_SOURCE_MAP_H_ diff --git a/include/tvm/ir/type.h b/include/tvm/ir/type.h index 62328f6a074a..c6baf5e08be3 100644 --- a/include/tvm/ir/type.h +++ b/include/tvm/ir/type.h @@ -49,7 +49,7 @@ #ifndef TVM_IR_TYPE_H_ #define TVM_IR_TYPE_H_ -#include +#include #include #include #include diff --git a/include/tvm/parser/source_map.h b/include/tvm/parser/source_map.h deleted file mode 100644 index a160c22a2a2f..000000000000 --- a/include/tvm/parser/source_map.h +++ /dev/null @@ -1,119 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/*! - * \file source_map.h - * \brief A map from source names to source code. - */ -#ifndef TVM_PARSER_SOURCE_MAP_H_ -#define TVM_PARSER_SOURCE_MAP_H_ - -#include -#include -#include - -#include -#include -#include -#include - -namespace tvm { -namespace parser { - -/*! \brief A program source in any language. - * - * Could represent the source from an ML framework or a source - * representing a tvm::IRModule. - */ -class Source; - -class SourceNode : public Object { - public: - /*! \brief The source name. */ - SourceName source_name; - - /*! \brief The raw source. */ - String source; - - /*! \brief A mapping of line breaks into the raw source. */ - std::vector> line_map; - - // override attr visitor - void VisitAttrs(AttrVisitor* v) { - v->Visit("source_name", &source_name); - v->Visit("source", &source); - } - - static constexpr const char* _type_key = "Source"; - TVM_DECLARE_FINAL_OBJECT_INFO(SourceNode, Object); -}; - -class Source : public ObjectRef { - public: - TVM_DLL Source(SourceName src_name, std::string source); - TVM_DLL tvm::String GetLine(int line); - - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Source, ObjectRef, SourceNode); -}; - -/*! - * \brief A mapping from a unique source name to source fragment. - */ -class SourceMap; -/*! - * \brief Stores locations in frontend source that generated a node. - */ -class SourceMapNode : public Object { - public: - /*! \brief The source mapping. */ - Map source_map; - - // override attr visitor - void VisitAttrs(AttrVisitor* v) { v->Visit("source_map", &source_map); } - - bool SEqualReduce(const SourceMapNode* other, SEqualReducer equal) const { - return equal(source_map, other->source_map); - } - - static constexpr const char* _type_key = "SourceMap"; - TVM_DECLARE_FINAL_OBJECT_INFO(SourceMapNode, Object); -}; - -class SourceMap : public ObjectRef { - public: - TVM_DLL SourceMap(Map source_map); - - TVM_DLL SourceMap(std::initializer_list> source_map) - : SourceMap(Map(source_map)) {} - - TVM_DLL SourceMap() : SourceMap(Map()) {} - - void Add(const Source& source); - - SourceMapNode* operator->() { - ICHECK(get() != nullptr); - return static_cast(get_mutable()); - } - - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(SourceMap, ObjectRef, SourceMapNode); -}; - -} // namespace parser -} // namespace tvm - -#endif // TVM_PARSER_SOURCE_MAP_H_ diff --git a/include/tvm/relay/base.h b/include/tvm/relay/base.h index 2825bcfc659a..a66b8044998b 100644 --- a/include/tvm/relay/base.h +++ b/include/tvm/relay/base.h @@ -24,7 +24,7 @@ #ifndef TVM_RELAY_BASE_H_ #define TVM_RELAY_BASE_H_ -#include +#include #include #include diff --git a/include/tvm/relay/error.h b/include/tvm/relay/error.h index be34e2b8ae1a..abe8278f2f5d 100644 --- a/include/tvm/relay/error.h +++ b/include/tvm/relay/error.h @@ -20,7 +20,6 @@ #define TVM_RELAY_ERROR_H_ #include -#include #include #include @@ -31,7 +30,7 @@ namespace tvm { namespace relay { /*! * \brief A wrapper around std::stringstream to build error. - * + *include/tvm/ir/type.h * Can be consumed by CompileError to construct an error. * * \code diff --git a/include/tvm/parser/parser.h b/include/tvm/relay/parser.h similarity index 86% rename from include/tvm/parser/parser.h rename to include/tvm/relay/parser.h index 0a73e1a2a532..6e33e7873f60 100644 --- a/include/tvm/parser/parser.h +++ b/include/tvm/relay/parser.h @@ -16,13 +16,9 @@ * specific language governing permissions and limitations * under the License. */ +#ifndef TVM_RELAY_PARSER_H_ +#define TVM_RELAY_PARSER_H_ -#ifndef TVM_PARSER_PARSER_H_ -#define TVM_PARSER_PARSER_H_ -/*! - * \file include/tvm/parser/parser.h - * \brief A parser for TVM IR. - */ #include #include #include @@ -32,7 +28,7 @@ #include namespace tvm { -namespace parser { +namespace relay { using MetaTable = Map>; @@ -45,9 +41,9 @@ IRModule ParseModule(const std::string& file_name, const std::string& file_conte * for all Relay sub-expressions. This improves error and debugging diagnostics downstream for * modules constructed programaticaly rather than textually. */ -transform::Pass AnnotateSpans(); +tvm::transform::Pass AnnotateSpans(); -} // namespace parser +} // namespace relay } // namespace tvm -#endif // TVM_PARSER_PARSER_H_ +#endif // TVM_RELAY_PARSER_H_ diff --git a/include/tvm/runtime/metadata_base.h b/include/tvm/runtime/metadata_base.h index 698f56d46d28..ca412a3b615c 100644 --- a/include/tvm/runtime/metadata_base.h +++ b/include/tvm/runtime/metadata_base.h @@ -24,7 +24,10 @@ #ifndef TVM_RUNTIME_METADATA_BASE_H_ #define TVM_RUNTIME_METADATA_BASE_H_ -#include +#include +#include +#include +#include #include #include diff --git a/python/tvm/ir/base.py b/python/tvm/ir/base.py index b84a83d55843..5df529b0532f 100644 --- a/python/tvm/ir/base.py +++ b/python/tvm/ir/base.py @@ -17,17 +17,23 @@ """Common base structures.""" import tvm._ffi import tvm.error -import tvm.runtime._ffi_node_api -from tvm.runtime import Object +from tvm._ffi import get_global_func, register_object +from tvm.runtime import Object, _ffi_node_api from . import _ffi_api, json_compact class Node(Object): - """Base class of all IR Nodes, implements astext function.""" + """Base class of all IR Nodes.""" -@tvm._ffi.register_object("SourceName") +@register_object("SourceMap") +class SourceMap(Object): + def add(self, name, content): + return get_global_func("SourceMapAdd")(self, name, content) + + +@register_object("SourceName") class SourceName(Object): """A identifier for a source location. @@ -38,10 +44,10 @@ class SourceName(Object): """ def __init__(self, name): - self.__init_handle_by_constructor__(_ffi_api.SourceName, name) + self.__init_handle_by_constructor__(_ffi_api.SourceName, name) # type: ignore # pylint: disable=no-member -@tvm._ffi.register_object("Span") +@register_object("Span") class Span(Object): """Specifies a location in a source program. @@ -59,11 +65,11 @@ class Span(Object): def __init__(self, source_name, line, end_line, column, end_column): self.__init_handle_by_constructor__( - _ffi_api.Span, source_name, line, end_line, column, end_column + _ffi_api.Span, source_name, line, end_line, column, end_column # type: ignore # pylint: disable=no-member ) -@tvm._ffi.register_object +@register_object class EnvFunc(Object): """Environment function. @@ -71,11 +77,11 @@ class EnvFunc(Object): """ def __call__(self, *args): - return _ffi_api.EnvFuncCall(self, *args) + return _ffi_api.EnvFuncCall(self, *args) # type: ignore # pylint: disable=no-member @property def func(self): - return _ffi_api.EnvFuncGetPackedFunc(self) + return _ffi_api.EnvFuncGetPackedFunc(self) # type: ignore # pylint: disable=no-member @staticmethod def get(name): @@ -86,7 +92,7 @@ def get(name): name : str The name of the function. """ - return _ffi_api.EnvFuncGet(name) + return _ffi_api.EnvFuncGet(name) # type: ignore # pylint: disable=no-member def load_json(json_str) -> Object: @@ -104,10 +110,10 @@ def load_json(json_str) -> Object: """ try: - return tvm.runtime._ffi_node_api.LoadJSON(json_str) + return _ffi_node_api.LoadJSON(json_str) except tvm.error.TVMError: json_str = json_compact.upgrade_json(json_str) - return tvm.runtime._ffi_node_api.LoadJSON(json_str) + return _ffi_node_api.LoadJSON(json_str) def save_json(node) -> str: @@ -123,7 +129,7 @@ def save_json(node) -> str: json_str : str Saved json string. """ - return tvm.runtime._ffi_node_api.SaveJSON(node) + return _ffi_node_api.SaveJSON(node) def structural_equal(lhs, rhs, map_free_vars=False): @@ -175,7 +181,7 @@ def structural_equal(lhs, rhs, map_free_vars=False): """ lhs = tvm.runtime.convert(lhs) rhs = tvm.runtime.convert(rhs) - return bool(tvm.runtime._ffi_node_api.StructuralEqual(lhs, rhs, False, map_free_vars)) + return bool(_ffi_node_api.StructuralEqual(lhs, rhs, False, map_free_vars)) # type: ignore # pylint: disable=no-member def get_first_structural_mismatch(lhs, rhs, map_free_vars=False): @@ -201,7 +207,7 @@ def get_first_structural_mismatch(lhs, rhs, map_free_vars=False): """ lhs = tvm.runtime.convert(lhs) rhs = tvm.runtime.convert(rhs) - mismatch = tvm.runtime._ffi_node_api.GetFirstStructuralMismatch(lhs, rhs, map_free_vars) + mismatch = _ffi_node_api.GetFirstStructuralMismatch(lhs, rhs, map_free_vars) # type: ignore # pylint: disable=no-member if mismatch is None: return None else: @@ -233,7 +239,7 @@ def assert_structural_equal(lhs, rhs, map_free_vars=False): """ lhs = tvm.runtime.convert(lhs) rhs = tvm.runtime.convert(rhs) - tvm.runtime._ffi_node_api.StructuralEqual(lhs, rhs, True, map_free_vars) + _ffi_node_api.StructuralEqual(lhs, rhs, True, map_free_vars) # type: ignore # pylint: disable=no-member def structural_hash(node, map_free_vars=False): @@ -275,4 +281,4 @@ def structural_hash(node, map_free_vars=False): -------- structrual_equal """ - return tvm.runtime._ffi_node_api.StructuralHash(node, map_free_vars) + return _ffi_node_api.StructuralHash(node, map_free_vars) # type: ignore # pylint: disable=no-member diff --git a/python/tvm/parser.py b/python/tvm/parser.py new file mode 100644 index 000000000000..63c40deb2069 --- /dev/null +++ b/python/tvm/parser.py @@ -0,0 +1,47 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""The legacy TVM parser """ +# pylint: disable=import-outside-toplevel + + +def parse(*args, **kwargs): + """Deprecated, use `tvm.relay.parse` instead""" + from tvm.relay import parse as _impl + + return _impl(*args, **kwargs) + + +def parse_expr(*args, **kwargs): + """Deprecated, use `tvm.relay.parse_expr` instead""" + from tvm.relay import parse_expr as _impl + + return _impl(*args, **kwargs) + + +def fromtext(*args, **kwargs): + """Deprecated, use `tvm.relay.fromtext` instead""" + from tvm.relay import fromtext as _impl + + return _impl(*args, **kwargs) + + +def SpanCheck(*args, **kwargs): + """Deprecated, use `tvm.relay.SpanCheck` instead""" + from tvm.relay import SpanCheck as _impl + + return _impl(*args, **kwargs) diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index 5e5d1d5f18d8..02eec18d3013 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -65,6 +65,9 @@ # Load Memory Passes from .transform import memory_plan +# Parser +from .parser import parse, parse_expr, fromtext, SpanCheck + # Required to traverse large programs setrecursionlimit(10000) diff --git a/python/tvm/parser/_ffi_api.py b/python/tvm/relay/_ffi_api_parser.py similarity index 91% rename from python/tvm/parser/_ffi_api.py rename to python/tvm/relay/_ffi_api_parser.py index 7fa3b78b72bb..731b926b5655 100644 --- a/python/tvm/parser/_ffi_api.py +++ b/python/tvm/relay/_ffi_api_parser.py @@ -14,8 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""FFI APIs for tvm.ir""" +"""FFI APIs for Relay parser.""" import tvm._ffi - -tvm._ffi._init_api("parser", __name__) +tvm._ffi._init_api("relay.parser", __name__) diff --git a/python/tvm/parser/__init__.py b/python/tvm/relay/parser.py similarity index 71% rename from python/tvm/parser/__init__.py rename to python/tvm/relay/parser.py index d75ad16ebab2..5e5f00a90eea 100644 --- a/python/tvm/parser/__init__.py +++ b/python/tvm/relay/parser.py @@ -15,25 +15,23 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=invalid-name -"""The under development unified IR parsing infrastructure.""" -from .. import _ffi, Object -from . import _ffi_api - - -@_ffi.register_object("SourceMap") -class SourceMap(Object): - def add(self, name, content): - return _ffi.get_global_func("SourceMapAdd")(self, name, content) +"""The relay parser.""" +from . import _ffi_api_parser def parse(source, source_name="from_string", init_module=None, init_meta_table=None): if init_meta_table is None: init_meta_table = {} - return _ffi_api.ParseModuleInContext(source_name, source, init_module, init_meta_table) + return _ffi_api_parser.ParseModuleInContext( # type: ignore # pylint: disable=no-member + source_name, + source, + init_module, + init_meta_table, + ) def parse_expr(source): - return _ffi_api.ParseExpr("string", source) + return _ffi_api_parser.ParseExpr("string", source) # type: ignore # pylint: disable=no-member def fromtext(source, source_name="from_string"): @@ -42,4 +40,4 @@ def fromtext(source, source_name="from_string"): def SpanCheck(): """A debugging utility for reporting missing span information.""" - return _ffi_api.SpanCheck() + return _ffi_api_parser.SpanCheck() # type: ignore # pylint: disable=no-member diff --git a/rust/tvm/src/ir/module.rs b/rust/tvm/src/ir/module.rs index ea257af1ebc0..8f71a8be2c7c 100644 --- a/rust/tvm/src/ir/module.rs +++ b/rust/tvm/src/ir/module.rs @@ -57,9 +57,9 @@ pub struct IRModuleNode { external! { // Parser functions - #[name("parser.ParseModule")] + #[name("relay.parser.ParseModule")] fn parse_module(file_name: TVMString, source: TVMString) -> IRModule; - #[name("parser.ParseExpr")] + #[name("relay.parser.ParseExpr")] fn parse_expression(file_name: TVMString, source: TVMString) -> IRModule; #[name("ir.IRModule")] fn module_new(funcs: Map, types: Map) -> IRModule; diff --git a/src/ir/diagnostic.cc b/src/ir/diagnostic.cc index 336575a93e97..6687a28d8c84 100644 --- a/src/ir/diagnostic.cc +++ b/src/ir/diagnostic.cc @@ -22,14 +22,12 @@ * \brief Implementation of DiagnosticContext and friends. */ #include -#include +#include #include namespace tvm { -using tvm::parser::Source; - // failed to check to argument arg0.dims[0] != 0 /* Diagnostic */ diff --git a/src/ir/module.cc b/src/ir/module.cc index b6923cd1e60d..22c6faf3d69d 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -16,16 +16,14 @@ * specific language governing permissions and limitations * under the License. */ - /*! * \file module.cc - * \brief The global module in Relay. + * \brief The global module in TVM. */ #include #include #include #include -#include #include #include @@ -36,8 +34,7 @@ namespace tvm { IRModule::IRModule(tvm::Map functions, tvm::Map type_definitions, - std::unordered_set import_set, parser::SourceMap source_map, - DictAttrs attrs) { + std::unordered_set import_set, SourceMap source_map, DictAttrs attrs) { auto n = make_object(); n->functions = std::move(functions); n->type_definitions = std::move(type_definitions); @@ -322,12 +319,14 @@ IRModule IRModule::FromExpr(const RelayExpr& expr, const Mapimport_set_.count(path) == 0) { this->import_set_.insert(path); std::fstream src_file(path, std::fstream::in); std::string file_contents{std::istreambuf_iterator(src_file), std::istreambuf_iterator()}; - auto mod_to_import = parser::ParseModule(path, file_contents, GetRef(this)); + auto mod_to_import = (*f)(path, file_contents, GetRef(this)); Update(mod_to_import); } } @@ -342,7 +341,9 @@ void IRModuleNode::ImportFromStd(const String& path) { std::unordered_set IRModuleNode::Imports() const { return this->import_set_; } IRModule IRModule::FromText(const String& text, const String& source_path) { - return tvm::parser::ParseModule(source_path, text); + static const auto* f = runtime::Registry::Get("relay.parser.ParseModule"); + ICHECK(f != nullptr) << "ValueError: Relay parser is not available"; + return (*f)(source_path, text, Optional()); } TVM_REGISTER_NODE_TYPE(IRModuleNode); diff --git a/src/ir/span.cc b/src/ir/source_map.cc similarity index 61% rename from src/ir/span.cc rename to src/ir/source_map.cc index 39f0044d16d3..8b913906ea42 100644 --- a/src/ir/span.cc +++ b/src/ir/source_map.cc @@ -17,11 +17,10 @@ * under the License. */ /*! - * \file span.cc - * \brief The span data structure. + * \file source_map.cc + * \brief The implementation of the source map data structure. */ -#include -#include +#include #include #include @@ -100,4 +99,72 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "Span(" << node->source_name << ", " << node->line << ", " << node->end_line << ", " << node->column << ", " << node->end_column << ")"; }); + +TVM_REGISTER_NODE_TYPE(SourceNode); + +/*! \brief Construct a source from a string. */ +Source::Source(SourceName src_name, std::string source) { + auto n = make_object(); + n->source_name = std::move(src_name); + n->source = std::move(source); + + int index = 0; + int length = 0; + n->line_map.push_back({index, length}); + // NB(@jroesch): + std::string source_str = n->source; + for (auto c : source_str) { + if (c == '\n') { + // Record the length of the line. + n->line_map.back().second = length; + // Bump past the newline. + index += 1; + // Record the start of the next line, and put placeholder for length. + n->line_map.push_back({index, 0}); + // Reset length to zero. + length = 0; + } else { + length += 1; + index += 1; + } + } + n->line_map.back().second = length; + + data_ = n; +} + +tvm::String Source::GetLine(int line) { + VLOG(1) << "Source::GetLine: line=" << line; + ICHECK(line - 1 < static_cast((*this)->line_map.size())) + << "requested line: " << line << "at index: " << (line - 1) + << "line_map size: " << (*this)->line_map.size() << "source: " << (*this)->source; + + // Adjust for zero indexing, now have (line_start, line_length); + auto range = (*this)->line_map.at(line - 1); + int line_start = range.first; + int line_length = range.second; + VLOG(1) << "Source::GetLine: line_start=" << line_start << " line_length=" << line_length; + // TODO(@jroesch): expose substring on tvm::String. + auto line_text = std::string((*this)->source).substr(line_start, line_length); + VLOG(1) << "Source::GetLine: line_text=" << line_text; + return line_text; +} + +TVM_REGISTER_NODE_TYPE(SourceMapNode); + +SourceMap::SourceMap(Map source_map) { + auto n = make_object(); + n->source_map = std::move(source_map); + data_ = std::move(n); +} + +void SourceMap::Add(const Source& source) { (*this)->source_map.Set(source->source_name, source); } + +TVM_REGISTER_GLOBAL("SourceMapAdd").set_body_typed([](SourceMap map, String name, String content) { + auto src_name = SourceName::Get(name); + Source source(src_name, content); + map.Add(source); + return src_name; +}); + } // namespace tvm diff --git a/src/ir/transform.cc b/src/ir/transform.cc index 9a669493ccb7..66b06e6b505d 100644 --- a/src/ir/transform.cc +++ b/src/ir/transform.cc @@ -587,11 +587,12 @@ TVM_REGISTER_GLOBAL("transform.OverrideInstruments") Pass PrintIR(String header, bool show_meta_data) { auto pass_func = [header, show_meta_data](IRModule mod, const PassContext& ctx) { - if (const auto* f = runtime::Registry::Get("relay.PrintIR")) { - (*f)(mod, header, show_meta_data); - } else { - LOG(INFO) << "PrintIR(" << header << "):\n" << mod; + if (const auto* f = runtime::Registry::Get("relay.ir.PrintIR")) { + if ((*f)(mod, header, show_meta_data)) { + return mod; + } } + LOG(INFO) << "PrintIR(" << header << "):\n" << mod; return mod; }; return CreateModulePass(pass_func, 0, "PrintIR", {}); diff --git a/src/parser/source_map.cc b/src/parser/source_map.cc deleted file mode 100644 index 3c1329670c40..000000000000 --- a/src/parser/source_map.cc +++ /dev/null @@ -1,97 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/*! - * \file source_map.cc - * \brief The implementation of the source map data structure. - */ -#include -#include - -namespace tvm { -namespace parser { - -TVM_REGISTER_NODE_TYPE(SourceNode); - -/*! \brief Construct a source from a string. */ -Source::Source(SourceName src_name, std::string source) { - auto n = make_object(); - n->source_name = std::move(src_name); - n->source = std::move(source); - - int index = 0; - int length = 0; - n->line_map.push_back({index, length}); - // NB(@jroesch): - std::string source_str = n->source; - for (auto c : source_str) { - if (c == '\n') { - // Record the length of the line. - n->line_map.back().second = length; - // Bump past the newline. - index += 1; - // Record the start of the next line, and put placeholder for length. - n->line_map.push_back({index, 0}); - // Reset length to zero. - length = 0; - } else { - length += 1; - index += 1; - } - } - n->line_map.back().second = length; - - data_ = n; -} - -tvm::String Source::GetLine(int line) { - VLOG(1) << "Source::GetLine: line=" << line; - ICHECK(line - 1 < static_cast((*this)->line_map.size())) - << "requested line: " << line << "at index: " << (line - 1) - << "line_map size: " << (*this)->line_map.size() << "source: " << (*this)->source; - - // Adjust for zero indexing, now have (line_start, line_length); - auto range = (*this)->line_map.at(line - 1); - int line_start = range.first; - int line_length = range.second; - VLOG(1) << "Source::GetLine: line_start=" << line_start << " line_length=" << line_length; - // TODO(@jroesch): expose substring on tvm::String. - auto line_text = std::string((*this)->source).substr(line_start, line_length); - VLOG(1) << "Source::GetLine: line_text=" << line_text; - return line_text; -} - -TVM_REGISTER_NODE_TYPE(SourceMapNode); - -SourceMap::SourceMap(Map source_map) { - auto n = make_object(); - n->source_map = std::move(source_map); - data_ = std::move(n); -} - -void SourceMap::Add(const Source& source) { (*this)->source_map.Set(source->source_name, source); } - -TVM_REGISTER_GLOBAL("SourceMapAdd").set_body_typed([](SourceMap map, String name, String content) { - auto src_name = SourceName::Get(name); - Source source(src_name, content); - map.Add(source); - return src_name; -}); - -} // namespace parser -} // namespace tvm diff --git a/src/relay/backend/utils.cc b/src/relay/backend/utils.cc index 183a3094e473..4ff8a59b349e 100644 --- a/src/relay/backend/utils.cc +++ b/src/relay/backend/utils.cc @@ -25,7 +25,7 @@ #include "utils.h" -#include +#include #include #include #include diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index fb23c4cc082a..c29b3195a3fd 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -25,13 +25,13 @@ #include "compiler.h" #include -#include #include #include #include #include #include #include +#include #include #include #include diff --git a/src/relay/ir/base.cc b/src/relay/ir/base.cc index 5f913026080d..deedd283c2ff 100644 --- a/src/relay/ir/base.cc +++ b/src/relay/ir/base.cc @@ -39,22 +39,5 @@ Id::Id(String name_hint) { data_ = std::move(n); } -TVM_REGISTER_GLOBAL("ir.NodeSetSpan").set_body_typed([](ObjectRef node_ref, Span sp) { - if (auto* rn = node_ref.as()) { - rn->span = sp; - } else if (auto* rn = node_ref.as()) { - rn->span = sp; - } else if (auto* rn = node_ref.as()) { - rn->span = sp; - } else { - LOG(FATAL) << "Expect Type or RelayNode "; - } -}); - -TVM_REGISTER_GLOBAL("relay.PrintIR") - .set_body_typed([](ObjectRef mod, String header, bool show_metadata) { - LOG(INFO) << "PrintIR(" << header << "):\n" << AsText(mod, show_metadata); - }); - } // namespace relay } // namespace tvm diff --git a/src/relay/ir/function.cc b/src/relay/ir/function.cc index 3ff5eaa059c1..5d743d521777 100644 --- a/src/relay/ir/function.cc +++ b/src/relay/ir/function.cc @@ -123,6 +123,7 @@ const FunctionNode* AsOptimizableFunctionNode(const BaseFunc& base_func) { } return nullptr; } + TVM_REGISTER_GLOBAL("relay.ir.PrintRelayModule") .set_body_typed([](IRModule mod) -> Optional { for (const auto& it : mod->functions) { @@ -133,6 +134,17 @@ TVM_REGISTER_GLOBAL("relay.ir.PrintRelayModule") return NullOpt; }); +TVM_REGISTER_GLOBAL("relay.ir.PrintIR") + .set_body_typed([](IRModule mod, String header, bool show_metadata) -> bool { + for (const auto& it : mod->functions) { + if (it.second->IsInstance()) { + LOG(INFO) << "PrintIR(" << header << "):\n" << AsText(mod, show_metadata); + return true; + } + } + return false; + }); + TVM_REGISTER_GLOBAL("relay.ir.WarnIfMalformed") .set_body_typed([](const IRModule& mod, const BaseFunc& base_func) -> void { if (const auto* relay_func = base_func.as()) { diff --git a/src/parser/meta_ref.cc b/src/relay/parser/meta_ref.cc similarity index 98% rename from src/parser/meta_ref.cc rename to src/relay/parser/meta_ref.cc index 6b0e8d0c5966..cdc6929622dd 100644 --- a/src/parser/meta_ref.cc +++ b/src/relay/parser/meta_ref.cc @@ -30,7 +30,7 @@ #include namespace tvm { -namespace parser { +namespace relay { using tvm::relay::transform::CreateFunctionPass; using tvm::transform::PassContext; @@ -95,5 +95,5 @@ IRModule ExpandMetaRefs(const MetaTable& meta_table, const IRModule& mod) { return pass(mod, PassContext::Create()); } -} // namespace parser +} // namespace relay } // namespace tvm diff --git a/src/parser/meta_ref.h b/src/relay/parser/meta_ref.h similarity index 92% rename from src/parser/meta_ref.h rename to src/relay/parser/meta_ref.h index 483b7f726e07..bed67bea05a4 100644 --- a/src/parser/meta_ref.h +++ b/src/relay/parser/meta_ref.h @@ -22,20 +22,18 @@ * \brief A reference into the metadata section of the Relay text format. */ -#ifndef TVM_PARSER_META_REF_H_ -#define TVM_PARSER_META_REF_H_ +#ifndef TVM_RELAY_PARSER_META_REF_H_ +#define TVM_RELAY_PARSER_META_REF_H_ #include -#include #include #include +#include #include namespace tvm { -namespace parser { - -using namespace relay; +namespace relay { /*! * \brief Options for allocating storage. @@ -78,7 +76,7 @@ Expr MetaRef(std::string type_key, uint64_t node_index); relay::Function ExpandMetaRefs(const MetaTable& meta_table, const relay::Function& func); IRModule ExpandMetaRefs(const MetaTable& meta_table, const IRModule& mod); -} // namespace parser +} // namespace relay } // namespace tvm -#endif // TVM_PARSER_META_REF_H_ +#endif // TVM_RELAY_PARSER_META_REF_H_ diff --git a/src/parser/op_table.h b/src/relay/parser/op_table.h similarity index 93% rename from src/parser/op_table.h rename to src/relay/parser/op_table.h index 28c9cd7fc05f..6ff2c05476f4 100644 --- a/src/parser/op_table.h +++ b/src/relay/parser/op_table.h @@ -18,14 +18,13 @@ */ /*! - * \file token.h + * \file op_table.h * \brief A operator table for parsing. - * * Provides symbolic token sequences to map to TVM operators, with a given associativity and arity. */ -#ifndef TVM_PARSER_OP_TABLE_H_ -#define TVM_PARSER_OP_TABLE_H_ +#ifndef TVM_RELAY_PARSER_OP_TABLE_H_ +#define TVM_RELAY_PARSER_OP_TABLE_H_ #include #include @@ -38,7 +37,7 @@ #include "./tokenizer.h" namespace tvm { -namespace parser { +namespace relay { struct Rule { std::vector tokens; @@ -77,7 +76,7 @@ struct OperatorTable { } }; -OperatorTable DefaultOpTable() { +inline OperatorTable DefaultOpTable() { return OperatorTable( {Rule({TokenType::kStar}, Op::Get("multiply"), 12, 2, true), Rule({TokenType::kDivision}, Op::Get("divide"), 12, 2, true), @@ -91,6 +90,6 @@ OperatorTable DefaultOpTable() { Rule({TokenType::kBang, TokenType::kEqual}, Op::Get("not_equal"), 7, 2, true)}); } -} // namespace parser +} // namespace relay } // namespace tvm -#endif // TVM_PARSER_OP_TABLE_H_ +#endif // TVM_RELAY_PARSER_OP_TABLE_H_ diff --git a/src/parser/parser.cc b/src/relay/parser/parser.cc similarity index 99% rename from src/parser/parser.cc rename to src/relay/parser/parser.cc index fe89857f2709..ae7fc52cbead 100644 --- a/src/parser/parser.cc +++ b/src/relay/parser/parser.cc @@ -23,11 +23,12 @@ */ #include #include -#include #include #include #include +#include #include +#include #include #include #include @@ -35,18 +36,14 @@ #include -#include "../support/scalars.h" +#include "../../support/scalars.h" #include "./meta_ref.h" #include "./op_table.h" #include "./span_check.h" #include "./tokenizer.h" -#include "tvm/runtime/builtin_fp16.h" namespace tvm { -namespace parser { - -using namespace relay; -using Expr = relay::Expr; +namespace relay { /*! \brief The meta table maps from type key to a sequence of objects. */ using MetaTable = Map>; @@ -1948,22 +1945,6 @@ Expr ParseExpr(const std::string& file_name, const std::string& file_content) { return expr; } -TVM_REGISTER_GLOBAL("parser.ParseModuleInContext") - .set_body_typed([](const std::string& file_name, const std::string& file_content, - const Optional& init_module, const MetaTable& init_meta_table) { - return ParseModule(file_name, file_content, init_module, init_meta_table); - }); - -TVM_REGISTER_GLOBAL("parser.ParseModule") - .set_body_typed([](const std::string& file_name, const std::string& file_content) { - return ParseModule(file_name, file_content); - }); - -TVM_REGISTER_GLOBAL("parser.ParseExpr") - .set_body_typed([](tvm::String file_name, tvm::String file_content) { - return ParseExpr(file_name, file_content); - }); - /*! * \brief This pass pretty-prints mod then parses it back so as to establish spans and sources * for all Relay sub-expressions. This improves error and debugging diagnostics downstream for @@ -1978,7 +1959,29 @@ Pass AnnotateSpans() { return CreateModulePass(pass_func, 0, "AnnotateSpans", {}); } +TVM_REGISTER_GLOBAL("relay.parser.ParseModuleInContext") + .set_body_typed([](const std::string& file_name, const std::string& file_content, + const Optional& init_module, const MetaTable& init_meta_table) { + return ParseModule(file_name, file_content, init_module, init_meta_table); + }); + +TVM_REGISTER_GLOBAL("relay.parser.ParseModule").set_body([](TVMArgs args, TVMRetValue* ret) { + ICHECK(args.size() >= 2 && args.size() <= 4) << "Expected 2-4 arguments, but got " << args.size(); + if (args.size() == 2) { + *ret = ParseModule(args[0], args[1]); + } else if (args.size() == 3) { + *ret = ParseModule(args[0], args[1], args[2]); + } else { + *ret = ParseModule(args[0], args[1], args[2], args[3]); + } +}); + +TVM_REGISTER_GLOBAL("relay.parser.ParseExpr") + .set_body_typed([](tvm::String file_name, tvm::String file_content) { + return ParseExpr(file_name, file_content); + }); + TVM_REGISTER_GLOBAL("relay._transform.AnnotateSpans").set_body_typed(AnnotateSpans); -} // namespace parser +} // namespace relay } // namespace tvm diff --git a/src/parser/span_check.cc b/src/relay/parser/span_check.cc similarity index 96% rename from src/parser/span_check.cc rename to src/relay/parser/span_check.cc index 7fed3730d926..6bbf6317ad9f 100644 --- a/src/parser/span_check.cc +++ b/src/relay/parser/span_check.cc @@ -25,7 +25,7 @@ #include namespace tvm { -namespace parser { +namespace relay { using tvm::relay::transform::CreateFunctionPass; using tvm::transform::PassContext; @@ -101,7 +101,7 @@ Pass SpanCheck() { 0, "SpanCheck", {}); } -TVM_REGISTER_GLOBAL("parser.SpanCheck").set_body_typed([]() { return SpanCheck(); }); +TVM_REGISTER_GLOBAL("relay.parser.SpanCheck").set_body_typed([]() { return SpanCheck(); }); -} // namespace parser +} // namespace relay } // namespace tvm diff --git a/src/parser/span_check.h b/src/relay/parser/span_check.h similarity index 93% rename from src/parser/span_check.h rename to src/relay/parser/span_check.h index 0074c66d61f4..b85b4a497965 100644 --- a/src/parser/span_check.h +++ b/src/relay/parser/span_check.h @@ -21,9 +21,8 @@ * \file span_check.h * \brief Check that the Relay IR has correctly attached span information. */ - -#ifndef TVM_PARSER_SPAN_CHECK_H_ -#define TVM_PARSER_SPAN_CHECK_H_ +#ifndef TVM_RELAY_PARSER_SPAN_CHECK_H_ +#define TVM_RELAY_PARSER_SPAN_CHECK_H_ #include #include @@ -38,7 +37,7 @@ #include namespace tvm { -namespace parser { +namespace relay { using namespace tvm::relay; using tvm::transform::Pass; @@ -74,6 +73,6 @@ struct SpanChecker : ExprVisitor { Pass SpanCheck(); -} // namespace parser +} // namespace relay } // namespace tvm -#endif // TVM_PARSER_SPAN_CHECK_H_ +#endif // TVM_RELAY_PARSER_SPAN_CHECK_H_ diff --git a/src/parser/token.h b/src/relay/parser/token.h similarity index 93% rename from src/parser/token.h rename to src/relay/parser/token.h index 48a1bf70a250..7b11e701cf6e 100644 --- a/src/parser/token.h +++ b/src/relay/parser/token.h @@ -22,10 +22,11 @@ * \brief The definition of tokens for the TVM parser. */ -#ifndef TVM_PARSER_TOKEN_H_ -#define TVM_PARSER_TOKEN_H_ +#ifndef TVM_RELAY_PARSER_TOKEN_H_ +#define TVM_RELAY_PARSER_TOKEN_H_ -#include +#include +#include #include #include @@ -33,7 +34,7 @@ #include namespace tvm { -namespace parser { +namespace relay { using namespace runtime; @@ -97,7 +98,7 @@ enum class TokenType { kNull, }; -std::string ToString(const TokenType& token_type) { +inline std::string ToString(const TokenType& token_type) { switch (token_type) { case TokenType::kCommentStart: return "CommentStart"; @@ -219,7 +220,7 @@ std::string ToString(const TokenType& token_type) { } } -std::string Pretty(const TokenType& token_type) { +inline std::string Pretty(const TokenType& token_type) { switch (token_type) { case TokenType::kCommentStart: return "`/*`"; @@ -375,7 +376,7 @@ class Token : public ObjectRef { TVM_DEFINE_OBJECT_REF_METHODS(Token, ObjectRef, TokenNode); }; -Token::Token(Span span, TokenType token_type, ObjectRef data) { +inline Token::Token(Span span, TokenType token_type, ObjectRef data) { ObjectPtr n = make_object(); n->span = span; n->token_type = token_type; @@ -383,15 +384,17 @@ Token::Token(Span span, TokenType token_type, ObjectRef data) { data_ = std::move(n); } -Token Token::Null() { return Token(Span(SourceName(), 0, 0, 0, 0), TokenType::kNull); } +inline Token Token::Null() { return Token(Span(SourceName(), 0, 0, 0, 0), TokenType::kNull); } -int64_t Token::ToNumber() const { +inline int64_t Token::ToNumber() const { return Downcast(this->operator->()->data).IntValue(); } -std::string Token::ToString() const { return Downcast(this->operator->()->data); } +inline std::string Token::ToString() const { + return Downcast(this->operator->()->data); +} -Map> Token::ToMetadata() const { +inline Map> Token::ToMetadata() const { ObjectRef data = this->operator->()->data; if (data.defined()) { return Downcast>>(data); @@ -400,6 +403,6 @@ Map> Token::ToMetadata() const { } } -} // namespace parser +} // namespace relay } // namespace tvm -#endif // TVM_PARSER_TOKEN_H_ +#endif // TVM_RELAY_PARSER_TOKEN_H_ diff --git a/src/parser/tokenizer.h b/src/relay/parser/tokenizer.h similarity index 96% rename from src/parser/tokenizer.h rename to src/relay/parser/tokenizer.h index 505784e4bf70..04dcd3263e99 100644 --- a/src/parser/tokenizer.h +++ b/src/relay/parser/tokenizer.h @@ -18,11 +18,11 @@ */ /*! - * \file parser.h + * \file tokenizer.h * \brief A parser for TVM IR. */ -#ifndef TVM_PARSER_TOKENIZER_H_ -#define TVM_PARSER_TOKENIZER_H_ +#ifndef TVM_RELAY_PARSER_TOKENIZER_H_ +#define TVM_RELAY_PARSER_TOKENIZER_H_ #include #include @@ -34,12 +34,12 @@ #include #include -#include "../support/scalars.h" +#include "../../support/scalars.h" #include "./meta_ref.h" #include "./token.h" namespace tvm { -namespace parser { +namespace relay { using namespace runtime; @@ -54,20 +54,20 @@ static inline void rtrim(std::string& s) { // NOLINT(*) s.end()); } -bool IsDigit(char c) { return '0' <= c && c <= '9'; } +inline bool IsDigit(char c) { return '0' <= c && c <= '9'; } -bool IsWhitespace(char c) { return ' ' == c || c == '\t' || c == '\n'; } +inline bool IsWhitespace(char c) { return ' ' == c || c == '\t' || c == '\n'; } -bool IsNumeric(char c) { +inline bool IsNumeric(char c) { return (IsDigit(c) || c == '.' || c == 'e' || c == '-' || c == '+' || c == 'E') && !IsWhitespace(c); } -bool IsIdentLetter(char c) { +inline bool IsIdentLetter(char c) { return '_' == c || c == '/' || ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z'); } -bool IsIdent(char c) { return IsIdentLetter(c) || IsDigit(c); } +inline bool IsIdent(char c) { return IsIdentLetter(c) || IsDigit(c); } static std::unordered_map KEYWORD_TABLE = { {"let", TokenType::kLet}, {"fn", TokenType::kFn}, @@ -371,7 +371,7 @@ struct Tokenizer { int line = this->line; int col = this->col; auto next = Peek(); - VLOG(9) << "tvm::parser::TokenizeOnce: next=" << next; + VLOG(9) << "tvm::relay::TokenizeOnce: next=" << next; if (next == '\n') { auto token = NewToken(TokenType::kNewline); Next(); @@ -582,7 +582,7 @@ struct Tokenizer { } void Tokenize() { - VLOG(9) << "tvm::parser::Tokenize"; + VLOG(9) << "tvm::relay::Tokenize"; while (this->More()) { auto token = TokenizeOnce(); ICHECK(token.defined()); @@ -601,7 +601,7 @@ struct Tokenizer { tokens() {} }; -std::vector Condense(const std::vector& tokens, Token* table) { +inline std::vector Condense(const std::vector& tokens, Token* table) { std::vector out; bool found_metadata = false; @@ -680,7 +680,8 @@ std::vector Condense(const std::vector& tokens, Token* table) { return out; } -std::pair, Token> Tokenize(const DiagnosticContext& ctx, const Source& source) { +inline std::pair, Token> Tokenize(const DiagnosticContext& ctx, + const Source& source) { auto tokenizer = Tokenizer(ctx, source); tokenizer.Tokenize(); Token meta_table(Span(), TokenType::kUnknown, ObjectRef()); @@ -691,7 +692,7 @@ std::pair, Token> Tokenize(const DiagnosticContext& ctx, cons return {tokens, meta_table}; } -} // namespace parser +} // namespace relay } // namespace tvm -#endif // TVM_PARSER_TOKENIZER_H_ +#endif // TVM_RELAY_PARSER_TOKENIZER_H_ diff --git a/src/relay/printer/relay_text_printer.cc b/src/relay/printer/relay_text_printer.cc index cc86f9b56435..5b47c262fd48 100644 --- a/src/relay/printer/relay_text_printer.cc +++ b/src/relay/printer/relay_text_printer.cc @@ -41,9 +41,9 @@ #include #include "../../ir/attr_functor.h" -#include "../../parser/meta_ref.h" #include "../../support/scalars.h" #include "../analysis/dependency_graph.h" +#include "../parser/meta_ref.h" #include "doc.h" #include "meta_data.h" #include "text_printer.h" diff --git a/src/runtime/profiling.cc b/src/runtime/profiling.cc index 168441d1708d..8b6600fbdfa9 100644 --- a/src/runtime/profiling.cc +++ b/src/runtime/profiling.cc @@ -23,7 +23,6 @@ */ #include -#include #include #include #include diff --git a/tests/cpp/relay/backend/aot/aot_lower_main_test.cc b/tests/cpp/relay/backend/aot/aot_lower_main_test.cc index 31166f1e6bb8..0157f031c214 100644 --- a/tests/cpp/relay/backend/aot/aot_lower_main_test.cc +++ b/tests/cpp/relay/backend/aot/aot_lower_main_test.cc @@ -20,7 +20,7 @@ #include "../../../../../src/relay/backend/aot/aot_lower_main.h" #include -#include +#include namespace tvm { namespace relay { @@ -37,7 +37,7 @@ TEST(AOTLowerMain, ExprAllocatorSkipNestedFunc) { %0(%x) } )"; - IRModule mod = parser::ParseModule("string", mod_text, {}, {}); + IRModule mod = ParseModule("string", mod_text, {}, {}); auto host_target = tvm::Target("llvm"); auto prim_target = tvm::Target(host_target, host_target); auto ctxt = tvm::transform::PassContext::Current(); diff --git a/tests/cpp/relay/collage/candidate_partition_test.cc b/tests/cpp/relay/collage/candidate_partition_test.cc index bc5d2d880a3b..d298a493c11f 100644 --- a/tests/cpp/relay/collage/candidate_partition_test.cc +++ b/tests/cpp/relay/collage/candidate_partition_test.cc @@ -20,9 +20,9 @@ #include "../../../src/relay/collage/candidate_partition.h" #include -#include #include #include +#include #include #include "../../../../src/relay/collage/mock_cost_estimator.h" @@ -37,7 +37,7 @@ namespace { // so not re-tested here. The only other non-trivial code is CandidatePartition::EstimateCost Function MakeTestFunction(const std::string& mod_text) { - IRModule mod = parser::ParseModule("string", mod_text, {}, {}); + IRModule mod = ParseModule("string", mod_text, {}, {}); mod = transform::CapturePostDfsIndexInSpans()(mod); auto func = Downcast(mod->Lookup("main")); LOG(INFO) << "------- input function -------"; diff --git a/tests/cpp/relay/collage/partition_rule_test.cc b/tests/cpp/relay/collage/partition_rule_test.cc index 51a4970c7ec0..effe0b1fa030 100644 --- a/tests/cpp/relay/collage/partition_rule_test.cc +++ b/tests/cpp/relay/collage/partition_rule_test.cc @@ -20,9 +20,9 @@ #include "../../../src/relay/collage/partition_rule.h" #include -#include #include #include +#include #include #include "../../../src/relay/collage/partition_spec.h" @@ -46,7 +46,7 @@ Function MakeTestFunction( } Map> metatable; metatable.Set("relay.Constant", constants); - IRModule mod = parser::ParseModule("string", mod_text, {}, metatable); + IRModule mod = ParseModule("string", mod_text, {}, metatable); mod = transform::CapturePostDfsIndexInSpans()(mod); auto func = Downcast(mod->Lookup("main")); LOG(INFO) << "------- input function -------"; diff --git a/tests/cpp/relay/df_pattern_rewrite_test.cc b/tests/cpp/relay/df_pattern_rewrite_test.cc index af09ae48aafd..374887c12a22 100644 --- a/tests/cpp/relay/df_pattern_rewrite_test.cc +++ b/tests/cpp/relay/df_pattern_rewrite_test.cc @@ -18,11 +18,11 @@ */ #include -#include #include #include #include #include +#include #include "../../../src/relay/transforms/simplify_expr.h" @@ -82,7 +82,7 @@ TEST(DFPatternRewrite, DeeplyNestedWithCallAttributes) { } )"; - IRModule module = parser::ParseModule("string", kModel); + IRModule module = ParseModule("string", kModel); DFPatternRewriteComposer composer; composer.AddRewrite(); Function in_function = Downcast(module->Lookup("main")); diff --git a/tests/cpp/relay/ir/indexed_graph_test.cc b/tests/cpp/relay/ir/indexed_graph_test.cc index 17ec68261684..486d027fbc21 100644 --- a/tests/cpp/relay/ir/indexed_graph_test.cc +++ b/tests/cpp/relay/ir/indexed_graph_test.cc @@ -20,9 +20,9 @@ #include "../../../src/relay/ir/indexed_graph.h" #include -#include #include #include +#include namespace tvm { namespace relay { @@ -81,7 +81,7 @@ IRModule TestRecursiveIRModule() { (%19, %20) // 51 } // 52 )"; - return parser::ParseModule("string", kModel, /*init_module=*/{}, metadata); + return ParseModule("string", kModel, /*init_module=*/{}, metadata); } TEST(IndexedGraph, RecursiveExprRegression) { @@ -179,7 +179,7 @@ IRModule TestUnusedLetBoundIRModule() { } } )"; - return parser::ParseModule("string", kModel); + return ParseModule("string", kModel); } TEST(IndexedGraph, UnusedLetVars) { diff --git a/tests/cpp/relay/transforms/device_domains_test.cc b/tests/cpp/relay/transforms/device_domains_test.cc index c5b2f26315b2..47e303996b3b 100644 --- a/tests/cpp/relay/transforms/device_domains_test.cc +++ b/tests/cpp/relay/transforms/device_domains_test.cc @@ -27,7 +27,7 @@ #include "../../../../src/relay/transforms/device_domains.h" #include -#include +#include #include namespace tvm { @@ -36,7 +36,7 @@ namespace transform { namespace { IRModule TestModule() { - return InferType()(tvm::parser::ParseModule("test", R"( + return InferType()(ParseModule("test", R"( #[version = "0.0.5"] def @f(%x : Tensor[(3, 7), float32], %y : Tensor[(3, 7), float32]) { add(%x, %y) diff --git a/tests/cpp/relay/with_fields_test.cc b/tests/cpp/relay/with_fields_test.cc index 48e04c259bb5..6114fa97a9fd 100644 --- a/tests/cpp/relay/with_fields_test.cc +++ b/tests/cpp/relay/with_fields_test.cc @@ -23,18 +23,18 @@ */ #include -#include #include #include #include +#include namespace tvm { namespace relay { namespace { IRModule TestIRModule() { - return parser::ParseModule("string", - R"( + return ParseModule("string", + R"( #[version = "0.0.5"] def @main(%data : Tensor[(1, 304, 128, 128), float32], %weight1 : Tensor[(304, 1, 3, 3), float32],