Skip to content

Commit 904acc3

Browse files
committed
improved removal
1 parent 2fa9091 commit 904acc3

File tree

3 files changed

+64
-37
lines changed

3 files changed

+64
-37
lines changed

build_dataset.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ def create_instance(
3535
# extract_test_names needs to be called on the environment set up commit
3636
base_commit = generate_base_commit(
3737
repo,
38+
raw_info["src_dir"],
39+
raw_info["test_dir"],
3840
spec_url=raw_info["specification"],
3941
base_branch_name=base_branch_name,
4042
removal=removal,

remove_repos.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def main() -> None:
2424
repos = list_repos(organization)
2525

2626
for repo in repos:
27-
if ".github.io" in repo["name"] or "commit0" in repo["name"]:
27+
if ".github.io" in repo["name"] or "commit0" in repo["name"] or "analysis" in repo["name"] or "build_dataset" in repo["name"]:
2828
continue
2929
else:
3030
repo_name = repo["name"]

utils.py

Lines changed: 61 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,11 @@ def __init__(self, removal_method):
117117

118118
def visit_FunctionDef(self, node):
119119
transform = node
120+
121+
# Don't modify Python's special functions (like __init__, __str__, etc.)
122+
if node.name.startswith('__') and node.name.endswith('__'):
123+
return node # Leave special methods unchanged
124+
120125
# Check if the first statement is a docstring
121126
if (
122127
node.body
@@ -159,6 +164,17 @@ def visit_AsyncFunctionDef(self, node):
159164
# Handle async functions the same way as regular functions
160165
return self.visit_FunctionDef(node)
161166

167+
def visit_ClassDef(self, node):
168+
# Visit all the body of the class to apply transformations
169+
self.generic_visit(node)
170+
171+
# Check if the class is left with only a `Pass` or is entirely empty
172+
if not node.body or all(isinstance(child, ast.Pass) for child in node.body):
173+
# If class has no methods or only pass, return None to remove it
174+
return None
175+
176+
return node
177+
162178

163179
def clone_repo(clone_url, clone_dir, commit) -> None:
164180
"""Clone repo into a temporary directory
@@ -192,8 +208,44 @@ def remove_local_repo(clone_dir) -> None:
192208
os.system(f"rm -rf {clone_dir}")
193209
logger.info(f"Cleaned up the cloned repository at {clone_dir}")
194210

211+
def collect_test_files(directory):
212+
# List to store all the filenames
213+
test_files = []
214+
subdirs = []
215+
216+
# Walk through the directory
217+
for root, dirs, files in os.walk(directory):
218+
if root.endswith("/"):
219+
root = root[:-1]
220+
# Check if 'test' is part of the folder name
221+
if 'test' in os.path.basename(root).lower() or os.path.basename(root) in subdirs:
222+
for file in files:
223+
# Process only Python files
224+
if file.endswith('.py'):
225+
file_path = os.path.join(root, file)
226+
test_files.append(file_path)
227+
for d in dirs:
228+
subdirs.append(d)
229+
230+
return test_files
231+
232+
233+
def collect_python_files(directory):
234+
# List to store all the .py filenames
235+
python_files = []
236+
237+
# Walk through the directory recursively
238+
for root, _, files in os.walk(directory):
239+
for file in files:
240+
# Check if the file ends with '.py'
241+
if file.endswith('.py'):
242+
file_path = os.path.join(root, file)
243+
python_files.append(file_path)
244+
245+
return python_files
246+
195247

196-
def _find_files_to_edit(base_dir: str) -> list[str]:
248+
def _find_files_to_edit(base_dir: str, src_dir: str, test_dir: str) -> list[str]:
197249
"""Identify files to remove content by heuristics.
198250
We assume source code is under [lib]/[lib] or [lib]/src.
199251
We exclude test code. This function would not work
@@ -208,50 +260,23 @@ def _find_files_to_edit(base_dir: str) -> list[str]:
208260
files (list[str]): a list of files to be edited.
209261
210262
"""
263+
files = collect_python_files(os.path.join(base_dir, src_dir))
264+
test_files = collect_test_files(os.path.join(base_dir, test_dir))
265+
files = list(set(files) - set(test_files))
211266

212-
def find_helper(n):
213-
path_src = os.path.join(base_dir, n, "**", "*.py")
214-
files = glob(path_src, recursive=True)
215-
path_src = os.path.join(base_dir, "src", n, "**", "*.py")
216-
files += glob(path_src, recursive=True)
217-
path_src = os.path.join(base_dir, "src", "**", "*.py")
218-
files += glob(path_src, recursive=True)
219-
path_test = os.path.join(base_dir, n, "**", "test*", "**", "*.py")
220-
test_files = glob(path_test, recursive=True)
221-
files = list(set(files) - set(test_files))
222-
return files
223-
224-
name = os.path.basename(base_dir)
225-
files = find_helper(name)
226-
if name != name.lower():
227-
files += find_helper(name.lower())
228-
elif name.startswith("pyjwt"):
229-
files += find_helper("jwt")
230-
elif name == "tlslite-ng":
231-
files += find_helper("tlslite")
232-
elif name == "dnspython":
233-
files += find_helper("dns")
234-
elif name == "web3.py":
235-
files += find_helper("web3")
236-
elif name == "python-rsa":
237-
files += find_helper("rsa")
238-
elif name == "more-itertools":
239-
files += find_helper("more_itertools")
240-
elif name == "imbalanced-learn":
241-
files += find_helper("imblearn")
242-
elif name == "python-progressbar":
243-
files += find_helper("progressbar")
244-
elif name == "filesystem_spec":
245-
files += find_helper("fsspec")
246267
# don't edit __init__ files
247268
files = [f for f in files if "__init__" not in f]
269+
# don't edit __main__ files
270+
files = [f for f in files if "__main__" not in f]
248271
# don't edit confest.py files
249272
files = [f for f in files if "conftest.py" not in f]
250273
return files
251274

252275

253276
def generate_base_commit(
254277
repo: Repo,
278+
src_dir: str,
279+
test_dir: str,
255280
spec_url: str,
256281
base_branch_name: str = "commit0",
257282
removal: str = "all",
@@ -288,7 +313,7 @@ def generate_base_commit(
288313
else:
289314
logger.info("Creating commit 0")
290315
repo.local_repo.git.checkout("-b", branch_name)
291-
files = _find_files_to_edit(repo.clone_dir)
316+
files = _find_files_to_edit(repo.clone_dir, src_dir, test_dir)
292317
for f in files:
293318
tree = astor.parse_file(f)
294319
tree = RemoveMethod(removal).visit(tree)

0 commit comments

Comments
 (0)