diff --git a/tests/unit/test_identify.py b/tests/unit/test_identify.py index 08a7e7269..8515d459e 100644 --- a/tests/unit/test_identify.py +++ b/tests/unit/test_identify.py @@ -5,10 +5,49 @@ def imports_in_code(code: str, **kwargs) -> List[identify.Import]: - return list(identify.imports(StringIO(code, **kwargs))) + return list(identify.imports(StringIO(code), **kwargs)) -def test_yield_edge_cases(): +def test_top_only(): + imports_in_function = """ +import abc + +def xyz(): + import defg +""" + assert len(imports_in_code(imports_in_function)) == 2 + assert len(imports_in_code(imports_in_function, top_only=True)) == 1 + + imports_after_class = """ +import abc + +class MyObject: + pass + +import defg +""" + assert len(imports_in_code(imports_after_class)) == 2 + assert len(imports_in_code(imports_after_class, top_only=True)) == 1 + + +def test_top_doc_string(): + assert ( + len( + imports_in_code( + ''' +#! /bin/bash import x +"""import abc +from y import z +""" +import abc +''' + ) + ) + == 1 + ) + + +def test_yield_and_raise_edge_cases(): assert not imports_in_code( """ raise SomeException("Blah") \\ @@ -137,3 +176,25 @@ def generator_function(): from \\ """ ) + assert ( + len( + imports_in_code( + """ +def generator_function(): + ( + ( + (((( + ((((( + (( + ((( + raise \\ + from \\ + import c + + import abc + import xyz +""" + ) + ) + == 2 + )