Skip to content

Commit

Permalink
Merge pull request #131 from mgxd/fix/zip-extract-race-condition
Browse files Browse the repository at this point in the history
FIX: Avoid directory clobber during zip extraction
  • Loading branch information
mgxd authored Apr 5, 2024
2 parents cb9566f + a9f7f5b commit 1d47d66
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 32 deletions.
16 changes: 2 additions & 14 deletions templateflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,20 +37,8 @@
del version
del PackageNotFoundError

import os

from . import api
from .conf import TF_USE_DATALAD, update

if not TF_USE_DATALAD and os.getenv('TEMPLATEFLOW_AUTOUPDATE', '1') not in (
'false',
'off',
'0',
'no',
'n',
):
# trigger skeleton autoupdate
update(local=True, overwrite=False, silent=True)
from templateflow import api
from templateflow.conf import update

__all__ = [
'__copyright__',
Expand Down
3 changes: 2 additions & 1 deletion templateflow/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

from templateflow import __package__, api
from templateflow._loader import Loader as _Loader
from templateflow.conf import TF_HOME, TF_USE_DATALAD
from templateflow.conf import TF_HOME, TF_USE_DATALAD, TF_AUTOUPDATE

load_data = _Loader(__package__)

Expand Down Expand Up @@ -91,6 +91,7 @@ def config():
TEMPLATEFLOW_HOME={TF_HOME}
TEMPLATEFLOW_USE_DATALAD={'on' if TF_USE_DATALAD else 'off'}
TEMPLATEFLOW_AUTOUPDATE={'on' if TF_AUTOUPDATE else 'off'}
""")


Expand Down
35 changes: 27 additions & 8 deletions templateflow/conf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,36 @@

load_data = Loader(__package__)


def _env_to_bool(envvar: str, default: bool) -> bool:
"""Check for environment variable switches and convert to booleans."""
switches = {
'on': {'true', 'on', '1', 'yes', 'y'},
'off': {'false', 'off', '0', 'no', 'n'},
}

val = getenv(envvar, default)
if isinstance(val, str):
if val.lower() in switches['on']:
return True
elif val.lower() in switches['off']:
return False
else:
# TODO: Create templateflow logger
print(
f'{envvar} is set to unknown value <{val}>. '
f'Falling back to default value <{default}>'
)
return default
return bool(val)


TF_DEFAULT_HOME = Path.home() / '.cache' / 'templateflow'
TF_HOME = Path(getenv('TEMPLATEFLOW_HOME', str(TF_DEFAULT_HOME)))
TF_GITHUB_SOURCE = 'https://github.com/templateflow/templateflow.git'
TF_S3_ROOT = 'https://templateflow.s3.amazonaws.com'
TF_USE_DATALAD = getenv('TEMPLATEFLOW_USE_DATALAD', 'false').lower() in (
'true',
'on',
'1',
'yes',
'y',
)
TF_USE_DATALAD = _env_to_bool('TEMPLATEFLOW_USE_DATALAD', False)
TF_AUTOUPDATE = _env_to_bool('TEMPLATEFLOW_AUTOUPDATE', True)
TF_CACHED = True
TF_GET_TIMEOUT = 10

Expand Down Expand Up @@ -50,7 +69,7 @@ def _init_cache():
if not TF_USE_DATALAD:
from ._s3 import update as _update_s3

_update_s3(TF_HOME, local=True, overwrite=True)
_update_s3(TF_HOME, local=True, overwrite=TF_AUTOUPDATE, silent=True)


_init_cache()
Expand Down
27 changes: 18 additions & 9 deletions templateflow/conf/_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,23 +80,32 @@ def _update_skeleton(skel_file, dest, overwrite=True, silent=False):
dest = Path(dest)
dest.mkdir(exist_ok=True, parents=True)
with ZipFile(skel_file, 'r') as zipref:
allfiles = sorted(zipref.namelist())

if overwrite:
zipref.extractall(str(dest))
return True
newfiles = allfiles
else:
current_files = [s.relative_to(dest) for s in dest.glob('**/*')]
existing = sorted({'%s/' % s.parent for s in current_files}) + [
str(s) for s in current_files
]
newfiles = sorted(set(allfiles) - set(existing))

allfiles = zipref.namelist()
current_files = [s.relative_to(dest) for s in dest.glob('**/*')]
existing = sorted({'%s/' % s.parent for s in current_files}) + [
str(s) for s in current_files
]
newfiles = sorted(set(allfiles) - set(existing))
if newfiles:
if not silent:
print(
'Updating TEMPLATEFLOW_HOME using S3. Adding:\n%s'
% '\n'.join(newfiles)
)
zipref.extractall(str(dest), members=newfiles)
for fl in newfiles:
localpath = dest / fl
if localpath.exists():
continue
try:
zipref.extract(fl, path=dest)
except FileExistsError:
# If there is a conflict, do not clobber
pass
return True
if not silent:
print('TEMPLATEFLOW_HOME directory (S3 type) was up-to-date.')
Expand Down
27 changes: 27 additions & 0 deletions templateflow/tests/test_multiproc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import os
from concurrent.futures import ProcessPoolExecutor

import pytest

CPUs = os.cpu_count() or 1


def _update():
from templateflow.conf import update

update(local=False, overwrite=True, silent=True)
return True


@pytest.mark.skipif(CPUs < 2, reason='At least 2 CPUs are required')
def test_multi_proc_update(tmp_path, monkeypatch):
tf_home = tmp_path / 'tf_home'
monkeypatch.setenv('TEMPLATEFLOW_HOME', str(tf_home))

futs = []
with ProcessPoolExecutor(max_workers=2) as executor:
for _ in range(2):
futs.append(executor.submit(_update))

for fut in futs:
assert fut.result()

0 comments on commit 1d47d66

Please sign in to comment.