@@ -67,10 +67,14 @@ def from_package(package):
6767
6868
6969@contextlib .contextmanager
70- def _tempfile (reader , suffix = '' ,
71- # gh-93353: Keep a reference to call os.remove() in late Python
72- # finalization.
73- * , _os_remove = os .remove ):
70+ def _tempfile (
71+ reader ,
72+ suffix = '' ,
73+ # gh-93353: Keep a reference to call os.remove() in late Python
74+ # finalization.
75+ * ,
76+ _os_remove = os .remove ,
77+ ):
7478 # Not using tempfile.NamedTemporaryFile as it leads to deeper 'try'
7579 # blocks due to the need to close the temporary file to work on Windows
7680 # properly.
@@ -89,13 +93,30 @@ def _tempfile(reader, suffix='',
8993 pass
9094
9195
96+ def _temp_file (path ):
97+ return _tempfile (path .read_bytes , suffix = path .name )
98+
99+
100+ def _is_present_dir (path : Traversable ) -> bool :
101+ """
102+ Some Traversables implement ``is_dir()`` to raise an
103+ exception (i.e. ``FileNotFoundError``) when the
104+ directory doesn't exist. This function wraps that call
105+ to always return a boolean and only return True
106+ if there's a dir and it exists.
107+ """
108+ with contextlib .suppress (FileNotFoundError ):
109+ return path .is_dir ()
110+ return False
111+
112+
92113@functools .singledispatch
93114def as_file (path ):
94115 """
95116 Given a Traversable object, return that object as a
96117 path on the local file system in a context manager.
97118 """
98- return _tempfile (path . read_bytes , suffix = path . name )
119+ return _temp_dir (path ) if _is_present_dir ( path ) else _temp_file ( path )
99120
100121
101122@as_file .register (pathlib .Path )
@@ -105,3 +126,34 @@ def _(path):
105126 Degenerate behavior for pathlib.Path objects.
106127 """
107128 yield path
129+
130+
131+ @contextlib .contextmanager
132+ def _temp_path (dir : tempfile .TemporaryDirectory ):
133+ """
134+ Wrap tempfile.TemporyDirectory to return a pathlib object.
135+ """
136+ with dir as result :
137+ yield pathlib .Path (result )
138+
139+
140+ @contextlib .contextmanager
141+ def _temp_dir (path ):
142+ """
143+ Given a traversable dir, recursively replicate the whole tree
144+ to the file system in a context manager.
145+ """
146+ assert path .is_dir ()
147+ with _temp_path (tempfile .TemporaryDirectory ()) as temp_dir :
148+ yield _write_contents (temp_dir , path )
149+
150+
151+ def _write_contents (target , source ):
152+ child = target .joinpath (source .name )
153+ if source .is_dir ():
154+ child .mkdir ()
155+ for item in source .iterdir ():
156+ _write_contents (child , item )
157+ else :
158+ child .open ('wb' ).write (source .read_bytes ())
159+ return child
0 commit comments