Skip to content

Commit ef21c53

Browse files
committed
Implement caller inference to allow files() to be called without any parameter.
1 parent b4e0cd4 commit ef21c53

File tree

2 files changed

+26
-3
lines changed

2 files changed

+26
-3
lines changed

importlib_resources/_common.py

+26-2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import contextlib
66
import types
77
import importlib
8+
import inspect
9+
import itertools
810

911
from typing import Union, Optional, cast
1012
from .abc import ResourceReader, Traversable
@@ -15,7 +17,7 @@
1517
Anchor = Package
1618

1719

18-
def files(package: Anchor) -> Traversable:
20+
def files(package: Optional[Anchor] = None) -> Traversable:
1921
"""
2022
Get a Traversable resource for an anchor.
2123
"""
@@ -39,7 +41,7 @@ def get_resource_reader(package: types.ModuleType) -> Optional[ResourceReader]:
3941

4042

4143
@functools.singledispatch
42-
def resolve(cand: Anchor) -> types.ModuleType:
44+
def resolve(cand: Optional[Anchor]) -> types.ModuleType:
4345
return cast(types.ModuleType, cand)
4446

4547

@@ -48,6 +50,28 @@ def _(cand: str) -> types.ModuleType:
4850
return importlib.import_module(cand)
4951

5052

53+
@resolve.register
54+
def _(cand: None) -> types.ModuleType:
55+
return resolve(_infer_caller().f_globals['__name__'])
56+
57+
58+
def _infer_caller():
59+
"""
60+
Walk the stack and find the frame of the first caller not in this module.
61+
"""
62+
63+
def is_this_file(frame_info):
64+
return frame_info.filename == __file__
65+
66+
def is_wrapper(frame_info):
67+
return frame_info.function == 'wrapper'
68+
69+
not_this_file = itertools.filterfalse(is_this_file, inspect.stack())
70+
# also exclude 'wrapper' due to singledispatch in the call stack
71+
callers = itertools.filterfalse(is_wrapper, not_this_file)
72+
return next(callers).frame
73+
74+
5175
def from_package(package: types.ModuleType):
5276
"""
5377
Return a Traversable object for the given package.

importlib_resources/tests/test_files.py

-1
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,6 @@ def test_module_resources(self):
7373

7474

7575
class ImplicitContextFilesTests(SiteDir, unittest.TestCase):
76-
@__import__('pytest').mark.xfail(reason="work in progress")
7776
def test_implicit_files(self):
7877
"""
7978
Without any parameter, files() will infer the location as the caller.

0 commit comments

Comments
 (0)