diff --git a/libarchive/test_zip.py b/libarchive/test_zip.py new file mode 100644 index 0000000..e11d1c0 --- /dev/null +++ b/libarchive/test_zip.py @@ -0,0 +1,61 @@ +import pytest +import os +import tempfile +from zip import sanitize_filename, ZipFile # Import from zip.py + +def test_sanitize_filename_safe(): + assert sanitize_filename("test.txt") == "test.txt" + +def test_sanitize_filename_traversal(): + with pytest.raises(ValueError, match="Potential directory traversal attempt detected"): + sanitize_filename("../etc/passwd") + +def test_sanitize_filename_absolute_path(): + with pytest.raises(ValueError, match="Potential directory traversal attempt detected"): + sanitize_filename("/etc/passwd") + +def create_test_zip(zip_path, filenames): + """Helper function to create a test ZIP file with given filenames.""" + import zipfile + with zipfile.ZipFile(zip_path, 'w') as zf: + for filename in filenames: + zf.writestr(filename, "Test content") + +def test_extract_safe(): + with tempfile.TemporaryDirectory() as temp_dir: + zip_path = os.path.join(temp_dir, "test.zip") + create_test_zip(zip_path, ["file1.txt", "subdir/file2.txt"]) + + with ZipFile(zip_path, 'r') as zip_ref: + zip_ref.extract("file1.txt", temp_dir) + + assert os.path.exists(os.path.join(temp_dir, "file1.txt")) + +def test_extract_traversal_attack(): + with tempfile.TemporaryDirectory() as temp_dir: + zip_path = os.path.join(temp_dir, "test.zip") + create_test_zip(zip_path, ["../evil.txt"]) + + with ZipFile(zip_path, 'r') as zip_ref: + with pytest.raises(ValueError, match="Potential directory traversal attempt detected"): + zip_ref.extract("../evil.txt", temp_dir) + +def test_extractall_safe(): + with tempfile.TemporaryDirectory() as temp_dir: + zip_path = os.path.join(temp_dir, "test.zip") + create_test_zip(zip_path, ["file1.txt", "subdir/file2.txt"]) + + with ZipFile(zip_path, 'r') as zip_ref: + zip_ref.extractall(temp_dir) + + assert os.path.exists(os.path.join(temp_dir, "file1.txt")) + assert os.path.exists(os.path.join(temp_dir, "subdir", "file2.txt")) + +def test_extractall_with_traversal_attack(): + with tempfile.TemporaryDirectory() as temp_dir: + zip_path = os.path.join(temp_dir, "test.zip") + create_test_zip(zip_path, ["file1.txt", "../evil.txt"]) + + with ZipFile(zip_path, 'r') as zip_ref: + with pytest.raises(ValueError, match="Potential directory traversal attempt detected"): + zip_ref.extractall(temp_dir) diff --git a/libarchive/zip.py b/libarchive/zip.py index a200aa7..eb5e81c 100644 --- a/libarchive/zip.py +++ b/libarchive/zip.py @@ -1,4 +1,5 @@ -import os, time +import os +import time from libarchive import is_archive, Entry, SeekableArchive, _libarchive from zipfile import ZIP_STORED, ZIP_DEFLATED @@ -7,6 +8,12 @@ def is_zipfile(filename): return is_archive(filename, formats=('zip',)) +def sanitize_filename(filename, base_path=os.getcwd()): + abs_path = os.path.abspath(os.path.join(base_path, filename)) + if not abs_path.startswith(os.path.abspath(base_path) + os.sep): + raise ValueError("Invalid filename: Potential directory traversal attempt detected.") + return os.path.basename(abs_path) # Ensures only filename is extracted + class ZipEntry(Entry): def __init__(self, *args, **kwargs): super(ZipEntry, self).__init__(*args, **kwargs) @@ -60,30 +67,26 @@ def _set_missing(self, value): CRC = property(_get_missing, _set_missing) compress_size = property(_get_missing, _set_missing) -# encryption is one of (traditional = zipcrypt, aes128, aes256) + class ZipFile(SeekableArchive): def __init__(self, f, mode='r', compression=ZIP_DEFLATED, allowZip64=False, password=None, - encryption=None): + encryption=None): self.compression = compression self.encryption = encryption super(ZipFile, self).__init__( f, mode=mode, format='zip', entry_class=ZipEntry, encoding='CP437', password=password ) - getinfo = SeekableArchive.getentry def set_initial_options(self): if self.mode == 'w' and self.compression == ZIP_STORED: - # Disable compression for writing. _libarchive.archive_write_set_format_option(self._a, "zip", "compression", "store") - + if self.mode == 'w' and self.password: if not self.encryption: self.encryption = "traditional" _libarchive.archive_write_set_format_option(self._a, "zip", "encryption", self.encryption) - - def namelist(self): return list(self.iterpaths()) @@ -104,7 +107,8 @@ def extract(self, name, path=None, pwd=None): self.add_passphrase(pwd) if not path: path = os.getcwd() - return self.readpath(name, os.path.join(path, name)) + sanitized_name = sanitize_filename(name) + return self.readpath(sanitized_name, os.path.join(path, sanitized_name)) def extractall(self, path, names=None, pwd=None): if pwd: @@ -113,7 +117,8 @@ def extractall(self, path, names=None, pwd=None): names = self.namelist() if names: for name in names: - self.extract(name, path) + sanitized_name = sanitize_filename(name, path) + self.extract(sanitized_name, path) def read(self, name, pwd=None): if pwd: diff --git a/tests.py b/tests.py index 9380e63..5766bf8 100644 --- a/tests.py +++ b/tests.py @@ -26,7 +26,12 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -import os, unittest, tempfile, random, string, sys +import os +import unittest +import tempfile +import random +import string +import sys import zipfile import io @@ -309,6 +314,7 @@ def test_read_with_wrong_password(self): self.assertRaises(RuntimeError, z.read, ITEM_NAME) z.close() + class TestProtectedWriting(unittest.TestCase): def setUp(self): create_protected_zip()