diff --git a/libarchive/__init__.py b/libarchive/__init__.py index 424e7a4..a4c250f 100644 --- a/libarchive/__init__.py +++ b/libarchive/__init__.py @@ -725,3 +725,45 @@ def readstream(self, member): self.seek(entry) self._stream = EntryReadStream(self, entry.size) return self._stream + + @classmethod + def sanitize_filename(cls, filename: str) -> str: + """ + Method for sanitizing provided file names. Logic behind the method is borrowed from pyzipper project. + + Parameters + ---------- + filename: str + name, which will be sanitized + Returns + ------- + sanitized filename, that should be secure to join using os.path.join method + """ + arcname: str = filename.replace('/', os.path.sep) + # Replace all alternative path separators + if os.path.altsep is not None: + arcname = arcname.replace(os.path.altsep, os.path.sep) + arcname = os.path.splitdrive(arcname)[1] + invalid_path_parts = ('', os.path.curdir, os.path.pardir) + arcname = os.path.sep.join(x for x in arcname.split(os.path.sep) + if x not in invalid_path_parts) + if os.path.sep == '\\': + # filter illegal characters on Windows + arcname = cls._sanitize_windows_name(arcname, os.path.sep) + return arcname + + + @classmethod + def _sanitize_windows_name(cls, arcname, pathsep): + """Replace bad characters and remove trailing dots from parts.""" + table = cls._windows_illegal_name_trans_table + if not table: + illegal = ':<>|"?*' + table = str.maketrans(illegal, '_' * len(illegal)) + cls._windows_illegal_name_trans_table = table + arcname = arcname.translate(table) + # remove trailing dots + arcname = (x.rstrip('.') for x in arcname.split(pathsep)) + # rejoin, removing empty parts. + arcname = pathsep.join(x for x in arcname if x) + return arcname \ No newline at end of file diff --git a/libarchive/zip.py b/libarchive/zip.py index a200aa7..6f5a8e6 100644 --- a/libarchive/zip.py +++ b/libarchive/zip.py @@ -99,21 +99,58 @@ def open(self, name, mode, pwd=None): else: return self.writestream(name) - def extract(self, name, path=None, pwd=None): + def extract(self, name: str, path=None, pwd=None, withoutpath: bool = True): + """ + Method for extracting sigle file in the zip archive. + + Parameters + ---------- + name: str + name of file inside the archive + path: str + target directory, where the archive should be extracted + pwd: str + password to the archive being extracted + withoutpath: bool + boolean flag to determine whether the name of extracted file + should remain same (False) or should be sanitized (True) + """ if pwd: self.add_passphrase(pwd) if not path: path = os.getcwd() - return self.readpath(name, os.path.join(path, name)) - - def extractall(self, path, names=None, pwd=None): + if withoutpath: + arcname = self.sanitize_filename(filename=name) + targetpath = os.path.join(path, arcname) + targetpath = os.path.normpath(targetpath) + else: + targetpath = os.path.join(path, name) + return self.readpath(name, targetpath) + + def extractall(self, path, names=None, pwd=None, withoutpath: bool = True): + """ + Method for extracting all the files provided in names array. In case names are not provided, they are + obtained by namelist method. + + Parameters + ---------- + path: str + target directory, where the archive should be extracted + names: list + array of names of files to be extracted + pwd: str + password to the archive being extracted + withoutpath: bool + boolean flag to determine whether the name of extracted file + should remain same (False) or should be sanitized (True) + """ if pwd: self.add_passphrase(pwd) if not names: names = self.namelist() if names: for name in names: - self.extract(name, path) + self.extract(name, path, withoutpath=withoutpath) def read(self, name, pwd=None): if pwd: