Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support all tar archives in create and extract #166

Merged
merged 2 commits into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 58 additions & 51 deletions audeer/core/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def create_archive(
*,
verbose: bool = False,
):
r"""Create ZIP or TAR.GZ archive.
r"""Create ZIP or TAR archive.

If a list with ``files`` is provided,
only those files will be included to the archive.
Expand Down Expand Up @@ -129,8 +129,9 @@ def create_archive(
Raises:
FileNotFoundError: if ``root`` or a file in ``files`` is not found
NotADirectoryError: if ``root`` is not a directory
RuntimeError: if archive does not end with ``zip`` or ``tar.gz``
or a file in ``files`` is not below ``root``
RuntimeError: if archive does not end with ``zip``,
``tar``, ``tar.gz``, ``tar.bz2``, ``tar.xz``
RuntimeError: if a file in ``files`` is not below ``root``

Examples:
>>> file_a = audeer.touch("a.txt")
Expand Down Expand Up @@ -211,21 +212,31 @@ def create_archive(
)
disable = not verbose

if archive.endswith("zip"):
with zipfile.ZipFile(archive, "w", zipfile.ZIP_DEFLATED) as zf:
for file in progress_bar(files, desc=desc, disable=disable):
full_file = safe_path(root, file)
zf.write(full_file, arcname=file)
elif archive.endswith("tar.gz"):
with tarfile.open(archive, "w:gz") as tf:
for file in progress_bar(files, desc=desc, disable=disable):
full_file = safe_path(root, file)
tf.add(full_file, file)
else:
archive_handlers = {
"zip": lambda path: zipfile.ZipFile(path, "w", zipfile.ZIP_DEFLATED),
"tar": lambda path: tarfile.open(path, "w:"),
"tar.gz": lambda path: tarfile.open(path, "w:gz"),
"tar.bz2": lambda path: tarfile.open(path, "w:bz2"),
"tar.xz": lambda path: tarfile.open(path, "w:xz"),
}

# Get the archive extension
extension = next((ext for ext in archive_handlers if archive.endswith(ext)), None)
if extension is None:
supported = ", ".join(archive_handlers.keys())
raise RuntimeError(
f"You can only create a ZIP or TAR.GZ archive, " f"not {archive}"
f"Unsupported archive format. Supported formats: {supported}"
)

# Create and populate the archive
with archive_handlers[extension](archive) as archive_file:
for file in progress_bar(files, desc=desc, disable=disable):
full_file = safe_path(root, file)
if extension == "zip":
archive_file.write(full_file, arcname=file)
else:
archive_file.add(full_file, file)


def download_url(
url: str,
Expand Down Expand Up @@ -282,10 +293,10 @@ def extract_archive(
keep_archive: bool = True,
verbose: bool = False,
) -> typing.List[str]:
r"""Extract ZIP or TAR.GZ file.
r"""Extract ZIP or TAR file.

Args:
archive: path to ZIP or TAR.GZ file
archive: path to ZIP or TAR file
destination: folder where the files will be extracted.
If the folder does not exists,
it will be created
Expand All @@ -300,7 +311,7 @@ def extract_archive(
FileNotFoundError: if ``archive`` is not found
IsADirectoryError: if ``archive`` is a directory
NotADirectoryError: if ``destination`` is not a directory
RuntimeError: if ``archive`` is not a ZIP or TAR.GZ file
RuntimeError: if ``archive`` is not a ZIP or TAR file
RuntimeError: if ``archive`` is malformed

Examples:
Expand Down Expand Up @@ -351,39 +362,35 @@ def extract_archive(
)
disable = not verbose

def extract_zip(archive: str) -> list:
with zipfile.ZipFile(archive, "r") as zf:
members = zf.infolist()
for member in progress_bar(members, desc=desc, disable=disable):
zf.extract(member, destination)
return [m.filename for m in members]

def extract_tar(archive: str) -> list:
with tarfile.open(archive, "r") as tf:
members = tf.getmembers()
for member in progress_bar(members, desc=desc, disable=disable):
# In Python 3.12 the `filter` argument was introduced,
# and it will be set automatically in Python 3.14,
# see
# https://docs.python.org/3.12/library/tarfile.html#tarfile-extraction-filter
# noqa: E501
kwargs = {"numeric_owner": True}
if sys.version_info >= (3, 12): # pragma: no cover
kwargs = kwargs | {"filter": "tar"}
tf.extract(member, destination, **kwargs)
return [m.name for m in members]

try:
if archive.endswith("zip"):
with zipfile.ZipFile(archive, "r") as zf:
members = zf.infolist()
for member in progress_bar(
members,
desc=desc,
disable=disable,
):
zf.extract(member, destination)
files = [m.filename for m in members]
elif archive.endswith("tar.gz"):
with tarfile.open(archive, "r") as tf:
members = tf.getmembers()
for member in progress_bar(
members,
desc=desc,
disable=disable,
):
# In Python 3.12 the `filter` argument was introduced,
# and it will be set automatically in Python 3.14,
# see
# https://docs.python.org/3.12/library/tarfile.html#tarfile-extraction-filter
# noqa: E501
kwargs = {"numeric_owner": True}
if sys.version_info >= (3, 12): # pragma: no cover
kwargs = kwargs | {"filter": "tar"}
tf.extract(member, destination, **kwargs)
files = [m.name for m in members]
files = extract_zip(archive)
elif tarfile.is_tarfile(archive):
files = extract_tar(archive)
else:
raise RuntimeError(
f"You can only extract ZIP and TAR.GZ files, " f"not {archive}"
)
raise RuntimeError(f"You can only extract ZIP and TAR files, not {archive}")
except (EOFError, zipfile.BadZipFile, tarfile.ReadError):
raise RuntimeError(f"Broken archive: {archive}")
except (KeyboardInterrupt, Exception): # pragma: no cover
Expand All @@ -410,10 +417,10 @@ def extract_archives(
keep_archive: bool = True,
verbose: bool = False,
) -> typing.List[str]:
r"""Extract multiple ZIP or TAR.GZ archives at once.
r"""Extract multiple ZIP or TAR archives at once.

Args:
archives: paths of ZIP or TAR.GZ files
archives: paths of ZIP or TAR files
destination: folder where the files will be extracted.
If the folder does not exists,
it will be created
Expand All @@ -428,7 +435,7 @@ def extract_archives(
FileNotFoundError: if an archive is not found
IsADirectoryError: if an archive is a directory
NotADirectoryError: if ``destination`` is not a directory
RuntimeError: if an archive is not a ZIP or TAR.GZ file
RuntimeError: if an archive is not a ZIP or TAR file
RuntimeError: if an archive file is malformed

Examples:
Expand Down
47 changes: 47 additions & 0 deletions tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,15 @@ def tree(tmpdir, request):
".",
[".hidden", "file.txt", "sub/a/b/file.txt"],
),
( # tar
["file.txt", "sub/a/b/file.txt"],
".",
["sub/a/b/file.txt", "file.txt"],
"archive.tar",
"archive.tar",
".",
["sub/a/b/file.txt", "file.txt"],
),
( # tar.gz
["file.txt", "sub/a/b/file.txt"],
".",
Expand All @@ -125,6 +134,24 @@ def tree(tmpdir, request):
".",
["sub/a/b/file.txt", "file.txt"],
),
( # tar.bz2
["file.txt", "sub/a/b/file.txt"],
".",
["sub/a/b/file.txt", "file.txt"],
"archive.tar.bz2",
"archive.tar.bz2",
".",
["sub/a/b/file.txt", "file.txt"],
),
( # tar.xz
["file.txt", "sub/a/b/file.txt"],
".",
["sub/a/b/file.txt", "file.txt"],
"archive.tar.xz",
"archive.tar.xz",
".",
["sub/a/b/file.txt", "file.txt"],
),
( # root is sub folder
["sub/file.txt"],
"./sub",
Expand Down Expand Up @@ -251,6 +278,26 @@ def tree(tmpdir, request):
None,
marks=pytest.mark.xfail(raises=RuntimeError),
),
pytest.param( # invalid .rar format
["file.txt", "sub/a/b/file.txt"],
".",
["sub/a/b/file.txt", "file.txt"],
"archive.rar",
"archive.rar",
".",
None,
marks=pytest.mark.xfail(raises=RuntimeError),
),
pytest.param( # invalid .7z format
["file.txt", "sub/a/b/file.txt"],
".",
["sub/a/b/file.txt", "file.txt"],
"archive.7z",
"archive.7z",
".",
None,
marks=pytest.mark.xfail(raises=RuntimeError),
),
pytest.param( # broken archive
["archive.zip"],
".",
Expand Down
Loading