diff --git a/audeer/core/io.py b/audeer/core/io.py index 5d93777..0cbb6c7 100644 --- a/audeer/core/io.py +++ b/audeer/core/io.py @@ -99,7 +99,7 @@ def create_archive( *, verbose: bool = False, ): - r"""Create ZIP or TAR.GZ archive. + r"""Create ZIP or TAR archive. If a list with ``files`` is provided, only those files will be included to the archive. @@ -129,8 +129,9 @@ def create_archive( Raises: FileNotFoundError: if ``root`` or a file in ``files`` is not found NotADirectoryError: if ``root`` is not a directory - RuntimeError: if archive does not end with ``zip`` or ``tar.gz`` - or a file in ``files`` is not below ``root`` + RuntimeError: if archive does not end with ``zip``, + ``tar``, ``tar.gz``, ``tar.bz2``, ``tar.xz`` + RuntimeError: if a file in ``files`` is not below ``root`` Examples: >>> file_a = audeer.touch("a.txt") @@ -211,21 +212,31 @@ def create_archive( ) disable = not verbose - if archive.endswith("zip"): - with zipfile.ZipFile(archive, "w", zipfile.ZIP_DEFLATED) as zf: - for file in progress_bar(files, desc=desc, disable=disable): - full_file = safe_path(root, file) - zf.write(full_file, arcname=file) - elif archive.endswith("tar.gz"): - with tarfile.open(archive, "w:gz") as tf: - for file in progress_bar(files, desc=desc, disable=disable): - full_file = safe_path(root, file) - tf.add(full_file, file) - else: + archive_handlers = { + "zip": lambda path: zipfile.ZipFile(path, "w", zipfile.ZIP_DEFLATED), + "tar": lambda path: tarfile.open(path, "w:"), + "tar.gz": lambda path: tarfile.open(path, "w:gz"), + "tar.bz2": lambda path: tarfile.open(path, "w:bz2"), + "tar.xz": lambda path: tarfile.open(path, "w:xz"), + } + + # Get the archive extension + extension = next((ext for ext in archive_handlers if archive.endswith(ext)), None) + if extension is None: + supported = ", ".join(archive_handlers.keys()) raise RuntimeError( - f"You can only create a ZIP or TAR.GZ archive, " f"not {archive}" + f"Unsupported archive format. Supported formats: {supported}" ) + # Create and populate the archive + with archive_handlers[extension](archive) as archive_file: + for file in progress_bar(files, desc=desc, disable=disable): + full_file = safe_path(root, file) + if extension == "zip": + archive_file.write(full_file, arcname=file) + else: + archive_file.add(full_file, file) + def download_url( url: str, @@ -282,10 +293,10 @@ def extract_archive( keep_archive: bool = True, verbose: bool = False, ) -> typing.List[str]: - r"""Extract ZIP or TAR.GZ file. + r"""Extract ZIP or TAR file. Args: - archive: path to ZIP or TAR.GZ file + archive: path to ZIP or TAR file destination: folder where the files will be extracted. If the folder does not exists, it will be created @@ -300,7 +311,7 @@ def extract_archive( FileNotFoundError: if ``archive`` is not found IsADirectoryError: if ``archive`` is a directory NotADirectoryError: if ``destination`` is not a directory - RuntimeError: if ``archive`` is not a ZIP or TAR.GZ file + RuntimeError: if ``archive`` is not a ZIP or TAR file RuntimeError: if ``archive`` is malformed Examples: @@ -351,39 +362,35 @@ def extract_archive( ) disable = not verbose + def extract_zip(archive: str) -> list: + with zipfile.ZipFile(archive, "r") as zf: + members = zf.infolist() + for member in progress_bar(members, desc=desc, disable=disable): + zf.extract(member, destination) + return [m.filename for m in members] + + def extract_tar(archive: str) -> list: + with tarfile.open(archive, "r") as tf: + members = tf.getmembers() + for member in progress_bar(members, desc=desc, disable=disable): + # In Python 3.12 the `filter` argument was introduced, + # and it will be set automatically in Python 3.14, + # see + # https://docs.python.org/3.12/library/tarfile.html#tarfile-extraction-filter + # noqa: E501 + kwargs = {"numeric_owner": True} + if sys.version_info >= (3, 12): # pragma: no cover + kwargs = kwargs | {"filter": "tar"} + tf.extract(member, destination, **kwargs) + return [m.name for m in members] + try: if archive.endswith("zip"): - with zipfile.ZipFile(archive, "r") as zf: - members = zf.infolist() - for member in progress_bar( - members, - desc=desc, - disable=disable, - ): - zf.extract(member, destination) - files = [m.filename for m in members] - elif archive.endswith("tar.gz"): - with tarfile.open(archive, "r") as tf: - members = tf.getmembers() - for member in progress_bar( - members, - desc=desc, - disable=disable, - ): - # In Python 3.12 the `filter` argument was introduced, - # and it will be set automatically in Python 3.14, - # see - # https://docs.python.org/3.12/library/tarfile.html#tarfile-extraction-filter - # noqa: E501 - kwargs = {"numeric_owner": True} - if sys.version_info >= (3, 12): # pragma: no cover - kwargs = kwargs | {"filter": "tar"} - tf.extract(member, destination, **kwargs) - files = [m.name for m in members] + files = extract_zip(archive) + elif tarfile.is_tarfile(archive): + files = extract_tar(archive) else: - raise RuntimeError( - f"You can only extract ZIP and TAR.GZ files, " f"not {archive}" - ) + raise RuntimeError(f"You can only extract ZIP and TAR files, not {archive}") except (EOFError, zipfile.BadZipFile, tarfile.ReadError): raise RuntimeError(f"Broken archive: {archive}") except (KeyboardInterrupt, Exception): # pragma: no cover @@ -410,10 +417,10 @@ def extract_archives( keep_archive: bool = True, verbose: bool = False, ) -> typing.List[str]: - r"""Extract multiple ZIP or TAR.GZ archives at once. + r"""Extract multiple ZIP or TAR archives at once. Args: - archives: paths of ZIP or TAR.GZ files + archives: paths of ZIP or TAR files destination: folder where the files will be extracted. If the folder does not exists, it will be created @@ -428,7 +435,7 @@ def extract_archives( FileNotFoundError: if an archive is not found IsADirectoryError: if an archive is a directory NotADirectoryError: if ``destination`` is not a directory - RuntimeError: if an archive is not a ZIP or TAR.GZ file + RuntimeError: if an archive is not a ZIP or TAR file RuntimeError: if an archive file is malformed Examples: diff --git a/tests/test_io.py b/tests/test_io.py index 4f80503..1ee22d1 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -116,6 +116,15 @@ def tree(tmpdir, request): ".", [".hidden", "file.txt", "sub/a/b/file.txt"], ), + ( # tar + ["file.txt", "sub/a/b/file.txt"], + ".", + ["sub/a/b/file.txt", "file.txt"], + "archive.tar", + "archive.tar", + ".", + ["sub/a/b/file.txt", "file.txt"], + ), ( # tar.gz ["file.txt", "sub/a/b/file.txt"], ".", @@ -125,6 +134,24 @@ def tree(tmpdir, request): ".", ["sub/a/b/file.txt", "file.txt"], ), + ( # tar.bz2 + ["file.txt", "sub/a/b/file.txt"], + ".", + ["sub/a/b/file.txt", "file.txt"], + "archive.tar.bz2", + "archive.tar.bz2", + ".", + ["sub/a/b/file.txt", "file.txt"], + ), + ( # tar.xz + ["file.txt", "sub/a/b/file.txt"], + ".", + ["sub/a/b/file.txt", "file.txt"], + "archive.tar.xz", + "archive.tar.xz", + ".", + ["sub/a/b/file.txt", "file.txt"], + ), ( # root is sub folder ["sub/file.txt"], "./sub", @@ -251,6 +278,26 @@ def tree(tmpdir, request): None, marks=pytest.mark.xfail(raises=RuntimeError), ), + pytest.param( # invalid .rar format + ["file.txt", "sub/a/b/file.txt"], + ".", + ["sub/a/b/file.txt", "file.txt"], + "archive.rar", + "archive.rar", + ".", + None, + marks=pytest.mark.xfail(raises=RuntimeError), + ), + pytest.param( # invalid .7z format + ["file.txt", "sub/a/b/file.txt"], + ".", + ["sub/a/b/file.txt", "file.txt"], + "archive.7z", + "archive.7z", + ".", + None, + marks=pytest.mark.xfail(raises=RuntimeError), + ), pytest.param( # broken archive ["archive.zip"], ".",