1313from transformers import AutoTokenizer
1414
1515from commit0 .harness .constants import SPLIT
16+ from commit0 .harness .get_pytest_ids import main as get_tests
1617from commit0 .harness .utils import clone_repo
1718from commit0 .cli import write_commit0_config_file
1819
1920import logging
21+ from typing import Any , NoReturn
2022
2123logging .basicConfig (
2224 level = logging .INFO , format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
2628analysis_files_path = "/share/rush/commit0_analysis_temp"
2729
2830
29- def get_pytest_info (path_to_logs , repo_name , branch_name ):
31+ def get_pytest_info (
32+ path_to_logs : str , repo_name : str , branch_name : str
33+ ) -> dict [str , dict [str , Any ]]:
3034 pytest_info = {}
3135 for pytest_hash in os .listdir (path_to_logs ):
36+ if not os .path .exists (os .path .join (path_to_logs , pytest_hash , "eval.sh" )):
37+ continue
3238 eval_script = open (os .path .join (path_to_logs , pytest_hash , "eval.sh" )).read ()
3339 testname = re .search (r"([\S]+) > test_output" , eval_script ).group (1 )
3440 patch_diff = open (os .path .join (path_to_logs , pytest_hash , "patch.diff" )).read ()
@@ -84,19 +90,19 @@ def get_pytest_info(path_to_logs, repo_name, branch_name):
8490 "failure_string" : failure_string ,
8591 "duration" : duration ,
8692 }
87- return pytest_info
93+ return pytest_info if len ( pytest_info ) else "Could not evaluate"
8894
8995
90- def get_coverage_info (path_to_logs , repo_name , branch_name ) :
96+ def get_coverage_info (path_to_logs : str , repo_name : str , branch_name : str ) -> Any :
9197 raise NotImplementedError
9298
9399
94100def get_blank_repo_metrics (
95- blank_source_code_folder ,
96- spec_filename ,
101+ blank_source_code_folder : str ,
102+ spec_filename : str ,
97103 tokenizer ,
98104 code_file_filter = lambda filename : filename ,
99- ):
105+ ) -> dict [ str , Any ] :
100106 blank_repo_metrics = {
101107 "functions_to_edit" : [],
102108 }
@@ -164,7 +170,7 @@ def get_blank_repo_metrics(
164170
165171
166172leaderboard_header = """\n \n ## Leaderboard ({split})
167- | Name | Repos Resolved (/{num_repos}) | Total Tests Passed (/{total_num_tests}) | Test Duration (s) | Date | Analysis | Github |
173+ | Name | Repos Resolved (/{num_repos}) | Avg. pass rate | Test Duration (s) | Date | Analysis | Github |
168174|------|:-------------------------:|:--------------------:|:--------------------:|:----------:|----|----| """
169175
170176submission_table_header = """# Submission Name: **{display_name}** (split: {split})
@@ -178,33 +184,44 @@ def get_blank_repo_metrics(
178184"""
179185
180186
181- def render_mds (overwrite_previous , subfolder = "docs" ):
187+ def render_mds (overwrite_previous : bool , subfolder : str = "docs" ) -> NoReturn :
182188 leaderboard = {}
183189
184190 split_to_total_tests = {
185191 "lite" : 3628 ,
186192 "all" : 140926 ,
187193 } # hard-coded to skip running it later
188- for split in tqdm . tqdm ( ["lite" , "all" ]) :
194+ for split in ["lite" , "all" ]:
189195 num_repos = len (SPLIT [split ])
190196 # total_num_tests = 0
191197 # for repo_name in SPLIT[split]:
192198 # repo_tests = subprocess.run(['commit0', 'get-tests', repo_name], capture_output=True, text=True).stdout.strip()
193199 # total_num_tests += len(repo_tests.splitlines())
194- leaderboard [split ] = leaderboard_header .format (
195- split = split ,
196- num_repos = num_repos ,
197- total_num_tests = split_to_total_tests [split ],
200+ leaderboard [split ] = []
201+ leaderboard [split ].append (
202+ (
203+ split_to_total_tests [split ] + 1 ,
204+ leaderboard_header .format (
205+ split = split ,
206+ num_repos = num_repos ,
207+ total_num_tests = split_to_total_tests [split ],
208+ ),
209+ )
198210 )
199211
200212 for org_path in tqdm .tqdm (glob .glob (os .path .join (analysis_files_path , "*" ))):
201213 org_name = os .path .basename (org_path )
202214 if org_name in {"blank" , "repos" , "submission_repos" }:
203215 continue
204216 for branch_path in glob .glob (os .path .join (org_path , "*.json" )):
205- cum_tests_passed = 0
217+ evaluate_numbers = []
218+ lite_evaluate_numbers = []
219+ # cum_tests_passed = 0
206220 repos_resolved = 0
207221 total_duration = 0.0
222+ # lite_cum_tests_passed = 0
223+ lite_repos_resolved = 0
224+ lite_total_duration = 0.0
208225 branch_metrics = json .load (open (branch_path ))
209226 submission_info = branch_metrics ["submission_info" ]
210227 split = submission_info ["split" ]
@@ -234,7 +251,7 @@ def render_mds(overwrite_previous, subfolder="docs"):
234251 subfolder , f"analysis_{ org_name } _{ branch_name } _{ repo_name } .md"
235252 )
236253 if isinstance (repo_pytest_results , str ):
237- submission_repo_page = f"# **{ display_name } **: { repo_name } \n \n ## Failed to clone \n \n { repo_pytest_results } "
254+ submission_repo_page = f"# **{ display_name } **: { repo_name } \n \n ## Failed\n \n { repo_pytest_results } "
238255 org_branch_repo_filepath = os .path .join (
239256 subfolder , f"analysis_{ org_name } _{ branch_name } _{ repo_name } .md"
240257 )
@@ -246,7 +263,7 @@ def render_mds(overwrite_previous, subfolder="docs"):
246263 submission_page = submission_table_header .format (
247264 display_name = display_name , split = split
248265 ) + (
249- f"\n | { repo_name } | No; Failed to clone. | - | - | "
266+ f"\n | { repo_name } | No; { repo_pytest_results } | - | - | "
250267 f"[Analysis](/{ f'analysis_{ org_name } _{ branch_name } _{ repo_name } ' } ) | "
251268 f"[Github]({ github_hyperlink } ) |"
252269 )
@@ -267,13 +284,23 @@ def render_mds(overwrite_previous, subfolder="docs"):
267284 )
268285 pytest_details = "Pytest failed"
269286 duration = "Failed."
287+ evaluate_numbers .append (0.0 )
288+ if split == "all" and repo_name in SPLIT ["lite" ]:
289+ lite_evaluate_numbers .append (0.0 )
270290 else :
271291 resolved = False
272292 if "passed" in pytest_info ["summary" ]:
273293 if "skipped" in pytest_info ["summary" ]:
274- resolved = pytest_info ["summary" ]["passed" ] + pytest_info ["summary" ]["skipped" ] == pytest_info ["summary" ]["total" ]
294+ resolved = (
295+ pytest_info ["summary" ]["passed" ]
296+ + pytest_info ["summary" ]["skipped" ]
297+ == pytest_info ["summary" ]["total" ]
298+ )
275299 else :
276- resolved = pytest_info ["summary" ]["passed" ] == pytest_info ["summary" ]["total" ]
300+ resolved = (
301+ pytest_info ["summary" ]["passed" ]
302+ == pytest_info ["summary" ]["total" ]
303+ )
277304 if write_submission :
278305 submission_repo_page += pytest_summary_table_header .format (
279306 pytest_group = pytest_group
@@ -295,9 +322,21 @@ def render_mds(overwrite_previous, subfolder="docs"):
295322 f"### { shortened_testname } \n \n <details><summary> <pre>{ shortened_testname } "
296323 f"</pre></summary><pre>\n { failure ['failure_string' ]} \n </pre>\n </details>\n "
297324 )
298- cum_tests_passed += pytest_info ["summary" ]["passed" ]
325+ # cum_tests_passed += pytest_info["summary"]["passed"]
326+ num_tests = len (get_tests (repo_name , verbose = 0 ))
327+ evaluate_numbers .append (
328+ pytest_info ["summary" ]["passed" ] / num_tests
329+ )
299330 total_duration += pytest_info ["duration" ]
300331 repos_resolved += int (resolved )
332+ if split == "all" and repo_name in SPLIT ["lite" ]:
333+ lite_evaluate_numbers .append (
334+ pytest_info ["summary" ]["passed" ] / num_tests
335+ )
336+ # lite_cum_tests_passed += pytest_info["summary"]["passed"]
337+ lite_total_duration += pytest_info ["duration" ]
338+ lite_repos_resolved += int (resolved )
339+
301340 if write_submission :
302341 pytest_details = f"{ pytest_info ['summary' ]['passed' ]} / { pytest_info ['summary' ]['total' ]} "
303342 duration = f"{ pytest_info ['duration' ]:.2f} "
@@ -322,22 +361,46 @@ def render_mds(overwrite_previous, subfolder="docs"):
322361 wf .write (back_button + "\n " + submission_page )
323362 analysis_link = f"[Analysis](/{ f'analysis_{ org_name } _{ branch_name } ' } )"
324363 github_link = f"[Github]({ project_page_link } )"
325- leaderboard [split ] += (
326- f"\n |{ display_name } |"
327- f"{ repos_resolved } |"
328- f"{ cum_tests_passed } |"
329- f"{ total_duration :.2f} |"
330- f"{ submission_date } |"
331- f"{ analysis_link } |"
332- f"{ github_link } |"
364+ avg_pass_rate = sum (evaluate_numbers ) / len (evaluate_numbers )
365+ leaderboard [split ].append (
366+ (
367+ avg_pass_rate * 100 ,
368+ f"\n |{ display_name } |"
369+ f"{ repos_resolved } |"
370+ f"{ avg_pass_rate * 100 :.2f} %|"
371+ f"{ total_duration :.2f} |"
372+ f"{ submission_date } |"
373+ f"{ analysis_link } |"
374+ f"{ github_link } |" ,
375+ )
333376 )
377+ if (split == "all" ) and ("Reference (Gold)" not in display_name ):
378+ avg_lite_pass_rate = sum (lite_evaluate_numbers ) / len (
379+ lite_evaluate_numbers
380+ )
381+ leaderboard ["lite" ].append (
382+ (
383+ avg_lite_pass_rate * 100 ,
384+ f"\n |{ display_name } (subset of `all`)|"
385+ f"{ lite_repos_resolved } |"
386+ f"{ avg_lite_pass_rate * 100 :.2f} %|"
387+ f"{ lite_total_duration :.2f} |"
388+ f"{ submission_date } |"
389+ f"{ analysis_link } |"
390+ f"{ github_link } |" ,
391+ )
392+ )
334393
335394 leaderboard_filepath = os .path .join (subfolder , "analysis.md" )
395+ for split in ["lite" , "all" ]:
396+ leaderboard [split ] = sorted (leaderboard [split ], key = lambda elt : - elt [0 ])
336397 with open (leaderboard_filepath , "w" ) as wf :
337- wf .write (leaderboard ["lite" ] + "\n \n " + leaderboard ["all" ])
398+ lite_leaderboard_string = "" .join (string for (_ , string ) in leaderboard ["lite" ])
399+ all_leaderboard_string = "" .join (string for (_ , string ) in leaderboard ["all" ])
400+ wf .write (lite_leaderboard_string + "\n \n " + all_leaderboard_string )
338401
339402
340- def get_args ():
403+ def get_args () -> argparse . Namespace :
341404 parser = argparse .ArgumentParser ()
342405 parser .add_argument (
343406 "--do_setup" , action = "store_true" , help = "Run commit0 setup with specified split"
@@ -366,14 +429,14 @@ def get_args():
366429 parser .add_argument (
367430 "--overwrite_previous_eval" ,
368431 action = "store_true" ,
369- help = "Overwrite cached pytest info"
432+ help = "Overwrite cached pytest info" ,
370433 # TODO add finer granularity so can specify which ones to overwrite
371434 )
372435
373436 return parser .parse_args ()
374437
375438
376- def main (args ) :
439+ def main (args : argparse . Namespace ) -> NoReturn :
377440 global analysis_files_path
378441
379442 commit0_dataset_name = "wentingzhao/commit0_combined"
@@ -493,6 +556,7 @@ def main(args):
493556 )
494557 if os .path .exists (submission_repos_path ):
495558 shutil .rmtree (submission_repos_path )
559+ print (f"Removed existing at { submission_repos_path } " )
496560 os .makedirs (os .path .join (analysis_files_path , org_name ), exist_ok = True )
497561 commit0_config_file = os .path .join (
498562 analysis_files_path ,
@@ -530,7 +594,7 @@ def main(args):
530594 )
531595 # run pytests
532596 os .system (
533- f"commit0 evaluate --branch { branch_name } "
597+ f"commit0 evaluate --branch { branch_name } --timeout 1800 "
534598 f"--commit0-config-file { commit0_config_file } "
535599 )
536600 for example in dataset :
0 commit comments