Skip to content

Commit cc1423f

Browse files
committed
Implement caller inference to allow files() to be called without any parameter.
1 parent 38f789d commit cc1423f

File tree

2 files changed

+26
-3
lines changed

2 files changed

+26
-3
lines changed

importlib_resources/_common.py

+26-2
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55
import contextlib
66
import types
77
import importlib
8+
import inspect
89
import warnings
10+
import itertools
911

1012
from typing import Union, Optional, cast
1113
from .abc import ResourceReader, Traversable
@@ -50,7 +52,7 @@ def wrapper(anchor=undefined, package=undefined):
5052

5153

5254
@package_to_anchor
53-
def files(anchor: Anchor) -> Traversable:
55+
def files(package: Optional[Anchor] = None) -> Traversable:
5456
"""
5557
Get a Traversable resource for an anchor.
5658
"""
@@ -74,7 +76,7 @@ def get_resource_reader(package: types.ModuleType) -> Optional[ResourceReader]:
7476

7577

7678
@functools.singledispatch
77-
def resolve(cand: Anchor) -> types.ModuleType:
79+
def resolve(cand: Optional[Anchor]) -> types.ModuleType:
7880
return cast(types.ModuleType, cand)
7981

8082

@@ -83,6 +85,28 @@ def _(cand: str) -> types.ModuleType:
8385
return importlib.import_module(cand)
8486

8587

88+
@resolve.register
89+
def _(cand: None) -> types.ModuleType:
90+
return resolve(_infer_caller().f_globals['__name__'])
91+
92+
93+
def _infer_caller():
94+
"""
95+
Walk the stack and find the frame of the first caller not in this module.
96+
"""
97+
98+
def is_this_file(frame_info):
99+
return frame_info.filename == __file__
100+
101+
def is_wrapper(frame_info):
102+
return frame_info.function == 'wrapper'
103+
104+
not_this_file = itertools.filterfalse(is_this_file, inspect.stack())
105+
# also exclude 'wrapper' due to singledispatch in the call stack
106+
callers = itertools.filterfalse(is_wrapper, not_this_file)
107+
return next(callers).frame
108+
109+
86110
def from_package(package: types.ModuleType):
87111
"""
88112
Return a Traversable object for the given package.

importlib_resources/tests/test_files.py

-1
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,6 @@ def test_module_resources(self):
8989

9090

9191
class ImplicitContextFilesTests(SiteDir, unittest.TestCase):
92-
@__import__('pytest').mark.xfail(reason="work in progress")
9392
def test_implicit_files(self):
9493
"""
9594
Without any parameter, files() will infer the location as the caller.

0 commit comments

Comments
 (0)