diff --git a/auto_type_annotate.py b/auto_type_annotate.py index da0bfe7..6fc5c31 100644 --- a/auto_type_annotate.py +++ b/auto_type_annotate.py @@ -52,6 +52,13 @@ def _args(node: ast.AsyncFunctionDef | ast.FunctionDef) -> Generator[ast.arg]: yield subnode +def _is_abstract(node: ast.AST) -> bool: + return ( + isinstance(node, ast.Attribute) and node.attr == 'abstractmethod' or + isinstance(node, ast.Name) and node.id == 'abstractmethod' + ) + + class FindUntyped(ast.NodeVisitor): def __init__(self) -> None: self._mod: list[Mod] = [] @@ -73,7 +80,10 @@ def visit_FunctionDef( self, node: ast.AsyncFunctionDef | ast.FunctionDef, ) -> None: - if not self._in_func[-1]: + if ( + not self._in_func[-1] and + not any(_is_abstract(dec) for dec in node.decorator_list) + ): args = tuple(_args(node)) if node.name == '__init__' and len(args) > 1: missing_annotation = False diff --git a/tests/auto_type_annotate_test.py b/tests/auto_type_annotate_test.py index ef672a3..abd393e 100644 --- a/tests/auto_type_annotate_test.py +++ b/tests/auto_type_annotate_test.py @@ -90,6 +90,28 @@ def test_find_untyped_async_def(): assert _find_untyped(src) == [(_MOD, 'f')] +def test_find_untyped_skips_abstract(): + src = '''\ +import abc + +class C: + @abc.abstractmethod + def f(self): pass +''' + assert _find_untyped(src) == [] + + +def test_find_untyped_skips_abstract_from_imported(): + src = '''\ +from abc import abstractmethod + +class C: + @abstractmethod + def f(self): pass +''' + assert _find_untyped(src) == [] + + @contextlib.contextmanager def _dmypy(): subprocess.check_call((sys.executable, '-m', 'mypy.dmypy', 'run', '.'))