5
5
import contextlib
6
6
import types
7
7
import importlib
8
+ import inspect
9
+ import warnings
10
+ import itertools
8
11
9
- from typing import Union , Optional
12
+ from typing import Union , Optional , cast
10
13
from .abc import ResourceReader , Traversable
11
14
12
15
from ._adapters import wrap_spec
13
16
14
17
Package = Union [types .ModuleType , str ]
18
+ Anchor = Package
15
19
16
20
17
- def files (package ):
18
- # type: (Package) -> Traversable
21
+ def package_to_anchor (func ):
19
22
"""
20
- Get a Traversable resource from a package
23
+ Replace 'package' parameter as 'anchor' and warn about the change.
24
+
25
+ Other errors should fall through.
26
+
27
+ >>> files('a', 'b')
28
+ Traceback (most recent call last):
29
+ TypeError: files() takes from 0 to 1 positional arguments but 2 were given
30
+ """
31
+ undefined = object ()
32
+
33
+ @functools .wraps (func )
34
+ def wrapper (anchor = undefined , package = undefined ):
35
+ if package is not undefined :
36
+ if anchor is not undefined :
37
+ return func (anchor , package )
38
+ warnings .warn (
39
+ "First parameter to files is renamed to 'anchor'" ,
40
+ DeprecationWarning ,
41
+ stacklevel = 2 ,
42
+ )
43
+ return func (package )
44
+ elif anchor is undefined :
45
+ return func ()
46
+ return func (anchor )
47
+
48
+ return wrapper
49
+
50
+
51
+ @package_to_anchor
52
+ def files (anchor : Optional [Anchor ] = None ) -> Traversable :
53
+ """
54
+ Get a Traversable resource for an anchor.
21
55
"""
22
- return from_package (get_package ( package ))
56
+ return from_package (resolve ( anchor ))
23
57
24
58
25
- def get_resource_reader (package ):
26
- # type: (types.ModuleType) -> Optional[ResourceReader]
59
+ def get_resource_reader (package : types .ModuleType ) -> Optional [ResourceReader ]:
27
60
"""
28
61
Return the package's loader if it's a ResourceReader.
29
62
"""
@@ -39,24 +72,39 @@ def get_resource_reader(package):
39
72
return reader (spec .name ) # type: ignore
40
73
41
74
42
- def resolve (cand ):
43
- # type: (Package) -> types.ModuleType
44
- return cand if isinstance (cand , types .ModuleType ) else importlib .import_module (cand )
75
+ @functools .singledispatch
76
+ def resolve (cand : Optional [Anchor ]) -> types .ModuleType :
77
+ return cast (types .ModuleType , cand )
78
+
79
+
80
+ @resolve .register
81
+ def _ (cand : str ) -> types .ModuleType :
82
+ return importlib .import_module (cand )
83
+
45
84
85
+ @resolve .register
86
+ def _ (cand : None ) -> types .ModuleType :
87
+ return resolve (_infer_caller ().f_globals ['__name__' ])
46
88
47
- def get_package (package ):
48
- # type: (Package) -> types.ModuleType
49
- """Take a package name or module object and return the module.
50
89
51
- Raise an exception if the resolved module is not a package.
90
+ def _infer_caller ():
52
91
"""
53
- resolved = resolve (package )
54
- if wrap_spec (resolved ).submodule_search_locations is None :
55
- raise TypeError (f'{ package !r} is not a package' )
56
- return resolved
92
+ Walk the stack and find the frame of the first caller not in this module.
93
+ """
94
+
95
+ def is_this_file (frame_info ):
96
+ return frame_info .filename == __file__
97
+
98
+ def is_wrapper (frame_info ):
99
+ return frame_info .function == 'wrapper'
100
+
101
+ not_this_file = itertools .filterfalse (is_this_file , inspect .stack ())
102
+ # also exclude 'wrapper' due to singledispatch in the call stack
103
+ callers = itertools .filterfalse (is_wrapper , not_this_file )
104
+ return next (callers ).frame
57
105
58
106
59
- def from_package (package ):
107
+ def from_package (package : types . ModuleType ):
60
108
"""
61
109
Return a Traversable object for the given package.
62
110
0 commit comments