Skip to content

Commit

Permalink
[Fixbug][Hidet Script] Fix a bug that hidet script does not recognize…
Browse files Browse the repository at this point in the history
… return type (#329)

Fix hidet script transpiler to run recognize the following code
```python
def test_unroll():
    from hidet.lang import printf, attrs
    from hidet.ir.dtypes import float32x8, float32
    from hidet.ir import primitives
    from hidet.lang import address, cast

    with hidet.script_module() as script_module:

        @hidet.script
        def example() -> float32x8:
            attrs.func_kind = 'cpu_internal'
            return primitives.cpu.avx_f32x8_setzero()

        @hidet.script
        def main():
            attrs.func_kind = 'cpu_kernel'

            a = example()
            a_unpacked = cast(address(a), ~float32)
            for i in range(8):
                printf("%f ", a_unpacked[i])
            printf("\n")

    func = script_module.build()
    func()

    return func
```
Previously, hidet can not recognize the `float32x8` return type.
  • Loading branch information
yaoyaoding authored Jul 25, 2023
1 parent fcdec26 commit b356f3d
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 1 deletion.
2 changes: 2 additions & 0 deletions python/hidet/ir/tools/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ def visit_Function(self, func: Function):
else:
head_doc += NewLine()
head_doc += ')'
if not func.ret_type.is_void():
head_doc += ' -> ' + self(func.ret_type) + ':'

# attributes
attr_doc = Doc()
Expand Down
3 changes: 2 additions & 1 deletion python/hidet/lang/transpiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,8 @@ def visit_FunctionDef(self, func_def: FunctionDef):
# the default return type is void
ret_type = ir.VoidType()
else:
ret_type = self.visit(func_def.returns)
# ret_type = self.visit(func_def.returns)
ret_type = self.func_annotations['return']
if not isinstance(ret_type, ir.BaseType):
if ret_type is bool:
ret_type = ir.data_type('bool')
Expand Down
46 changes: 46 additions & 0 deletions tests/script/test_return_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import hidet


def test_return_type():
from hidet.lang import printf, attrs
from hidet.ir.dtypes import float32x8, float32
from hidet.ir import primitives
from hidet.lang import address, cast

with hidet.script_module() as script_module:

@hidet.script
def example() -> float32x8:
attrs.func_kind = 'cpu_internal'
return primitives.cpu.avx_f32x8_setzero()

@hidet.script
def main():
attrs.func_kind = 'cpu_kernel'

a = example()
a_unpacked = cast(address(a), ~float32)
for i in range(8):
printf("%f ", a_unpacked[i])
printf("\n")

func = script_module.build()
func()

return func


if __name__ == '__main__':
pytest.main([__file__])

0 comments on commit b356f3d

Please sign in to comment.