@@ -117,6 +117,11 @@ def __init__(self, removal_method):
117
117
118
118
def visit_FunctionDef (self , node ):
119
119
transform = node
120
+
121
+ # Don't modify Python's special functions (like __init__, __str__, etc.)
122
+ if node .name .startswith ('__' ) and node .name .endswith ('__' ):
123
+ return node # Leave special methods unchanged
124
+
120
125
# Check if the first statement is a docstring
121
126
if (
122
127
node .body
@@ -159,6 +164,17 @@ def visit_AsyncFunctionDef(self, node):
159
164
# Handle async functions the same way as regular functions
160
165
return self .visit_FunctionDef (node )
161
166
167
+ def visit_ClassDef (self , node ):
168
+ # Visit all the body of the class to apply transformations
169
+ self .generic_visit (node )
170
+
171
+ # Check if the class is left with only a `Pass` or is entirely empty
172
+ if not node .body or all (isinstance (child , ast .Pass ) for child in node .body ):
173
+ # If class has no methods or only pass, return None to remove it
174
+ return None
175
+
176
+ return node
177
+
162
178
163
179
def clone_repo (clone_url , clone_dir , commit ) -> None :
164
180
"""Clone repo into a temporary directory
@@ -192,8 +208,44 @@ def remove_local_repo(clone_dir) -> None:
192
208
os .system (f"rm -rf { clone_dir } " )
193
209
logger .info (f"Cleaned up the cloned repository at { clone_dir } " )
194
210
211
+ def collect_test_files (directory ):
212
+ # List to store all the filenames
213
+ test_files = []
214
+ subdirs = []
215
+
216
+ # Walk through the directory
217
+ for root , dirs , files in os .walk (directory ):
218
+ if root .endswith ("/" ):
219
+ root = root [:- 1 ]
220
+ # Check if 'test' is part of the folder name
221
+ if 'test' in os .path .basename (root ).lower () or os .path .basename (root ) in subdirs :
222
+ for file in files :
223
+ # Process only Python files
224
+ if file .endswith ('.py' ):
225
+ file_path = os .path .join (root , file )
226
+ test_files .append (file_path )
227
+ for d in dirs :
228
+ subdirs .append (d )
229
+
230
+ return test_files
231
+
232
+
233
+ def collect_python_files (directory ):
234
+ # List to store all the .py filenames
235
+ python_files = []
236
+
237
+ # Walk through the directory recursively
238
+ for root , _ , files in os .walk (directory ):
239
+ for file in files :
240
+ # Check if the file ends with '.py'
241
+ if file .endswith ('.py' ):
242
+ file_path = os .path .join (root , file )
243
+ python_files .append (file_path )
244
+
245
+ return python_files
246
+
195
247
196
- def _find_files_to_edit (base_dir : str ) -> list [str ]:
248
+ def _find_files_to_edit (base_dir : str , src_dir : str , test_dir : str ) -> list [str ]:
197
249
"""Identify files to remove content by heuristics.
198
250
We assume source code is under [lib]/[lib] or [lib]/src.
199
251
We exclude test code. This function would not work
@@ -208,50 +260,23 @@ def _find_files_to_edit(base_dir: str) -> list[str]:
208
260
files (list[str]): a list of files to be edited.
209
261
210
262
"""
263
+ files = collect_python_files (os .path .join (base_dir , src_dir ))
264
+ test_files = collect_test_files (os .path .join (base_dir , test_dir ))
265
+ files = list (set (files ) - set (test_files ))
211
266
212
- def find_helper (n ):
213
- path_src = os .path .join (base_dir , n , "**" , "*.py" )
214
- files = glob (path_src , recursive = True )
215
- path_src = os .path .join (base_dir , "src" , n , "**" , "*.py" )
216
- files += glob (path_src , recursive = True )
217
- path_src = os .path .join (base_dir , "src" , "**" , "*.py" )
218
- files += glob (path_src , recursive = True )
219
- path_test = os .path .join (base_dir , n , "**" , "test*" , "**" , "*.py" )
220
- test_files = glob (path_test , recursive = True )
221
- files = list (set (files ) - set (test_files ))
222
- return files
223
-
224
- name = os .path .basename (base_dir )
225
- files = find_helper (name )
226
- if name != name .lower ():
227
- files += find_helper (name .lower ())
228
- elif name .startswith ("pyjwt" ):
229
- files += find_helper ("jwt" )
230
- elif name == "tlslite-ng" :
231
- files += find_helper ("tlslite" )
232
- elif name == "dnspython" :
233
- files += find_helper ("dns" )
234
- elif name == "web3.py" :
235
- files += find_helper ("web3" )
236
- elif name == "python-rsa" :
237
- files += find_helper ("rsa" )
238
- elif name == "more-itertools" :
239
- files += find_helper ("more_itertools" )
240
- elif name == "imbalanced-learn" :
241
- files += find_helper ("imblearn" )
242
- elif name == "python-progressbar" :
243
- files += find_helper ("progressbar" )
244
- elif name == "filesystem_spec" :
245
- files += find_helper ("fsspec" )
246
267
# don't edit __init__ files
247
268
files = [f for f in files if "__init__" not in f ]
269
+ # don't edit __main__ files
270
+ files = [f for f in files if "__main__" not in f ]
248
271
# don't edit confest.py files
249
272
files = [f for f in files if "conftest.py" not in f ]
250
273
return files
251
274
252
275
253
276
def generate_base_commit (
254
277
repo : Repo ,
278
+ src_dir : str ,
279
+ test_dir : str ,
255
280
spec_url : str ,
256
281
base_branch_name : str = "commit0" ,
257
282
removal : str = "all" ,
@@ -288,7 +313,7 @@ def generate_base_commit(
288
313
else :
289
314
logger .info ("Creating commit 0" )
290
315
repo .local_repo .git .checkout ("-b" , branch_name )
291
- files = _find_files_to_edit (repo .clone_dir )
316
+ files = _find_files_to_edit (repo .clone_dir , src_dir , test_dir )
292
317
for f in files :
293
318
tree = astor .parse_file (f )
294
319
tree = RemoveMethod (removal ).visit (tree )
0 commit comments