diff --git a/deptry/import_parser.py b/deptry/import_parser.py index 775555c0..1dcb06ed 100644 --- a/deptry/import_parser.py +++ b/deptry/import_parser.py @@ -5,6 +5,8 @@ from deptry.notebook_import_extractor import NotebookImportExtractor +RECURSION_TYPES = [ast.If, ast.Try, ast.ExceptHandler, ast.FunctionDef, ast.ClassDef] + class ImportParser: """ @@ -57,9 +59,10 @@ def _get_import_nodes_from(self, root: Union[ast.Module, ast.If]): are defined within if/else or try/except statements. In that case, the ast.Import or ast.ImportFrom node is a child of an ast.If/Try/ExceptHandler node. """ + imports = [] for node in ast.iter_child_nodes(root): - if isinstance(node, ast.If) or isinstance(node, ast.Try) or isinstance(node, ast.ExceptHandler): + if any([isinstance(node, recursion_type) for recursion_type in RECURSION_TYPES]): imports += self._get_import_nodes_from(node) elif isinstance(node, ast.Import) or isinstance(node, ast.ImportFrom): imports += [node] diff --git a/tests/test_import_parser.py b/tests/test_import_parser.py index 622ee228..015dbae2 100644 --- a/tests/test_import_parser.py +++ b/tests/test_import_parser.py @@ -44,3 +44,28 @@ def test_import_parser_tryexcept(): """ ) assert set(imported_modules) == set(["numpy", "pandas", "click", "logging"]) + + +def test_import_parser_func(): + imported_modules = ImportParser().get_imported_modules_from_str( + """ +import pandas as pd +from numpy import random +def func(): + import click +""" + ) + assert set(imported_modules) == set(["numpy", "pandas", "click"]) + + +def test_import_parser_class(): + imported_modules = ImportParser().get_imported_modules_from_str( + """ +import pandas as pd +from numpy import random +class MyClass: + def __init__(self): + import click +""" + ) + assert set(imported_modules) == set(["numpy", "pandas", "click"])