1
- import os , time
1
+ import os
2
+ import time
2
3
from libarchive import is_archive , Entry , SeekableArchive , _libarchive
3
4
from zipfile import ZIP_STORED , ZIP_DEFLATED
4
5
@@ -7,6 +8,12 @@ def is_zipfile(filename):
7
8
return is_archive (filename , formats = ('zip' ,))
8
9
9
10
11
+ def sanitize_filename (filename , base_path = os .getcwd ()):
12
+ abs_path = os .path .abspath (os .path .join (base_path , filename ))
13
+ if not abs_path .startswith (os .path .abspath (base_path ) + os .sep ):
14
+ raise ValueError ("Invalid filename: Potential directory traversal attempt detected." )
15
+ return os .path .basename (abs_path ) # Ensures only filename is extracted
16
+
10
17
class ZipEntry (Entry ):
11
18
def __init__ (self , * args , ** kwargs ):
12
19
super (ZipEntry , self ).__init__ (* args , ** kwargs )
@@ -60,30 +67,26 @@ def _set_missing(self, value):
60
67
CRC = property (_get_missing , _set_missing )
61
68
compress_size = property (_get_missing , _set_missing )
62
69
63
- # encryption is one of (traditional = zipcrypt, aes128, aes256)
70
+
64
71
class ZipFile (SeekableArchive ):
65
72
def __init__ (self , f , mode = 'r' , compression = ZIP_DEFLATED , allowZip64 = False , password = None ,
66
- encryption = None ):
73
+ encryption = None ):
67
74
self .compression = compression
68
75
self .encryption = encryption
69
76
super (ZipFile , self ).__init__ (
70
77
f , mode = mode , format = 'zip' , entry_class = ZipEntry , encoding = 'CP437' , password = password
71
78
)
72
-
73
79
74
80
getinfo = SeekableArchive .getentry
75
81
76
82
def set_initial_options (self ):
77
83
if self .mode == 'w' and self .compression == ZIP_STORED :
78
- # Disable compression for writing.
79
84
_libarchive .archive_write_set_format_option (self ._a , "zip" , "compression" , "store" )
80
-
85
+
81
86
if self .mode == 'w' and self .password :
82
87
if not self .encryption :
83
88
self .encryption = "traditional"
84
89
_libarchive .archive_write_set_format_option (self ._a , "zip" , "encryption" , self .encryption )
85
-
86
-
87
90
88
91
def namelist (self ):
89
92
return list (self .iterpaths ())
@@ -104,7 +107,8 @@ def extract(self, name, path=None, pwd=None):
104
107
self .add_passphrase (pwd )
105
108
if not path :
106
109
path = os .getcwd ()
107
- return self .readpath (name , os .path .join (path , name ))
110
+ sanitized_name = sanitize_filename (name )
111
+ return self .readpath (sanitized_name , os .path .join (path , sanitized_name ))
108
112
109
113
def extractall (self , path , names = None , pwd = None ):
110
114
if pwd :
@@ -113,7 +117,8 @@ def extractall(self, path, names=None, pwd=None):
113
117
names = self .namelist ()
114
118
if names :
115
119
for name in names :
116
- self .extract (name , path )
120
+ sanitized_name = sanitize_filename (name , path )
121
+ self .extract (sanitized_name , path )
117
122
118
123
def read (self , name , pwd = None ):
119
124
if pwd :
0 commit comments