diff --git a/commit0/cli.py b/commit0/cli.py index 310f6d2..1b67080 100644 --- a/commit0/cli.py +++ b/commit0/cli.py @@ -257,18 +257,24 @@ def test( if reference: branch = "reference" else: - if "humaneval" not in commit0_config["dataset_name"].split("/")[-1].lower(): + dataset_name = commit0_config["dataset_name"].lower() + if ( + "humaneval" in dataset_name + or "mbpp" in dataset_name + or "bigcodebench" in dataset_name + or "codecontests" in dataset_name + ): + branch = repo_or_repo_path + else: if branch is None and not reference: git_path = os.path.join( commit0_config["base_dir"], repo_or_repo_path.split("/")[-1] ) branch = get_active_branch(git_path) - else: - branch = test_ids if stdin: # Read test names from stdin - test_ids = sys.stdin.read().strip() + test_ids = sys.stdin.read() elif test_ids is None: typer.echo("Error: test_ids must be provided or use --stdin option", err=True) raise typer.Exit(code=1) diff --git a/commit0/harness/build.py b/commit0/harness/build.py index b38fe33..216a10c 100644 --- a/commit0/harness/build.py +++ b/commit0/harness/build.py @@ -25,14 +25,20 @@ def main( dataset_name, split=dataset_split ) # type: ignore specs = [] - if "swe" in dataset_name.lower(): + dataset_name = dataset_name.lower() + if "swe" in dataset_name: dataset_type = "swebench" - elif "humaneval" in dataset_name.lower(): + elif ( + "humaneval" in dataset_name + or "mbpp" in dataset_name + or "bigcodebench" in dataset_name + or "codecontests" in dataset_name + ): dataset_type = "simple" else: dataset_type = "commit0" for example in dataset: - if "swe" in dataset_name.lower() or dataset_type == "simple": + if "swe" in dataset_name or dataset_type == "simple": if split != "all" and split not in example["instance_id"]: continue else: diff --git a/commit0/harness/run_pytest_ids.py b/commit0/harness/run_pytest_ids.py index eab9d93..06857df 100644 --- a/commit0/harness/run_pytest_ids.py +++ b/commit0/harness/run_pytest_ids.py @@ -51,6 +51,7 @@ def main( dataset: Iterator[Union[RepoInstance, SimpleInstance]] = load_dataset( dataset_name, split=dataset_split ) # type: ignore + dataset_name = dataset_name.lower() spec = None example = None repo_name = None @@ -58,10 +59,15 @@ def main( for example in dataset: if repo_or_repo_dir.endswith("/"): repo_or_repo_dir = repo_or_repo_dir[:-1] - if "swe" in dataset_name.lower(): + if "swe" in dataset_name: repo_name = example["instance_id"] dataset_type = "swebench" - elif "humaneval" in dataset_name.lower(): + elif ( + "humaneval" in dataset_name + or "mbpp" in dataset_name + or "bigcodebench" in dataset_name + or "codecontests" in dataset_name + ): repo_name = example["instance_id"] dataset_type = "simple" else: @@ -130,7 +136,7 @@ def main( ) # make patch file - if "swe" in dataset_name.lower(): + if "swe" in dataset_name: if branch == "reference": patch = ( example["test"]["patch"] + "\n\n" + example["test"]["test_patch"] @@ -164,7 +170,7 @@ def main( + example["test"] ) else: - solution = open(test_ids).read() + solution = test_ids prompt = example["prompt"] if "prompt" in example.keys() else "" matches = extract_code_blocks(solution) if len(matches) > 0: diff --git a/commit0/harness/setup.py b/commit0/harness/setup.py index 816cd91..a69dd60 100644 --- a/commit0/harness/setup.py +++ b/commit0/harness/setup.py @@ -23,12 +23,18 @@ def main( base_dir: str, ) -> None: dataset: Iterator[RepoInstance] = load_dataset(dataset_name, split=dataset_split) # type: ignore - if "humaneval" in dataset_name.lower(): + dataset_name = dataset_name.lower() + if ( + "humaneval" in dataset_name + or "mbpp" in dataset_name + or "bigcodebench" in dataset_name + or "codecontests" in dataset_name + ): return for example in dataset: repo_name = example["repo"].split("/")[-1] clone_url = f"https://github.com/{example['repo']}.git" - if "swe" in dataset_name.lower(): + if "swe" in dataset_name: if repo_split != "all" and repo_split not in example["instance_id"]: continue clone_dir = os.path.abspath(os.path.join(base_dir, example["instance_id"]))