Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow a package to resolve its own resources simply #260

Merged
merged 4 commits into from
Oct 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ v5.10.0
files was renamed from 'package' to 'anchor', with a
compatibility shim for those passing by keyword.

* #259: ``files`` no longer requires the anchor to be
specified and can infer the anchor from the caller's scope
(defaults to the caller's module).

v5.9.0
======

Expand Down
33 changes: 27 additions & 6 deletions importlib_resources/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
import contextlib
import types
import importlib
import inspect
import warnings
import itertools

from typing import Union, Optional, cast
from .abc import ResourceReader, Traversable
Expand All @@ -22,12 +24,9 @@ def package_to_anchor(func):

Other errors should fall through.

>>> files()
Traceback (most recent call last):
TypeError: files() missing 1 required positional argument: 'anchor'
>>> files('a', 'b')
Traceback (most recent call last):
TypeError: files() takes 1 positional argument but 2 were given
TypeError: files() takes from 0 to 1 positional arguments but 2 were given
"""
undefined = object()

Expand All @@ -50,7 +49,7 @@ def wrapper(anchor=undefined, package=undefined):


@package_to_anchor
def files(anchor: Anchor) -> Traversable:
def files(anchor: Optional[Anchor] = None) -> Traversable:
"""
Get a Traversable resource for an anchor.
"""
Expand All @@ -74,7 +73,7 @@ def get_resource_reader(package: types.ModuleType) -> Optional[ResourceReader]:


@functools.singledispatch
def resolve(cand: Anchor) -> types.ModuleType:
def resolve(cand: Optional[Anchor]) -> types.ModuleType:
return cast(types.ModuleType, cand)


Expand All @@ -83,6 +82,28 @@ def _(cand: str) -> types.ModuleType:
return importlib.import_module(cand)


@resolve.register
def _(cand: None) -> types.ModuleType:
return resolve(_infer_caller().f_globals['__name__'])


def _infer_caller():
"""
Walk the stack and find the frame of the first caller not in this module.
"""

def is_this_file(frame_info):
return frame_info.filename == __file__

def is_wrapper(frame_info):
return frame_info.function == 'wrapper'

not_this_file = itertools.filterfalse(is_this_file, inspect.stack())
# also exclude 'wrapper' due to singledispatch in the call stack
callers = itertools.filterfalse(is_wrapper, not_this_file)
return next(callers).frame


def from_package(package: types.ModuleType):
"""
Return a Traversable object for the given package.
Expand Down
26 changes: 25 additions & 1 deletion importlib_resources/tests/test_files.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import typing
import textwrap
import unittest
import warnings
import importlib
import contextlib

import importlib_resources as resources
Expand Down Expand Up @@ -61,14 +63,16 @@ def setUp(self):
self.data = namespacedata01


class ModulesFilesTests(unittest.TestCase):
class SiteDir:
def setUp(self):
self.fixtures = contextlib.ExitStack()
self.addCleanup(self.fixtures.close)
self.site_dir = self.fixtures.enter_context(os_helper.temp_dir())
self.fixtures.enter_context(import_helper.DirsOnSysPath(self.site_dir))
self.fixtures.enter_context(import_helper.CleanImport())


class ModulesFilesTests(SiteDir, unittest.TestCase):
def test_module_resources(self):
"""
A module can have resources found adjacent to the module.
Expand All @@ -84,5 +88,25 @@ def test_module_resources(self):
assert actual == spec['res.txt']


class ImplicitContextFilesTests(SiteDir, unittest.TestCase):
def test_implicit_files(self):
"""
Without any parameter, files() will infer the location as the caller.
"""
spec = {
'somepkg': {
'__init__.py': textwrap.dedent(
"""
import importlib_resources as res
val = res.files().joinpath('res.txt').read_text()
"""
),
'res.txt': 'resources are the best',
},
}
_path.build(spec, self.site_dir)
assert importlib.import_module('somepkg').val == 'resources are the best'


if __name__ == '__main__':
unittest.main()