Skip to content

Commit

Permalink
Change the def keyword to fn if the function has no arguments with --…
Browse files Browse the repository at this point in the history
…level=1
  • Loading branch information
msaelices committed Sep 13, 2023
1 parent 15a3b1d commit 1288fec
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 8 deletions.
6 changes: 3 additions & 3 deletions py2mojo/converters/functiondef.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,6 @@ def _add_declaration(tokens: list, i: int, level: int, declaration: str) -> None

def convert_functiondef(node: ast.FunctionDef, level: int = 0) -> Iterable:
"""Converts the annotation of the given function definition."""
if not node.args.args:
return

if level > 0:
offset = ast_to_offset(node)
yield (
Expand All @@ -51,6 +48,9 @@ def convert_functiondef(node: ast.FunctionDef, level: int = 0) -> Iterable:
),
)

if not node.args.args:
return

for arg in node.args.args:
if arg.arg == 'self':
yield (
Expand Down
23 changes: 18 additions & 5 deletions tests/test_functiondef.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,25 @@

from helpers import validate

# parametrize the converted types

def test_functiondef_with_no_params():
validate(
'def main(): print("Hello world!")',
'def main(): print("Hello world!")',
)
validate(
'def main(): print("Hello world!")',
'fn main(): print("Hello world!")',
level=1,
)


@pytest.mark.parametrize(
'python_type, mojo_type',
[
('int', 'Int'),
('float', 'Float64'),
]
],
)
def test_functiondef_with_basic_types(python_type, mojo_type):
validate(
Expand All @@ -27,7 +39,7 @@ def test_functiondef_with_basic_types(python_type, mojo_type):
[
('int', 'Int'),
('float', 'Float64'),
]
],
)
def test_functiondef_with_list_types(python_type, mojo_type):
validate(
Expand All @@ -47,13 +59,14 @@ def test_functiondef_with_list_types(python_type, mojo_type):
'def concat(l1: list, l2: list) -> list: return l1 + l2', # no changed
)


def test_functiondef_inside_classes():
validate(
'''
'''
class Point:
def __init__(self, x: int, y: int) -> int: ...
''',
'''
'''
class Point:
def __init__(inout self, x: Int, y: Int) -> Int: ...
''',
Expand Down

0 comments on commit 1288fec

Please sign in to comment.