13
13
from transformers import AutoTokenizer
14
14
15
15
from commit0 .harness .constants import SPLIT
16
+ from commit0 .harness .get_pytest_ids import main as get_tests
16
17
from commit0 .harness .utils import clone_repo
17
18
from commit0 .cli import write_commit0_config_file
18
19
19
20
import logging
21
+ from typing import Any , NoReturn
20
22
21
23
logging .basicConfig (
22
24
level = logging .INFO , format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
26
28
analysis_files_path = "/share/rush/commit0_analysis_temp"
27
29
28
30
29
- def get_pytest_info (path_to_logs , repo_name , branch_name ):
31
+ def get_pytest_info (
32
+ path_to_logs : str , repo_name : str , branch_name : str
33
+ ) -> dict [str , dict [str , Any ]]:
30
34
pytest_info = {}
31
35
for pytest_hash in os .listdir (path_to_logs ):
36
+ if not os .path .exists (os .path .join (path_to_logs , pytest_hash , "eval.sh" )):
37
+ continue
32
38
eval_script = open (os .path .join (path_to_logs , pytest_hash , "eval.sh" )).read ()
33
39
testname = re .search (r"([\S]+) > test_output" , eval_script ).group (1 )
34
40
patch_diff = open (os .path .join (path_to_logs , pytest_hash , "patch.diff" )).read ()
@@ -84,19 +90,19 @@ def get_pytest_info(path_to_logs, repo_name, branch_name):
84
90
"failure_string" : failure_string ,
85
91
"duration" : duration ,
86
92
}
87
- return pytest_info
93
+ return pytest_info if len ( pytest_info ) else "Could not evaluate"
88
94
89
95
90
- def get_coverage_info (path_to_logs , repo_name , branch_name ) :
96
+ def get_coverage_info (path_to_logs : str , repo_name : str , branch_name : str ) -> Any :
91
97
raise NotImplementedError
92
98
93
99
94
100
def get_blank_repo_metrics (
95
- blank_source_code_folder ,
96
- spec_filename ,
101
+ blank_source_code_folder : str ,
102
+ spec_filename : str ,
97
103
tokenizer ,
98
104
code_file_filter = lambda filename : filename ,
99
- ):
105
+ ) -> dict [ str , Any ] :
100
106
blank_repo_metrics = {
101
107
"functions_to_edit" : [],
102
108
}
@@ -164,7 +170,7 @@ def get_blank_repo_metrics(
164
170
165
171
166
172
leaderboard_header = """\n \n ## Leaderboard ({split})
167
- | Name | Repos Resolved (/{num_repos}) | Total Tests Passed (/{total_num_tests}) | Test Duration (s) | Date | Analysis | Github |
173
+ | Name | Repos Resolved (/{num_repos}) | Avg. pass rate | Test Duration (s) | Date | Analysis | Github |
168
174
|------|:-------------------------:|:--------------------:|:--------------------:|:----------:|----|----| """
169
175
170
176
submission_table_header = """# Submission Name: **{display_name}** (split: {split})
@@ -178,33 +184,44 @@ def get_blank_repo_metrics(
178
184
"""
179
185
180
186
181
- def render_mds (overwrite_previous , subfolder = "docs" ):
187
+ def render_mds (overwrite_previous : bool , subfolder : str = "docs" ) -> NoReturn :
182
188
leaderboard = {}
183
189
184
190
split_to_total_tests = {
185
191
"lite" : 3628 ,
186
192
"all" : 140926 ,
187
193
} # hard-coded to skip running it later
188
- for split in tqdm . tqdm ( ["lite" , "all" ]) :
194
+ for split in ["lite" , "all" ]:
189
195
num_repos = len (SPLIT [split ])
190
196
# total_num_tests = 0
191
197
# for repo_name in SPLIT[split]:
192
198
# repo_tests = subprocess.run(['commit0', 'get-tests', repo_name], capture_output=True, text=True).stdout.strip()
193
199
# total_num_tests += len(repo_tests.splitlines())
194
- leaderboard [split ] = leaderboard_header .format (
195
- split = split ,
196
- num_repos = num_repos ,
197
- total_num_tests = split_to_total_tests [split ],
200
+ leaderboard [split ] = []
201
+ leaderboard [split ].append (
202
+ (
203
+ split_to_total_tests [split ] + 1 ,
204
+ leaderboard_header .format (
205
+ split = split ,
206
+ num_repos = num_repos ,
207
+ total_num_tests = split_to_total_tests [split ],
208
+ ),
209
+ )
198
210
)
199
211
200
212
for org_path in tqdm .tqdm (glob .glob (os .path .join (analysis_files_path , "*" ))):
201
213
org_name = os .path .basename (org_path )
202
214
if org_name in {"blank" , "repos" , "submission_repos" }:
203
215
continue
204
216
for branch_path in glob .glob (os .path .join (org_path , "*.json" )):
205
- cum_tests_passed = 0
217
+ evaluate_numbers = []
218
+ lite_evaluate_numbers = []
219
+ # cum_tests_passed = 0
206
220
repos_resolved = 0
207
221
total_duration = 0.0
222
+ # lite_cum_tests_passed = 0
223
+ lite_repos_resolved = 0
224
+ lite_total_duration = 0.0
208
225
branch_metrics = json .load (open (branch_path ))
209
226
submission_info = branch_metrics ["submission_info" ]
210
227
split = submission_info ["split" ]
@@ -234,7 +251,7 @@ def render_mds(overwrite_previous, subfolder="docs"):
234
251
subfolder , f"analysis_{ org_name } _{ branch_name } _{ repo_name } .md"
235
252
)
236
253
if isinstance (repo_pytest_results , str ):
237
- submission_repo_page = f"# **{ display_name } **: { repo_name } \n \n ## Failed to clone \n \n { repo_pytest_results } "
254
+ submission_repo_page = f"# **{ display_name } **: { repo_name } \n \n ## Failed\n \n { repo_pytest_results } "
238
255
org_branch_repo_filepath = os .path .join (
239
256
subfolder , f"analysis_{ org_name } _{ branch_name } _{ repo_name } .md"
240
257
)
@@ -246,7 +263,7 @@ def render_mds(overwrite_previous, subfolder="docs"):
246
263
submission_page = submission_table_header .format (
247
264
display_name = display_name , split = split
248
265
) + (
249
- f"\n | { repo_name } | No; Failed to clone. | - | - | "
266
+ f"\n | { repo_name } | No; { repo_pytest_results } | - | - | "
250
267
f"[Analysis](/{ f'analysis_{ org_name } _{ branch_name } _{ repo_name } ' } ) | "
251
268
f"[Github]({ github_hyperlink } ) |"
252
269
)
@@ -267,13 +284,23 @@ def render_mds(overwrite_previous, subfolder="docs"):
267
284
)
268
285
pytest_details = "Pytest failed"
269
286
duration = "Failed."
287
+ evaluate_numbers .append (0.0 )
288
+ if split == "all" and repo_name in SPLIT ["lite" ]:
289
+ lite_evaluate_numbers .append (0.0 )
270
290
else :
271
291
resolved = False
272
292
if "passed" in pytest_info ["summary" ]:
273
293
if "skipped" in pytest_info ["summary" ]:
274
- resolved = pytest_info ["summary" ]["passed" ] + pytest_info ["summary" ]["skipped" ] == pytest_info ["summary" ]["total" ]
294
+ resolved = (
295
+ pytest_info ["summary" ]["passed" ]
296
+ + pytest_info ["summary" ]["skipped" ]
297
+ == pytest_info ["summary" ]["total" ]
298
+ )
275
299
else :
276
- resolved = pytest_info ["summary" ]["passed" ] == pytest_info ["summary" ]["total" ]
300
+ resolved = (
301
+ pytest_info ["summary" ]["passed" ]
302
+ == pytest_info ["summary" ]["total" ]
303
+ )
277
304
if write_submission :
278
305
submission_repo_page += pytest_summary_table_header .format (
279
306
pytest_group = pytest_group
@@ -295,9 +322,21 @@ def render_mds(overwrite_previous, subfolder="docs"):
295
322
f"### { shortened_testname } \n \n <details><summary> <pre>{ shortened_testname } "
296
323
f"</pre></summary><pre>\n { failure ['failure_string' ]} \n </pre>\n </details>\n "
297
324
)
298
- cum_tests_passed += pytest_info ["summary" ]["passed" ]
325
+ # cum_tests_passed += pytest_info["summary"]["passed"]
326
+ num_tests = len (get_tests (repo_name , verbose = 0 ))
327
+ evaluate_numbers .append (
328
+ pytest_info ["summary" ]["passed" ] / num_tests
329
+ )
299
330
total_duration += pytest_info ["duration" ]
300
331
repos_resolved += int (resolved )
332
+ if split == "all" and repo_name in SPLIT ["lite" ]:
333
+ lite_evaluate_numbers .append (
334
+ pytest_info ["summary" ]["passed" ] / num_tests
335
+ )
336
+ # lite_cum_tests_passed += pytest_info["summary"]["passed"]
337
+ lite_total_duration += pytest_info ["duration" ]
338
+ lite_repos_resolved += int (resolved )
339
+
301
340
if write_submission :
302
341
pytest_details = f"{ pytest_info ['summary' ]['passed' ]} / { pytest_info ['summary' ]['total' ]} "
303
342
duration = f"{ pytest_info ['duration' ]:.2f} "
@@ -322,22 +361,46 @@ def render_mds(overwrite_previous, subfolder="docs"):
322
361
wf .write (back_button + "\n " + submission_page )
323
362
analysis_link = f"[Analysis](/{ f'analysis_{ org_name } _{ branch_name } ' } )"
324
363
github_link = f"[Github]({ project_page_link } )"
325
- leaderboard [split ] += (
326
- f"\n |{ display_name } |"
327
- f"{ repos_resolved } |"
328
- f"{ cum_tests_passed } |"
329
- f"{ total_duration :.2f} |"
330
- f"{ submission_date } |"
331
- f"{ analysis_link } |"
332
- f"{ github_link } |"
364
+ avg_pass_rate = sum (evaluate_numbers ) / len (evaluate_numbers )
365
+ leaderboard [split ].append (
366
+ (
367
+ avg_pass_rate * 100 ,
368
+ f"\n |{ display_name } |"
369
+ f"{ repos_resolved } |"
370
+ f"{ avg_pass_rate * 100 :.2f} %|"
371
+ f"{ total_duration :.2f} |"
372
+ f"{ submission_date } |"
373
+ f"{ analysis_link } |"
374
+ f"{ github_link } |" ,
375
+ )
333
376
)
377
+ if (split == "all" ) and ("Reference (Gold)" not in display_name ):
378
+ avg_lite_pass_rate = sum (lite_evaluate_numbers ) / len (
379
+ lite_evaluate_numbers
380
+ )
381
+ leaderboard ["lite" ].append (
382
+ (
383
+ avg_lite_pass_rate * 100 ,
384
+ f"\n |{ display_name } (subset of `all`)|"
385
+ f"{ lite_repos_resolved } |"
386
+ f"{ avg_lite_pass_rate * 100 :.2f} %|"
387
+ f"{ lite_total_duration :.2f} |"
388
+ f"{ submission_date } |"
389
+ f"{ analysis_link } |"
390
+ f"{ github_link } |" ,
391
+ )
392
+ )
334
393
335
394
leaderboard_filepath = os .path .join (subfolder , "analysis.md" )
395
+ for split in ["lite" , "all" ]:
396
+ leaderboard [split ] = sorted (leaderboard [split ], key = lambda elt : - elt [0 ])
336
397
with open (leaderboard_filepath , "w" ) as wf :
337
- wf .write (leaderboard ["lite" ] + "\n \n " + leaderboard ["all" ])
398
+ lite_leaderboard_string = "" .join (string for (_ , string ) in leaderboard ["lite" ])
399
+ all_leaderboard_string = "" .join (string for (_ , string ) in leaderboard ["all" ])
400
+ wf .write (lite_leaderboard_string + "\n \n " + all_leaderboard_string )
338
401
339
402
340
- def get_args ():
403
+ def get_args () -> argparse . Namespace :
341
404
parser = argparse .ArgumentParser ()
342
405
parser .add_argument (
343
406
"--do_setup" , action = "store_true" , help = "Run commit0 setup with specified split"
@@ -366,14 +429,14 @@ def get_args():
366
429
parser .add_argument (
367
430
"--overwrite_previous_eval" ,
368
431
action = "store_true" ,
369
- help = "Overwrite cached pytest info"
432
+ help = "Overwrite cached pytest info" ,
370
433
# TODO add finer granularity so can specify which ones to overwrite
371
434
)
372
435
373
436
return parser .parse_args ()
374
437
375
438
376
- def main (args ) :
439
+ def main (args : argparse . Namespace ) -> NoReturn :
377
440
global analysis_files_path
378
441
379
442
commit0_dataset_name = "wentingzhao/commit0_combined"
@@ -493,6 +556,7 @@ def main(args):
493
556
)
494
557
if os .path .exists (submission_repos_path ):
495
558
shutil .rmtree (submission_repos_path )
559
+ print (f"Removed existing at { submission_repos_path } " )
496
560
os .makedirs (os .path .join (analysis_files_path , org_name ), exist_ok = True )
497
561
commit0_config_file = os .path .join (
498
562
analysis_files_path ,
@@ -530,7 +594,7 @@ def main(args):
530
594
)
531
595
# run pytests
532
596
os .system (
533
- f"commit0 evaluate --branch { branch_name } "
597
+ f"commit0 evaluate --branch { branch_name } --timeout 1800 "
534
598
f"--commit0-config-file { commit0_config_file } "
535
599
)
536
600
for example in dataset :
0 commit comments