diff --git a/CMakeLists.txt b/CMakeLists.txt index 306a8be30858..46de8f5d07fa 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -281,6 +281,7 @@ tvm_file_glob(GLOB_RECURSE COMPILER_SRCS src/parser/*.cc src/printer/*.cc src/support/*.cc + src/script/*.cc ) tvm_file_glob(GLOB CODEGEN_SRCS diff --git a/include/tvm/script/printer/doc.h b/include/tvm/script/printer/doc.h new file mode 100644 index 000000000000..67c27bd45a1d --- /dev/null +++ b/include/tvm/script/printer/doc.h @@ -0,0 +1,165 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_SCRIPT_PRINTER_DOC_H_ +#define TVM_SCRIPT_PRINTER_DOC_H_ + +#include +#include +#include + +namespace tvm { +namespace script { +namespace printer { + +/*! + * \brief The base class of all Doc. + * + * Doc is an intermediate representation between IR from TVM + * and the TVMScript code. + * During printing, IR graph is first translated into Doc tree, + * then the Doc tree is translated to the target language in + * text format. + * + * \sa Doc + */ +class DocNode : public Object { + public: + void VisitAttrs(AttrVisitor* v) {} + + static constexpr const char* _type_key = "script.printer.Doc"; + TVM_DECLARE_BASE_OBJECT_INFO(DocNode, Object); + + public: + virtual ~DocNode() = default; +}; + +/*! + * \brief Reference type of DocNode. + * + * \sa DocNode + */ +class Doc : public ObjectRef { + protected: + Doc() = default; + + public: + virtual ~Doc() = default; + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Doc, ObjectRef, DocNode); +}; + +/*! + * \brief The base class of expression doc. + * + * \sa ExprDoc + */ +class ExprDocNode : public DocNode { + public: + void VisitAttrs(AttrVisitor* v) { DocNode::VisitAttrs(v); } + + static constexpr const char* _type_key = "script.printer.ExprDoc"; + TVM_DECLARE_BASE_OBJECT_INFO(ExprDocNode, DocNode); +}; + +/*! + * \brief Reference type of ExprDocNode. + * + * \sa ExprDocNode + */ +class ExprDoc : public Doc { + protected: + ExprDoc() = default; + + public: + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ExprDoc, Doc, ExprDocNode); +}; + +/*! + * \brief Doc that represents literal value. + * + * \sa LiteralDoc + */ +class LiteralDocNode : public ExprDocNode { + public: + /*! + * \brief the internal representation of the literal value. + * + * Possible actual types: + * - IntImm (integer or boolean) + * - FloatImm + * - String + * - null + */ + ObjectRef value; + + void VisitAttrs(AttrVisitor* v) { + ExprDocNode::VisitAttrs(v); + v->Visit("value", &value); + } + + static constexpr const char* _type_key = "script.printer.LiteralDoc"; + TVM_DECLARE_FINAL_OBJECT_INFO(LiteralDocNode, ExprDocNode); +}; + +/*! + * \brief Reference type of LiteralDocNode. + * + * \sa LiteralDocNode + */ +class LiteralDoc : public ExprDoc { + protected: + explicit LiteralDoc(ObjectRef value); + + public: + /*! + * \brief Create a LiteralDoc to represent None/null/empty value. + */ + static LiteralDoc None() { return LiteralDoc(ObjectRef(nullptr)); } + + /*! + * \brief Create a LiteralDoc to represent integer. + * \param v The integer value. + */ + static LiteralDoc Int(int v) { return LiteralDoc(IntImm(DataType::Int(64), v)); } + + /*! + * \brief Create a LiteralDoc to represent boolean. + * \param v The boolean value. + */ + static LiteralDoc Boolean(bool v) { return LiteralDoc(IntImm(DataType::Bool(), v)); } + + /*! + * \brief Create a LiteralDoc to represent float. + * \param v The float value. + */ + static LiteralDoc Float(double v) { return LiteralDoc(FloatImm(DataType::Float(64), v)); } + + /*! + * \brief Create a LiteralDoc to represent string. + * \param v The string value. + */ + static LiteralDoc Str(const String& v) { return LiteralDoc(v); } + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(LiteralDoc, ExprDoc, LiteralDocNode); +}; + +} // namespace printer +} // namespace script +} // namespace tvm + +#endif // TVM_SCRIPT_PRINTER_DOC_H_ diff --git a/include/tvm/script/printer/doc_printer.h b/include/tvm/script/printer/doc_printer.h new file mode 100644 index 000000000000..6bf502fab910 --- /dev/null +++ b/include/tvm/script/printer/doc_printer.h @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_SCRIPT_PRINTER_DOC_PRINTER_H_ +#define TVM_SCRIPT_PRINTER_DOC_PRINTER_H_ + +#include + +namespace tvm { +namespace script { +namespace printer { + +/*! + * \brief Convert Doc into Python script. + * + * This function unpacks the DocPrinterOptions into function arguments + * to be FFI friendly. + * + * \param doc the doc to be converted + * \param indent_spaces the number of spaces used for indention + */ +String DocToPythonScript(Doc doc, int indent_spaces = 4); + +} // namespace printer +} // namespace script +} // namespace tvm + +#endif // TVM_SCRIPT_PRINTER_DOC_PRINTER_H_ diff --git a/python/tvm/script/printer/__init__.py b/python/tvm/script/printer/__init__.py new file mode 100644 index 000000000000..84ab7b0ba836 --- /dev/null +++ b/python/tvm/script/printer/__init__.py @@ -0,0 +1,26 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +TVMScript Unified Printer + +This package provides a set of APIs to print supported TVM IR into TVMScript +in a roundtrippable way. + +https://github.com/apache/tvm-rfcs/blob/main/rfcs/0074-tvmscript-unified-printer.md +""" + +from . import _ffi_api diff --git a/python/tvm/script/printer/_ffi_api.py b/python/tvm/script/printer/_ffi_api.py new file mode 100644 index 000000000000..baa639fe2d67 --- /dev/null +++ b/python/tvm/script/printer/_ffi_api.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""FFI APIs for tvm.script.printer""" +import tvm._ffi + +tvm._ffi._init_api("script.printer", __name__) diff --git a/python/tvm/script/printer/doc.py b/python/tvm/script/printer/doc.py new file mode 100644 index 000000000000..f6179d7351b2 --- /dev/null +++ b/python/tvm/script/printer/doc.py @@ -0,0 +1,49 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Doc types for TVMScript Unified Printer""" + +import tvm._ffi +from tvm.runtime import Object + +from . import _ffi_api + + +class Doc(Object): + """Base class of all Docs""" + + +class ExprDoc(Object): + """Base class of all expression Docs""" + + +@tvm._ffi.register_object("script.printer.LiteralDoc") +class LiteralDoc(ExprDoc): + """Doc that represents literal value""" + + def __init__(self, value): + if value is None: + self.__init_handle_by_constructor__(_ffi_api.LiteralDocNone) # type: ignore + elif isinstance(value, str): + self.__init_handle_by_constructor__(_ffi_api.LiteralDocStr, value) # type: ignore + elif isinstance(value, float): + self.__init_handle_by_constructor__(_ffi_api.LiteralDocFloat, value) # type: ignore + elif isinstance(value, bool): + self.__init_handle_by_constructor__(_ffi_api.LiteralDocBoolean, value) # type: ignore + elif isinstance(value, int): + self.__init_handle_by_constructor__(_ffi_api.LiteralDocInt, value) # type: ignore + else: + raise TypeError(f"Unsupported type {type(value)} for LiteralDoc") diff --git a/python/tvm/script/printer/doc_printer.py b/python/tvm/script/printer/doc_printer.py new file mode 100644 index 000000000000..404632b44c07 --- /dev/null +++ b/python/tvm/script/printer/doc_printer.py @@ -0,0 +1,39 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Functions to print doc into text format""" + +from . import _ffi_api +from .doc import Doc + + +def to_python_script(doc: Doc, indent_spaces: int = 4) -> str: + """ + Convert Doc into Python script. + + Parameters + ---------- + doc : Doc + The doc to convert into Python script + indent_spaces : int + The number of indent spaces to use in the output + + Returns + ------- + script : str + The text representation of Doc in Python syntax + """ + return _ffi_api.DocToPythonScript(doc, indent_spaces) # type: ignore diff --git a/src/script/printer/base_doc_printer.cc b/src/script/printer/base_doc_printer.cc new file mode 100644 index 000000000000..f6874ba1a2ee --- /dev/null +++ b/src/script/printer/base_doc_printer.cc @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "./base_doc_printer.h" + +namespace tvm { +namespace script { +namespace printer { + +DocPrinter::DocPrinter(int indent_spaces) : indent_spaces_(indent_spaces) {} + +void DocPrinter::Append(const Doc& doc) { PrintDoc(doc); } + +String DocPrinter::GetString() const { + std::string text = output_.str(); + if (!text.empty() && text.back() != '\n') { + text.push_back('\n'); + } + return text; +} + +void DocPrinter::PrintDoc(const Doc& doc) { + if (const auto* doc_node = doc.as()) { + PrintTypedDoc(GetRef(doc_node)); + } else { + LOG(FATAL) << "Do not know how to print " << doc->GetTypeKey(); + throw; + } +} + +} // namespace printer +} // namespace script +} // namespace tvm diff --git a/src/script/printer/base_doc_printer.h b/src/script/printer/base_doc_printer.h new file mode 100644 index 000000000000..128fcef2ea32 --- /dev/null +++ b/src/script/printer/base_doc_printer.h @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_SCRIPT_PRINTER_BASE_DOC_PRINTER_H_ +#define TVM_SCRIPT_PRINTER_BASE_DOC_PRINTER_H_ + +#include +#include + +#include +#include +#include + +namespace tvm { +namespace script { +namespace printer { + +/*! + * \brief DocPrinter is responsible for printing Doc tree into text format + * \details This is the base class for translating Doc into string. + * Each target language needs to have its subclass of DocPrinter + * to define the actual logic of printing Doc. + * + * \sa Doc + */ +class DocPrinter { + public: + /*! + * \brief The constructor of DocPrinter + * + * \param options the option for printer + */ + explicit DocPrinter(int indent_spaces = 4); + virtual ~DocPrinter() = default; + + /*! + * \brief Append a doc into the final content + * + * \param doc the Doc to be printed + * + * \sa GetString + */ + void Append(const Doc& doc); + + /*! + * \brief Get the printed string of all Doc appended + * + * The content of each Doc in the returned string will + * appear in the same order as they are appended. + * + * \sa Append + */ + String GetString() const; + + protected: + /*! + * \brief Get the printed string + * + * It will dispatch to the PrintTypedDoc method based on + * the actual type of Doc. + * + * \sa PrintTypedDoc + */ + void PrintDoc(const Doc& doc); + + /*! + * \brief Virtual method to print a LiteralDoc + */ + virtual void PrintTypedDoc(const LiteralDoc& doc) = 0; + + /*! + * \brief Increase the indent level of any content to be + * printed after this call + */ + void IncreaseIndent() { indent_ += indent_spaces_; } + + /*! + * \brief Decrease the indent level of any content to be + * printed after this call + */ + void DecreaseIndent() { indent_ -= indent_spaces_; } + + /*! + * \brief Add a new line into the output stream + * + * \sa output_ + */ + std::ostream& NewLine() { + output_ << "\n"; + output_ << std::string(indent_, ' '); + return output_; + } + + /*! + * \brief The output stream of printer + * + * All printed content will be stored in this stream and returned + * when GetString is called. + * + * \sa GetString + */ + std::ostringstream output_; + + private: + /*! \brief the number of spaces for one level of indentation */ + int indent_spaces_ = 4; + + /*! \brief the current level of indent */ + int indent_ = 0; +}; + +} // namespace printer +} // namespace script +} // namespace tvm + +#endif // TVM_SCRIPT_PRINTER_BASE_DOC_PRINTER_H_ diff --git a/src/script/printer/doc.cc b/src/script/printer/doc.cc new file mode 100644 index 000000000000..e54adbd36b4c --- /dev/null +++ b/src/script/printer/doc.cc @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include + +namespace tvm { +namespace script { +namespace printer { + +LiteralDoc::LiteralDoc(ObjectRef value) { + ObjectPtr n = make_object(); + n->value = value; + this->data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(DocNode); +TVM_REGISTER_NODE_TYPE(ExprDocNode); +TVM_REGISTER_NODE_TYPE(LiteralDocNode); +TVM_REGISTER_GLOBAL("script.printer.LiteralDocNone").set_body_typed(LiteralDoc::None); +TVM_REGISTER_GLOBAL("script.printer.LiteralDocInt").set_body_typed(LiteralDoc::Int); +TVM_REGISTER_GLOBAL("script.printer.LiteralDocBoolean").set_body_typed(LiteralDoc::Boolean); +TVM_REGISTER_GLOBAL("script.printer.LiteralDocFloat").set_body_typed(LiteralDoc::Float); +TVM_REGISTER_GLOBAL("script.printer.LiteralDocStr").set_body_typed(LiteralDoc::Str); + +} // namespace printer +} // namespace script +} // namespace tvm diff --git a/src/script/printer/python_doc_printer.cc b/src/script/printer/python_doc_printer.cc new file mode 100644 index 000000000000..cd816e4f7010 --- /dev/null +++ b/src/script/printer/python_doc_printer.cc @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include + +#include "../../support/str_escape.h" +#include "./base_doc_printer.h" + +namespace tvm { +namespace script { +namespace printer { + +class PythonDocPrinter : public DocPrinter { + public: + explicit PythonDocPrinter(int indent_spaces = 4) : DocPrinter(indent_spaces) {} + + protected: + using DocPrinter::PrintDoc; + + void PrintTypedDoc(const LiteralDoc& doc) final; +}; + +void PythonDocPrinter::PrintTypedDoc(const LiteralDoc& doc) { + const ObjectRef& value = doc->value; + if (!value.defined()) { + output_ << "None"; + } else if (const auto* int_imm = value.as()) { + if (int_imm->dtype.is_bool()) { + output_ << (int_imm->value ? "True" : "False"); + } else { + output_ << int_imm->value; + } + } else if (const auto* float_imm = value.as()) { + // TODO(yelite): Make float number printing roundtrippable + output_.precision(17); + output_ << float_imm->value; + } else if (const auto* string_obj = value.as()) { + output_ << "\"" << support::StrEscape(string_obj->data, string_obj->size) << "\""; + } else { + LOG(FATAL) << "TypeError: Unsupported literal value type: " << value->GetTypeKey(); + } +} + +String DocToPythonScript(Doc doc, int indent_spaces) { + PythonDocPrinter printer(indent_spaces); + printer.Append(doc); + return printer.GetString(); +} + +TVM_REGISTER_GLOBAL("script.printer.DocToPythonScript").set_body_typed(DocToPythonScript); + +} // namespace printer +} // namespace script +} // namespace tvm diff --git a/tests/python/unittest/test_tvmscript_printer_doc.py b/tests/python/unittest/test_tvmscript_printer_doc.py new file mode 100644 index 000000000000..6330d33bf25a --- /dev/null +++ b/tests/python/unittest/test_tvmscript_printer_doc.py @@ -0,0 +1,33 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +from tvm.tir import IntImm +from tvm.script.printer.doc import LiteralDoc + + +@pytest.mark.parametrize( + "value", + [None, "test", 0, 1, -2, 0.0, 1.5, -1.3, True, False], +) +def test_literal_doc_construction(value): + doc = LiteralDoc(value) + if isinstance(value, float): + # FloatImm cannot be compared with Python's float directly + assert float(doc.value) == pytest.approx(value) + else: + assert doc.value == value diff --git a/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py b/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py new file mode 100644 index 000000000000..55b5e88c88c8 --- /dev/null +++ b/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py @@ -0,0 +1,53 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +from tvm.script.printer.doc_printer import to_python_script +from tvm.script.printer.doc import LiteralDoc + + +def format_script(s: str) -> str: + """ + Remove leading and trailing blank lines, and make the minimum idention 0 + """ + s = s.strip("\n") + non_empty_lines = [line for line in s.splitlines() if line and not line.isspace()] + line_indents = [len(line) - len(line.lstrip(" ")) for line in non_empty_lines] + spaces_to_remove = min(line_indents) + return "\n".join(line[spaces_to_remove:] for line in s.splitlines()) + + +@pytest.mark.parametrize( + "doc,expected", + [ + (LiteralDoc(None), "None"), + (LiteralDoc(True), "True"), + (LiteralDoc(False), "False"), + (LiteralDoc("test"), '"test"'), + (LiteralDoc(""), '""'), + (LiteralDoc('""'), r'"\"\""'), + (LiteralDoc("\n\t\\test\r"), r'"\n\t\\test\r"'), + # TODO: fix the roundatrippable problem caused by utf8 + pytest.param(LiteralDoc("\x88"), r'"\x88"', marks=pytest.mark.xfail), + (LiteralDoc(0), "0"), + (LiteralDoc(-1), "-1"), + (LiteralDoc(3.25), "3.25"), + (LiteralDoc(-0.5), "-0.5"), + ], +) +def test_print_literal_doc(doc, expected): + assert to_python_script(doc).rstrip("\n") == format_script(expected) diff --git a/tests/scripts/task_mypy.sh b/tests/scripts/task_mypy.sh index 1ef7db589432..f165adfe1bc4 100755 --- a/tests/scripts/task_mypy.sh +++ b/tests/scripts/task_mypy.sh @@ -32,6 +32,9 @@ mypy --check-untyped-defs python/tvm/tir/analysis/ echo "Checking MyPy Type defs in the transform package." mypy --check-untyped-defs python/tvm/tir/transform/ +echo "Checking MyPy Type defs in the tvmscript printer package." +mypy --check-untyped-defs python/tvm/script/printer + echo "Checking MyPy Type defs in the TIR package with unittest" MYPYPATH=$TVM_PATH/python mypy --check-untyped-defs tests/python/unittest/test_tvmscript_type.py