Skip to content

Commit

Permalink
[hot-fix] Malicious repo can overwrite any file on disk (#1429)
Browse files Browse the repository at this point in the history
* Add regression test

* add protections + fix tests

* fix widnwso test

* fix widnwso test

* FIX resolving path without following symlinks

* increase pause time in test

* fix windows test

* Update src/huggingface_hub/file_download.py
  • Loading branch information
Wauplin committed Apr 6, 2023
1 parent 58d8242 commit 0848f80
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 9 deletions.
6 changes: 2 additions & 4 deletions src/huggingface_hub/_login.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,15 +121,13 @@ def interpreter_login() -> None:
For more details, see [`login`].
"""
print( # docstyle-ignore
"""
print("""
_| _| _| _| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _|_|_|_| _|_| _|_|_| _|_|_|_|
_| _| _| _| _| _| _| _|_| _| _| _| _| _| _| _|
_|_|_|_| _| _| _| _|_| _| _|_| _| _| _| _| _| _|_| _|_|_| _|_|_|_| _| _|_|_|
_| _| _| _| _| _| _| _| _| _| _|_| _| _| _| _| _| _| _|
_| _| _|_| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _| _| _| _|_|_| _|_|_|_|
"""
)
""") # docstyle-ignore
if HfFolder.get_token() is not None:
print(
" A token is already saved on your machine. Run `huggingface-cli"
Expand Down
32 changes: 29 additions & 3 deletions src/huggingface_hub/file_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -1132,11 +1132,17 @@ def hf_hub_download(

# cross platform transcription of filename, to be used as a local file path.
relative_filename = os.path.join(*filename.split("/"))
if os.name == "nt":
if relative_filename.startswith("..\\") or "\\..\\" in relative_filename:
raise ValueError(
f"Invalid filename: cannot handle filename '{relative_filename}' on Windows. Please ask the repository"
" owner to rename this file."
)

# if user provides a commit_hash and they already have the file on disk,
# shortcut everything.
if REGEX_COMMIT_HASH.match(revision):
pointer_path = os.path.join(storage_folder, "snapshots", revision, relative_filename)
pointer_path = _get_pointer_path(storage_folder, revision, relative_filename)
if os.path.exists(pointer_path):
if local_dir is not None:
return _to_local_dir(pointer_path, local_dir, relative_filename, use_symlinks=local_dir_use_symlinks)
Expand Down Expand Up @@ -1231,7 +1237,7 @@ def hf_hub_download(

# Return pointer file if exists
if commit_hash is not None:
pointer_path = os.path.join(storage_folder, "snapshots", commit_hash, relative_filename)
pointer_path = _get_pointer_path(storage_folder, commit_hash, relative_filename)
if os.path.exists(pointer_path):
if local_dir is not None:
return _to_local_dir(
Expand Down Expand Up @@ -1260,7 +1266,7 @@ def hf_hub_download(
assert etag is not None, "etag must have been retrieved from server"
assert commit_hash is not None, "commit_hash must have been retrieved from server"
blob_path = os.path.join(storage_folder, "blobs", etag)
pointer_path = os.path.join(storage_folder, "snapshots", commit_hash, relative_filename)
pointer_path = _get_pointer_path(storage_folder, commit_hash, relative_filename)

os.makedirs(os.path.dirname(blob_path), exist_ok=True)
os.makedirs(os.path.dirname(pointer_path), exist_ok=True)
Expand Down Expand Up @@ -1549,14 +1555,34 @@ def _chmod_and_replace(src: str, dst: str) -> None:
os.replace(src, dst)


def _get_pointer_path(storage_folder: str, revision: str, relative_filename: str) -> str:
# Using `os.path.abspath` instead of `Path.resolve()` to avoid resolving symlinks
snapshot_path = os.path.join(storage_folder, "snapshots")
pointer_path = os.path.join(snapshot_path, revision, relative_filename)
if Path(os.path.abspath(snapshot_path)) not in Path(os.path.abspath(pointer_path)).parents:
raise ValueError(
"Invalid pointer path: cannot create pointer path in snapshot folder if"
f" `storage_folder='{storage_folder}'`, `revision='{revision}'` and"
f" `relative_filename='{relative_filename}'`."
)
return pointer_path


def _to_local_dir(
path: str, local_dir: str, relative_filename: str, use_symlinks: Union[bool, Literal["auto"]]
) -> str:
"""Place a file in a local dir (different than cache_dir).
Either symlink to blob file in cache or duplicate file depending on `use_symlinks` and file size.
"""
# Using `os.path.abspath` instead of `Path.resolve()` to avoid resolving symlinks
local_dir_filepath = os.path.join(local_dir, relative_filename)
if Path(os.path.abspath(local_dir)) not in Path(os.path.abspath(local_dir_filepath)).parents:
raise ValueError(
f"Cannot copy file '{relative_filename}' to local dir '{local_dir}': file would not be in the local"
" directory."
)

os.makedirs(os.path.dirname(local_dir_filepath), exist_ok=True)
real_blob_path = os.path.realpath(path)

Expand Down
65 changes: 65 additions & 0 deletions tests/test_file_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
from huggingface_hub.file_download import (
_CACHED_NO_EXIST,
_create_symlink,
_get_pointer_path,
_to_local_dir,
cached_download,
filename_to_url,
get_hf_file_metadata,
Expand Down Expand Up @@ -737,6 +739,69 @@ def test_hf_hub_download_on_awful_subfolder_and_filename(self):
self.assertTrue(local_path.endswith(self.filepath))


@pytest.mark.usefixtures("fx_cache_dir")
class TestHfHubDownloadRelativePaths(unittest.TestCase):
"""Regression test for HackerOne report 1928845.
Issue was that any file outside of the local dir could be overwritten (Windows only).
In the end, multiple protections have been added to prevent this (..\\ in filename forbidden on Windows, always check
the filepath is in local_dir/snapshot_dir).
"""

cache_dir: Path

@classmethod
def setUpClass(cls):
cls.api = HfApi(endpoint=ENDPOINT_STAGING, token=TOKEN)
cls.repo_id = cls.api.create_repo(repo_id=repo_name()).repo_id
cls.api.upload_file(path_or_fileobj=b"content", path_in_repo="..\\ddd", repo_id=cls.repo_id)
cls.api.upload_file(path_or_fileobj=b"content", path_in_repo="folder/..\\..\\..\\file", repo_id=cls.repo_id)

@classmethod
def tearDownClass(cls) -> None:
cls.api.delete_repo(repo_id=cls.repo_id)

@xfail_on_windows(reason="Windows paths cannot start with '..\\'.", raises=ValueError)
def test_download_file_in_cache_dir(self) -> None:
hf_hub_download(self.repo_id, "..\\ddd", cache_dir=self.cache_dir)

@xfail_on_windows(reason="Windows paths cannot start with '..\\'.", raises=ValueError)
def test_download_file_to_local_dir(self) -> None:
with SoftTemporaryDirectory() as local_dir:
hf_hub_download(self.repo_id, "..\\ddd", cache_dir=self.cache_dir, local_dir=local_dir)

@xfail_on_windows(reason="Windows paths cannot contain '\\..\\'.", raises=ValueError)
def test_download_folder_file_in_cache_dir(self) -> None:
hf_hub_download(self.repo_id, "folder/..\\..\\..\\file", cache_dir=self.cache_dir)

@xfail_on_windows(reason="Windows paths cannot contain '\\..\\'.", raises=ValueError)
def test_download_folder_file_to_local_dir(self) -> None:
with SoftTemporaryDirectory() as local_dir:
hf_hub_download(self.repo_id, "folder/..\\..\\..\\file", cache_dir=self.cache_dir, local_dir=local_dir)

def test_get_pointer_path_and_valid_relative_filename(self) -> None:
# Cannot happen because of other protections, but just in case.
self.assertEqual(
_get_pointer_path("path/to/storage", "abcdef", "path/to/file.txt"),
os.path.join("path/to/storage", "snapshots", "abcdef", "path/to/file.txt"),
)

def test_get_pointer_path_but_invalid_relative_filename(self) -> None:
# Cannot happen because of other protections, but just in case.
relative_filename = "folder\\..\\..\\..\\file.txt" if os.name == "nt" else "folder/../../../file.txt"
with self.assertRaises(ValueError):
_get_pointer_path("path/to/storage", "abcdef", relative_filename)

def test_to_local_dir_but_invalid_relative_filename(self) -> None:
# Cannot happen because of other protections, but just in case.
relative_filename = "folder\\..\\..\\..\\file.txt" if os.name == "nt" else "folder/../../../file.txt"
with self.assertRaises(ValueError):
_to_local_dir(
"path/to/file_to_copy", "path/to/local/dir", relative_filename=relative_filename, use_symlinks=False
)


class CreateSymlinkTest(unittest.TestCase):
@unittest.skipIf(os.name == "nt", "No symlinks on Windows")
@patch("huggingface_hub.file_download.are_symlinks_supported")
Expand Down
5 changes: 3 additions & 2 deletions tests/test_hf_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2439,8 +2439,9 @@ def test_space_runtime(self) -> None:
def test_pause_and_restart_space(self) -> None:
runtime_after_pause = self.api.pause_space(self.repo_id)
self.assertEqual(runtime_after_pause.stage, SpaceStage.PAUSED)

runtime_after_restart = self.api.restart_space(self.repo_id)
self.api.restart_space(self.repo_id)
time.sleep(1.0)
runtime_after_restart = self.api.get_space_runtime(self.repo_id)
self.assertIn(runtime_after_restart.stage, (SpaceStage.BUILDING, SpaceStage.RUNNING_BUILDING))


Expand Down

0 comments on commit 0848f80

Please sign in to comment.