From 94c528486aba0285e396c9047d2ad8d0618afe55 Mon Sep 17 00:00:00 2001 From: Wenting Date: Wed, 4 Dec 2024 16:08:59 +0000 Subject: [PATCH 1/2] extending to more datasets --- commit0/cli.py | 9 +++++---- commit0/harness/build.py | 7 ++++--- commit0/harness/run_pytest_ids.py | 9 +++++---- commit0/harness/setup.py | 5 +++-- 4 files changed, 17 insertions(+), 13 deletions(-) diff --git a/commit0/cli.py b/commit0/cli.py index 310f6d2..02fe204 100644 --- a/commit0/cli.py +++ b/commit0/cli.py @@ -257,18 +257,19 @@ def test( if reference: branch = "reference" else: - if "humaneval" not in commit0_config["dataset_name"].split("/")[-1].lower(): + dataset_name = commit0_config["dataset_name"].lower() + if "humaneval" in dataset_name or "mbpp" in dataset_name or "bigcodebench" in dataset_name or "codecontests" in dataset_name: + branch = repo_or_repo_path + else: if branch is None and not reference: git_path = os.path.join( commit0_config["base_dir"], repo_or_repo_path.split("/")[-1] ) branch = get_active_branch(git_path) - else: - branch = test_ids if stdin: # Read test names from stdin - test_ids = sys.stdin.read().strip() + test_ids = sys.stdin.read() elif test_ids is None: typer.echo("Error: test_ids must be provided or use --stdin option", err=True) raise typer.Exit(code=1) diff --git a/commit0/harness/build.py b/commit0/harness/build.py index b38fe33..57ca0bc 100644 --- a/commit0/harness/build.py +++ b/commit0/harness/build.py @@ -25,14 +25,15 @@ def main( dataset_name, split=dataset_split ) # type: ignore specs = [] - if "swe" in dataset_name.lower(): + dataset_name = dataset_name.lower() + if "swe" in dataset_name: dataset_type = "swebench" - elif "humaneval" in dataset_name.lower(): + elif "humaneval" in dataset_name or "mbpp" in dataset_name or "bigcodebench" in dataset_name or "codecontests" in dataset_name: dataset_type = "simple" else: dataset_type = "commit0" for example in dataset: - if "swe" in dataset_name.lower() or dataset_type == "simple": + if "swe" in dataset_name or dataset_type == "simple": if split != "all" and split not in example["instance_id"]: continue else: diff --git a/commit0/harness/run_pytest_ids.py b/commit0/harness/run_pytest_ids.py index eab9d93..da6136f 100644 --- a/commit0/harness/run_pytest_ids.py +++ b/commit0/harness/run_pytest_ids.py @@ -51,6 +51,7 @@ def main( dataset: Iterator[Union[RepoInstance, SimpleInstance]] = load_dataset( dataset_name, split=dataset_split ) # type: ignore + dataset_name = dataset_name.lower() spec = None example = None repo_name = None @@ -58,10 +59,10 @@ def main( for example in dataset: if repo_or_repo_dir.endswith("/"): repo_or_repo_dir = repo_or_repo_dir[:-1] - if "swe" in dataset_name.lower(): + if "swe" in dataset_name: repo_name = example["instance_id"] dataset_type = "swebench" - elif "humaneval" in dataset_name.lower(): + elif "humaneval" in dataset_name or "mbpp" in dataset_name or "bigcodebench" in dataset_name or "codecontests" in dataset_name: repo_name = example["instance_id"] dataset_type = "simple" else: @@ -130,7 +131,7 @@ def main( ) # make patch file - if "swe" in dataset_name.lower(): + if "swe" in dataset_name: if branch == "reference": patch = ( example["test"]["patch"] + "\n\n" + example["test"]["test_patch"] @@ -164,7 +165,7 @@ def main( + example["test"] ) else: - solution = open(test_ids).read() + solution = test_ids prompt = example["prompt"] if "prompt" in example.keys() else "" matches = extract_code_blocks(solution) if len(matches) > 0: diff --git a/commit0/harness/setup.py b/commit0/harness/setup.py index 816cd91..ab0de0a 100644 --- a/commit0/harness/setup.py +++ b/commit0/harness/setup.py @@ -23,12 +23,13 @@ def main( base_dir: str, ) -> None: dataset: Iterator[RepoInstance] = load_dataset(dataset_name, split=dataset_split) # type: ignore - if "humaneval" in dataset_name.lower(): + dataset_name = dataset_name.lower() + if "humaneval" in dataset_name or "mbpp" in dataset_name or "bigcodebench" in dataset_name or "codecontests" in dataset_name: return for example in dataset: repo_name = example["repo"].split("/")[-1] clone_url = f"https://github.com/{example['repo']}.git" - if "swe" in dataset_name.lower(): + if "swe" in dataset_name: if repo_split != "all" and repo_split not in example["instance_id"]: continue clone_dir = os.path.abspath(os.path.join(base_dir, example["instance_id"])) From 3cf275cb5844c60b1b4cb9e5ab19fa2129119854 Mon Sep 17 00:00:00 2001 From: Wenting Date: Wed, 4 Dec 2024 16:11:21 +0000 Subject: [PATCH 2/2] pre-commit fixes --- commit0/cli.py | 7 ++++++- commit0/harness/build.py | 7 ++++++- commit0/harness/run_pytest_ids.py | 7 ++++++- commit0/harness/setup.py | 7 ++++++- 4 files changed, 24 insertions(+), 4 deletions(-) diff --git a/commit0/cli.py b/commit0/cli.py index 02fe204..1b67080 100644 --- a/commit0/cli.py +++ b/commit0/cli.py @@ -258,7 +258,12 @@ def test( branch = "reference" else: dataset_name = commit0_config["dataset_name"].lower() - if "humaneval" in dataset_name or "mbpp" in dataset_name or "bigcodebench" in dataset_name or "codecontests" in dataset_name: + if ( + "humaneval" in dataset_name + or "mbpp" in dataset_name + or "bigcodebench" in dataset_name + or "codecontests" in dataset_name + ): branch = repo_or_repo_path else: if branch is None and not reference: diff --git a/commit0/harness/build.py b/commit0/harness/build.py index 57ca0bc..216a10c 100644 --- a/commit0/harness/build.py +++ b/commit0/harness/build.py @@ -28,7 +28,12 @@ def main( dataset_name = dataset_name.lower() if "swe" in dataset_name: dataset_type = "swebench" - elif "humaneval" in dataset_name or "mbpp" in dataset_name or "bigcodebench" in dataset_name or "codecontests" in dataset_name: + elif ( + "humaneval" in dataset_name + or "mbpp" in dataset_name + or "bigcodebench" in dataset_name + or "codecontests" in dataset_name + ): dataset_type = "simple" else: dataset_type = "commit0" diff --git a/commit0/harness/run_pytest_ids.py b/commit0/harness/run_pytest_ids.py index da6136f..06857df 100644 --- a/commit0/harness/run_pytest_ids.py +++ b/commit0/harness/run_pytest_ids.py @@ -62,7 +62,12 @@ def main( if "swe" in dataset_name: repo_name = example["instance_id"] dataset_type = "swebench" - elif "humaneval" in dataset_name or "mbpp" in dataset_name or "bigcodebench" in dataset_name or "codecontests" in dataset_name: + elif ( + "humaneval" in dataset_name + or "mbpp" in dataset_name + or "bigcodebench" in dataset_name + or "codecontests" in dataset_name + ): repo_name = example["instance_id"] dataset_type = "simple" else: diff --git a/commit0/harness/setup.py b/commit0/harness/setup.py index ab0de0a..a69dd60 100644 --- a/commit0/harness/setup.py +++ b/commit0/harness/setup.py @@ -24,7 +24,12 @@ def main( ) -> None: dataset: Iterator[RepoInstance] = load_dataset(dataset_name, split=dataset_split) # type: ignore dataset_name = dataset_name.lower() - if "humaneval" in dataset_name or "mbpp" in dataset_name or "bigcodebench" in dataset_name or "codecontests" in dataset_name: + if ( + "humaneval" in dataset_name + or "mbpp" in dataset_name + or "bigcodebench" in dataset_name + or "codecontests" in dataset_name + ): return for example in dataset: repo_name = example["repo"].split("/")[-1]