@@ -93,13 +93,30 @@ def _tempfile(
93
93
pass
94
94
95
95
96
+ def _temp_file (path ):
97
+ return _tempfile (path .read_bytes , suffix = path .name )
98
+
99
+
100
+ def _is_present_dir (path : Traversable ) -> bool :
101
+ """
102
+ Some Traversables implement ``is_dir()`` to raise an
103
+ exception (i.e. ``FileNotFoundError``) when the
104
+ directory doesn't exist. This function wraps that call
105
+ to always return a boolean and only return True
106
+ if there's a dir and it exists.
107
+ """
108
+ with contextlib .suppress (FileNotFoundError ):
109
+ return path .is_dir ()
110
+ return False
111
+
112
+
96
113
@functools .singledispatch
97
114
def as_file (path ):
98
115
"""
99
116
Given a Traversable object, return that object as a
100
117
path on the local file system in a context manager.
101
118
"""
102
- return _tempfile (path . read_bytes , suffix = path . name )
119
+ return _temp_dir (path ) if _is_present_dir ( path ) else _temp_file ( path )
103
120
104
121
105
122
@as_file .register (pathlib .Path )
@@ -109,3 +126,34 @@ def _(path):
109
126
Degenerate behavior for pathlib.Path objects.
110
127
"""
111
128
yield path
129
+
130
+
131
+ @contextlib .contextmanager
132
+ def _temp_path (dir : tempfile .TemporaryDirectory ):
133
+ """
134
+ Wrap tempfile.TemporyDirectory to return a pathlib object.
135
+ """
136
+ with dir as result :
137
+ yield pathlib .Path (result )
138
+
139
+
140
+ @contextlib .contextmanager
141
+ def _temp_dir (path ):
142
+ """
143
+ Given a traversable dir, recursively replicate the whole tree
144
+ to the file system in a context manager.
145
+ """
146
+ assert path .is_dir ()
147
+ with _temp_path (tempfile .TemporaryDirectory ()) as temp_dir :
148
+ yield _write_contents (temp_dir , path )
149
+
150
+
151
+ def _write_contents (target , source ):
152
+ child = target .joinpath (source .name )
153
+ if source .is_dir ():
154
+ child .mkdir ()
155
+ for item in source .iterdir ():
156
+ _write_contents (child , item )
157
+ else :
158
+ child .open ('wb' ).write (source .read_bytes ())
159
+ return child
0 commit comments