From 33fa64e254a6114c533a25e4c9b9b6f0985a6da0 Mon Sep 17 00:00:00 2001 From: ubospica Date: Sat, 25 Feb 2023 13:55:00 +0000 Subject: [PATCH 1/2] finished --- src/script/printer/ir_docsifier.cc | 10 +-- .../test_tvmscript_printer_metadata.py | 66 +++++++++++++++++++ 2 files changed, 71 insertions(+), 5 deletions(-) create mode 100644 tests/python/unittest/test_tvmscript_printer_metadata.py diff --git a/src/script/printer/ir_docsifier.cc b/src/script/printer/ir_docsifier.cc index 936534480ffb..fd5003073afb 100644 --- a/src/script/printer/ir_docsifier.cc +++ b/src/script/printer/ir_docsifier.cc @@ -56,11 +56,11 @@ ExprDoc IRDocsifierNode::AddMetadata(const ObjectRef& obj) { ICHECK(obj.defined()) << "TypeError: Cannot add nullptr to metadata"; String key = obj->GetTypeKey(); Array& array = metadata[key]; - int index = array.size(); - array.push_back(obj); - return IdDoc("metadata") // - [{LiteralDoc::Str(key, NullOpt)}] // - [{LiteralDoc::Int(index, NullOpt)}]; + int index = std::find(array.begin(), array.end(), obj) - array.begin(); + if (index == static_cast(array.size())) { + array.push_back(obj); + } + return IdDoc("metadata")[{LiteralDoc::Str(key, NullOpt)}][{LiteralDoc::Int(index, NullOpt)}]; } bool IRDocsifierNode::IsVarDefined(const ObjectRef& obj) const { return obj2info.count(obj); } diff --git a/tests/python/unittest/test_tvmscript_printer_metadata.py b/tests/python/unittest/test_tvmscript_printer_metadata.py new file mode 100644 index 000000000000..46f324cede07 --- /dev/null +++ b/tests/python/unittest/test_tvmscript_printer_metadata.py @@ -0,0 +1,66 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring +import tvm.testing +from tvm.script.parser import ir as I +from tvm.script.parser import tir as T + + +def _assert_print(obj, expected): + assert obj.script(verbose_expr=True).strip() == expected.strip() + + +def test_str_metadata(): + str_imm = T.StringImm("aaa\nbbb\n") + + @I.ir_module + class Module: + @T.prim_func + def foo() -> None: + A = str_imm + B = str_imm + + @T.prim_func + def foo1() -> None: + A = str_imm + + _assert_print( + Module, + """ +# from tvm.script import ir as I +# from tvm.script import tir as T + +@I.ir_module +class Module: + @T.prim_func + def foo(): + A: T.handle = metadata["tir.StringImm"][0] + B: T.handle = metadata["tir.StringImm"][0] + T.evaluate(0) + + @T.prim_func + def foo1(): + A: T.handle = metadata["tir.StringImm"][0] + T.evaluate(0) + + +# Metadata omitted. Use show_meta=True in script() method to show it.""", + ) + + +if __name__ == "__main__": + tvm.testing.main() From 55cd6d955eeb24cfe3a1cc3b4de3e18540a5ede7 Mon Sep 17 00:00:00 2001 From: ubospica Date: Sat, 25 Feb 2023 15:31:21 +0000 Subject: [PATCH 2/2] finished --- .../test_tvmscript_printer_metadata.py | 31 ++++--------------- 1 file changed, 6 insertions(+), 25 deletions(-) diff --git a/tests/python/unittest/test_tvmscript_printer_metadata.py b/tests/python/unittest/test_tvmscript_printer_metadata.py index 46f324cede07..a57f4c71f7dd 100644 --- a/tests/python/unittest/test_tvmscript_printer_metadata.py +++ b/tests/python/unittest/test_tvmscript_printer_metadata.py @@ -20,11 +20,9 @@ from tvm.script.parser import tir as T -def _assert_print(obj, expected): - assert obj.script(verbose_expr=True).strip() == expected.strip() - - def test_str_metadata(): + # This test is to check we reuse the existing metadata element for the same tir.StringImm + # So metadata["tir.StringImm"][0] will occur in the printed script for three times str_imm = T.StringImm("aaa\nbbb\n") @I.ir_module @@ -38,27 +36,10 @@ def foo() -> None: def foo1() -> None: A = str_imm - _assert_print( - Module, - """ -# from tvm.script import ir as I -# from tvm.script import tir as T - -@I.ir_module -class Module: - @T.prim_func - def foo(): - A: T.handle = metadata["tir.StringImm"][0] - B: T.handle = metadata["tir.StringImm"][0] - T.evaluate(0) - - @T.prim_func - def foo1(): - A: T.handle = metadata["tir.StringImm"][0] - T.evaluate(0) - - -# Metadata omitted. Use show_meta=True in script() method to show it.""", + printed_str = Module.script(verbose_expr=True) + assert ( + printed_str.count('metadata["tir.StringImm"][0]') == 3 + and printed_str.count('metadata["tir.StringImm"][1]') == 0 )