Skip to content

Commit

Permalink
fix: Remove memory leak caused by file upload. #1584
Browse files Browse the repository at this point in the history
  • Loading branch information
mturoci committed Aug 17, 2022
1 parent 6f850a0 commit c3a5c25
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 3 deletions.
45 changes: 42 additions & 3 deletions py/h2o_wave/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,6 +673,36 @@ def upload(self, files: List[str]) -> List[str]:
return json.loads(res.text)['files']
raise ServiceError(f'Upload failed (code={res.status_code}): {res.text}')

def upload_dir(self, directory: str) -> str:
"""
Upload local files to the site.
Args:
files: A list of file paths of the files to be uploaded.
Returns:
A list of remote URLs for the uploaded files, in order.
"""
if not os.path.isdir(directory):
raise ValueError(f'{directory} is not a directory.')

upload_files = []
file_handles: List[BufferedReader] = []
for f in _get_files_in_directory(directory, []):
file_handle = open(f, 'rb')
upload_files.append(('files', (os.path.relpath(f, directory), file_handle)))
file_handles.append(file_handle)

res = self._http.post(f'{_config.hub_address}_f/', headers={'Wave-Directory-Upload': "True"}, files=upload_files)

for h in file_handles:
h.close()

if res.status_code == 200:
return json.loads(res.text)['files']
raise ServiceError(f'Upload failed (code={res.status_code}): {res.text}')


def download(self, url: str, path: str) -> str:
"""
Download a file from the site.
Expand Down Expand Up @@ -806,9 +836,18 @@ async def upload_dir(self, directory: str) -> str:
if not os.path.isdir(directory):
raise ValueError(f'{directory} is not a directory.')

files = _get_files_in_directory(directory, [])
res = await self._http.post(f'{_config.hub_address}_f/', headers={'Wave-Directory-Upload': "True"},
files=[('files', (os.path.relpath(f, directory), open(f, 'rb'))) for f in files])
upload_files = []
file_handles: List[BufferedReader] = []
for f in _get_files_in_directory(directory, []):
file_handle = open(f, 'rb')
upload_files.append(('files', (os.path.relpath(f, directory), file_handle)))
file_handles.append(file_handle)

res = await self._http.post(f'{_config.hub_address}_f/', headers={'Wave-Directory-Upload': "True"}, files=upload_files)

for h in file_handles:
h.close()

if res.status_code == 200:
return json.loads(res.text)['files']
raise ServiceError(f'Upload failed (code={res.status_code}): {res.text}')
Expand Down
8 changes: 8 additions & 0 deletions py/tests/test_python_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,3 +451,11 @@ def test_multipart_server(self):
file_handle.close()
site.unlink('test_stream')
assert len(p) > 0

def test_upload_dir(self):
upload_path, = site.upload_dir(os.path.join('tests', 'test_folder'))
base_url = os.getenv('H2O_WAVE_BASE_URL', '/')
download_path = site.download(f'{base_url}{upload_path}test.txt', 'test.txt')
txt = _read_file(download_path)
os.remove(download_path)
assert len(txt) > 0

0 comments on commit c3a5c25

Please sign in to comment.