Skip to content

Commit

Permalink
Merge pull request #3 from msaelices/types-error-handling
Browse files Browse the repository at this point in the history
Types error handling
  • Loading branch information
msaelices committed Sep 14, 2023
2 parents cddc724 + 3b6eace commit 9e52c69
Show file tree
Hide file tree
Showing 7 changed files with 69 additions and 3 deletions.
5 changes: 5 additions & 0 deletions py2mojo/converters/functiondef.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from tokenize_rt import Token

from ..exceptions import ParseException
from ..helpers import (
ast_to_offset,
get_annotation_type,
Expand Down Expand Up @@ -63,6 +64,10 @@ def convert_functiondef(node: ast.FunctionDef, rules: RuleSet = 0) -> Iterable:
)
continue

if rules.convert_def_to_fn and not arg.annotation:
raise ParseException(
node, 'For converting a def function to fn, the declaration needs to be fully type annotated'
)
curr_type = get_annotation_type(arg.annotation)
new_type = get_mojo_type(curr_type)

Expand Down
7 changes: 7 additions & 0 deletions py2mojo/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import ast


class ParseException(Exception):
def __init__(self, node: ast.AST, msg: str):
self.node = node
self.msg = msg
29 changes: 29 additions & 0 deletions py2mojo/helpers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import ast
import re

import astor
from rich import print
from rich.text import Text
from tokenize_rt import UNIMPORTANT_WS, Offset, Token


Expand Down Expand Up @@ -115,3 +118,29 @@ def get_mojo_type(curr_type: str) -> str:
curr_type = pattern.sub(replacement, curr_type)

return curr_type


def highlight_code_at_position(code: str, line: int, column: int, end_column: int) -> Text:
lines = code.splitlines()
highlighted = Text()

for idx, source_line in enumerate(lines):
if idx + 1 == line:
# Highlight the specific column in the given line
highlighted.append(source_line[:column], style='white')
for i in range(column, min(end_column, len(source_line))):
highlighted.append(source_line[i], style='bold black on yellow')
highlighted.append(source_line[end_column + 1 :], style='white')
else:
highlighted.append(source_line, style='white')
highlighted.append('\n')

return highlighted


def display_error(node: ast.AST, message: str):
src = astor.to_source(node)

highlighted_src = highlight_code_at_position(src, 1, node.col_offset, node.end_col_offset)
print('[bold red]Error:[/bold red]', message)
print(highlighted_src)
9 changes: 7 additions & 2 deletions py2mojo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
from tokenize_rt import Token, reversed_enumerate, src_to_tokens, tokens_to_src

from .converters import convert_assignment, convert_functiondef, convert_classdef
from .helpers import fixup_dedent_tokens
from .exceptions import ParseException
from .helpers import display_error, fixup_dedent_tokens
from .rules import get_rules, RuleSet


Expand Down Expand Up @@ -113,7 +114,11 @@ def main(argv: Sequence[str] | None = None) -> int:

rules = get_rules(args)

annotated_source = convert_to_mojo(source, rules)
try:
annotated_source = convert_to_mojo(source, rules)
except ParseException as exc:
display_error(exc.node, exc.msg)
sys.exit(1)

if source != annotated_source:
print(f'Rewriting {filename}' if args.inplace else f'Rewriting {filename} into {mojo_filename}')
Expand Down
4 changes: 3 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
tokenize_rt>=5.2.0
tokenize_rt>=5.2.0
rich>=13.5.2
astor>=0.8.1
4 changes: 4 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
import pytest

# We want pytest assert introspection in the helpers
pytest.register_assert_rewrite('helpers')
14 changes: 14 additions & 0 deletions tests/test_functiondef.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest

from helpers import validate
from py2mojo.exceptions import ParseException
from py2mojo.rules import RuleSet


Expand Down Expand Up @@ -72,3 +73,16 @@ class Point:
def __init__(inout self, x: Int, y: Int) -> Int: ...
''',
)


def test_functiondef_non_fully_annotated_functions():
validate(
'''def add(x, y): return x + y''',
'''def add(x, y): return x + y''',
)
with pytest.raises(ParseException):
validate(
'''def add(x, y): return x + y''',
'''def add(x, y): return x + y''',
rules=RuleSet(convert_def_to_fn=True),
)

0 comments on commit 9e52c69

Please sign in to comment.