diff --git a/python/hidet/ir/tools/printer.py b/python/hidet/ir/tools/printer.py index 2f95b2fa0..29fb903e2 100644 --- a/python/hidet/ir/tools/printer.py +++ b/python/hidet/ir/tools/printer.py @@ -75,6 +75,8 @@ def visit_Function(self, func: Function): else: head_doc += NewLine() head_doc += ')' + if not func.ret_type.is_void(): + head_doc += ' -> ' + self(func.ret_type) + ':' # attributes attr_doc = Doc() diff --git a/python/hidet/lang/transpiler.py b/python/hidet/lang/transpiler.py index a62a855e6..a674ce292 100644 --- a/python/hidet/lang/transpiler.py +++ b/python/hidet/lang/transpiler.py @@ -501,7 +501,8 @@ def visit_FunctionDef(self, func_def: FunctionDef): # the default return type is void ret_type = ir.VoidType() else: - ret_type = self.visit(func_def.returns) + # ret_type = self.visit(func_def.returns) + ret_type = self.func_annotations['return'] if not isinstance(ret_type, ir.BaseType): if ret_type is bool: ret_type = ir.data_type('bool') diff --git a/tests/script/test_return_type.py b/tests/script/test_return_type.py new file mode 100644 index 000000000..ced7cd7f7 --- /dev/null +++ b/tests/script/test_return_type.py @@ -0,0 +1,46 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +import hidet + + +def test_return_type(): + from hidet.lang import printf, attrs + from hidet.ir.dtypes import float32x8, float32 + from hidet.ir import primitives + from hidet.lang import address, cast + + with hidet.script_module() as script_module: + + @hidet.script + def example() -> float32x8: + attrs.func_kind = 'cpu_internal' + return primitives.cpu.avx_f32x8_setzero() + + @hidet.script + def main(): + attrs.func_kind = 'cpu_kernel' + + a = example() + a_unpacked = cast(address(a), ~float32) + for i in range(8): + printf("%f ", a_unpacked[i]) + printf("\n") + + func = script_module.build() + func() + + return func + + +if __name__ == '__main__': + pytest.main([__file__])