1818from commit0 .cli import write_commit0_config_file
1919
2020import logging
21+ from typing import Any , NoReturn
2122
2223logging .basicConfig (
2324 level = logging .INFO , format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
2728analysis_files_path = "/share/rush/commit0_analysis_temp"
2829
2930
30- def get_pytest_info (path_to_logs , repo_name , branch_name ):
31+ def get_pytest_info (
32+ path_to_logs : str , repo_name : str , branch_name : str
33+ ) -> dict [str , dict [str , Any ]]:
3134 pytest_info = {}
3235 for pytest_hash in os .listdir (path_to_logs ):
36+ if not os .path .exists (os .path .join (path_to_logs , pytest_hash , "eval.sh" )):
37+ continue
3338 eval_script = open (os .path .join (path_to_logs , pytest_hash , "eval.sh" )).read ()
3439 testname = re .search (r"([\S]+) > test_output" , eval_script ).group (1 )
3540 patch_diff = open (os .path .join (path_to_logs , pytest_hash , "patch.diff" )).read ()
@@ -85,19 +90,19 @@ def get_pytest_info(path_to_logs, repo_name, branch_name):
8590 "failure_string" : failure_string ,
8691 "duration" : duration ,
8792 }
88- return pytest_info
93+ return pytest_info if len ( pytest_info ) else "Could not evaluate"
8994
9095
91- def get_coverage_info (path_to_logs , repo_name , branch_name ) :
96+ def get_coverage_info (path_to_logs : str , repo_name : str , branch_name : str ) -> Any :
9297 raise NotImplementedError
9398
9499
95100def get_blank_repo_metrics (
96- blank_source_code_folder ,
97- spec_filename ,
101+ blank_source_code_folder : str ,
102+ spec_filename : str ,
98103 tokenizer ,
99104 code_file_filter = lambda filename : filename ,
100- ):
105+ ) -> dict [ str , Any ] :
101106 blank_repo_metrics = {
102107 "functions_to_edit" : [],
103108 }
@@ -165,7 +170,7 @@ def get_blank_repo_metrics(
165170
166171
167172leaderboard_header = """\n \n ## Leaderboard ({split})
168- | Name | Repos Resolved (/{num_repos}) | Tests Passed (Total: {total_num_tests}) | Test Duration (s) | Date | Analysis | Github |
173+ | Name | Repos Resolved (/{num_repos}) | Avg. pass rate | Test Duration (s) | Date | Analysis | Github |
169174|------|:-------------------------:|:--------------------:|:--------------------:|:----------:|----|----| """
170175
171176submission_table_header = """# Submission Name: **{display_name}** (split: {split})
@@ -179,7 +184,7 @@ def get_blank_repo_metrics(
179184"""
180185
181186
182- def render_mds (overwrite_previous , subfolder = "docs" ):
187+ def render_mds (overwrite_previous : bool , subfolder : str = "docs" ) -> NoReturn :
183188 leaderboard = {}
184189
185190 split_to_total_tests = {
@@ -193,11 +198,16 @@ def render_mds(overwrite_previous, subfolder="docs"):
193198 # repo_tests = subprocess.run(['commit0', 'get-tests', repo_name], capture_output=True, text=True).stdout.strip()
194199 # total_num_tests += len(repo_tests.splitlines())
195200 leaderboard [split ] = []
196- leaderboard [split ].append ((split_to_total_tests [split ]+ 1 , leaderboard_header .format (
197- split = split ,
198- num_repos = num_repos ,
199- total_num_tests = split_to_total_tests [split ],
200- )))
201+ leaderboard [split ].append (
202+ (
203+ split_to_total_tests [split ] + 1 ,
204+ leaderboard_header .format (
205+ split = split ,
206+ num_repos = num_repos ,
207+ total_num_tests = split_to_total_tests [split ],
208+ ),
209+ )
210+ )
201211
202212 for org_path in tqdm .tqdm (glob .glob (os .path .join (analysis_files_path , "*" ))):
203213 org_name = os .path .basename (org_path )
@@ -241,7 +251,7 @@ def render_mds(overwrite_previous, subfolder="docs"):
241251 subfolder , f"analysis_{ org_name } _{ branch_name } _{ repo_name } .md"
242252 )
243253 if isinstance (repo_pytest_results , str ):
244- submission_repo_page = f"# **{ display_name } **: { repo_name } \n \n ## Failed to clone \n \n { repo_pytest_results } "
254+ submission_repo_page = f"# **{ display_name } **: { repo_name } \n \n ## Failed\n \n { repo_pytest_results } "
245255 org_branch_repo_filepath = os .path .join (
246256 subfolder , f"analysis_{ org_name } _{ branch_name } _{ repo_name } .md"
247257 )
@@ -253,7 +263,7 @@ def render_mds(overwrite_previous, subfolder="docs"):
253263 submission_page = submission_table_header .format (
254264 display_name = display_name , split = split
255265 ) + (
256- f"\n | { repo_name } | No; Failed to clone. | - | - | "
266+ f"\n | { repo_name } | No; { repo_pytest_results } | - | - | "
257267 f"[Analysis](/{ f'analysis_{ org_name } _{ branch_name } _{ repo_name } ' } ) | "
258268 f"[Github]({ github_hyperlink } ) |"
259269 )
@@ -274,16 +284,23 @@ def render_mds(overwrite_previous, subfolder="docs"):
274284 )
275285 pytest_details = "Pytest failed"
276286 duration = "Failed."
277- evaluate_numbers .append (0. )
278- if split == "all" and repo_name in SPLIT [' lite' ]:
279- lite_evaluate_numbers .append (0. )
287+ evaluate_numbers .append (0.0 )
288+ if split == "all" and repo_name in SPLIT [" lite" ]:
289+ lite_evaluate_numbers .append (0.0 )
280290 else :
281291 resolved = False
282292 if "passed" in pytest_info ["summary" ]:
283293 if "skipped" in pytest_info ["summary" ]:
284- resolved = pytest_info ["summary" ]["passed" ] + pytest_info ["summary" ]["skipped" ] == pytest_info ["summary" ]["total" ]
294+ resolved = (
295+ pytest_info ["summary" ]["passed" ]
296+ + pytest_info ["summary" ]["skipped" ]
297+ == pytest_info ["summary" ]["total" ]
298+ )
285299 else :
286- resolved = pytest_info ["summary" ]["passed" ] == pytest_info ["summary" ]["total" ]
300+ resolved = (
301+ pytest_info ["summary" ]["passed" ]
302+ == pytest_info ["summary" ]["total" ]
303+ )
287304 if write_submission :
288305 submission_repo_page += pytest_summary_table_header .format (
289306 pytest_group = pytest_group
@@ -307,11 +324,15 @@ def render_mds(overwrite_previous, subfolder="docs"):
307324 )
308325 # cum_tests_passed += pytest_info["summary"]["passed"]
309326 num_tests = len (get_tests (repo_name , verbose = 0 ))
310- evaluate_numbers .append (pytest_info ["summary" ]["passed" ] / num_tests )
327+ evaluate_numbers .append (
328+ pytest_info ["summary" ]["passed" ] / num_tests
329+ )
311330 total_duration += pytest_info ["duration" ]
312331 repos_resolved += int (resolved )
313- if split == "all" and repo_name in SPLIT ['lite' ]:
314- lite_evaluate_numbers .append (pytest_info ["summary" ]["passed" ] / num_tests )
332+ if split == "all" and repo_name in SPLIT ["lite" ]:
333+ lite_evaluate_numbers .append (
334+ pytest_info ["summary" ]["passed" ] / num_tests
335+ )
315336 # lite_cum_tests_passed += pytest_info["summary"]["passed"]
316337 lite_total_duration += pytest_info ["duration" ]
317338 lite_repos_resolved += int (resolved )
@@ -341,26 +362,34 @@ def render_mds(overwrite_previous, subfolder="docs"):
341362 analysis_link = f"[Analysis](/{ f'analysis_{ org_name } _{ branch_name } ' } )"
342363 github_link = f"[Github]({ project_page_link } )"
343364 avg_pass_rate = sum (evaluate_numbers ) / len (evaluate_numbers )
344- leaderboard [split ].append ((avg_pass_rate * 100 ,
345- f"\n |{ display_name } |"
346- f"{ repos_resolved } |"
347- f"{ avg_pass_rate * 100 :.2f} %|"
348- f"{ total_duration :.2f} |"
349- f"{ submission_date } |"
350- f"{ analysis_link } |"
351- f"{ github_link } |"
352- ))
353- if ((split == "all" ) and ("Reference (Gold)" not in display_name )):
354- avg_lite_pass_rate = sum (lite_evaluate_numbers ) / len (lite_evaluate_numbers )
355- leaderboard ["lite" ].append ((avg_lite_pass_rate * 100 ,
356- f"\n |{ display_name } (subset of `all`)|"
357- f"{ lite_repos_resolved } |"
358- f"{ avg_lite_pass_rate * 100 :.2f} %|"
359- f"{ lite_total_duration :.2f} |"
365+ leaderboard [split ].append (
366+ (
367+ avg_pass_rate * 100 ,
368+ f"\n |{ display_name } |"
369+ f"{ repos_resolved } |"
370+ f"{ avg_pass_rate * 100 :.2f} %|"
371+ f"{ total_duration :.2f} |"
360372 f"{ submission_date } |"
361373 f"{ analysis_link } |"
362- f"{ github_link } |"
363- ))
374+ f"{ github_link } |" ,
375+ )
376+ )
377+ if (split == "all" ) and ("Reference (Gold)" not in display_name ):
378+ avg_lite_pass_rate = sum (lite_evaluate_numbers ) / len (
379+ lite_evaluate_numbers
380+ )
381+ leaderboard ["lite" ].append (
382+ (
383+ avg_lite_pass_rate * 100 ,
384+ f"\n |{ display_name } (subset of `all`)|"
385+ f"{ lite_repos_resolved } |"
386+ f"{ avg_lite_pass_rate * 100 :.2f} %|"
387+ f"{ lite_total_duration :.2f} |"
388+ f"{ submission_date } |"
389+ f"{ analysis_link } |"
390+ f"{ github_link } |" ,
391+ )
392+ )
364393
365394 leaderboard_filepath = os .path .join (subfolder , "analysis.md" )
366395 for split in ["lite" , "all" ]:
@@ -371,7 +400,7 @@ def render_mds(overwrite_previous, subfolder="docs"):
371400 wf .write (lite_leaderboard_string + "\n \n " + all_leaderboard_string )
372401
373402
374- def get_args ():
403+ def get_args () -> argparse . Namespace :
375404 parser = argparse .ArgumentParser ()
376405 parser .add_argument (
377406 "--do_setup" , action = "store_true" , help = "Run commit0 setup with specified split"
@@ -400,14 +429,14 @@ def get_args():
400429 parser .add_argument (
401430 "--overwrite_previous_eval" ,
402431 action = "store_true" ,
403- help = "Overwrite cached pytest info"
432+ help = "Overwrite cached pytest info" ,
404433 # TODO add finer granularity so can specify which ones to overwrite
405434 )
406435
407436 return parser .parse_args ()
408437
409438
410- def main (args ) :
439+ def main (args : argparse . Namespace ) -> NoReturn :
411440 global analysis_files_path
412441
413442 commit0_dataset_name = "wentingzhao/commit0_combined"
@@ -565,7 +594,7 @@ def main(args):
565594 )
566595 # run pytests
567596 os .system (
568- f"commit0 evaluate --branch { branch_name } --timeout 1800"
597+ f"commit0 evaluate --branch { branch_name } --timeout 1800 "
569598 f"--commit0-config-file { commit0_dot_file_path } "
570599 )
571600 for example in dataset :
0 commit comments