Skip to content

Commit 1b7780d

Browse files
authored
Merge pull request #104 from commit-0/integration3
extending integration to more datasets
2 parents 08617ff + 3cf275c commit 1b7780d

File tree

4 files changed

+37
-13
lines changed

4 files changed

+37
-13
lines changed

commit0/cli.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -257,18 +257,24 @@ def test(
257257
if reference:
258258
branch = "reference"
259259
else:
260-
if "humaneval" not in commit0_config["dataset_name"].split("/")[-1].lower():
260+
dataset_name = commit0_config["dataset_name"].lower()
261+
if (
262+
"humaneval" in dataset_name
263+
or "mbpp" in dataset_name
264+
or "bigcodebench" in dataset_name
265+
or "codecontests" in dataset_name
266+
):
267+
branch = repo_or_repo_path
268+
else:
261269
if branch is None and not reference:
262270
git_path = os.path.join(
263271
commit0_config["base_dir"], repo_or_repo_path.split("/")[-1]
264272
)
265273
branch = get_active_branch(git_path)
266-
else:
267-
branch = test_ids
268274

269275
if stdin:
270276
# Read test names from stdin
271-
test_ids = sys.stdin.read().strip()
277+
test_ids = sys.stdin.read()
272278
elif test_ids is None:
273279
typer.echo("Error: test_ids must be provided or use --stdin option", err=True)
274280
raise typer.Exit(code=1)

commit0/harness/build.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,20 @@ def main(
2525
dataset_name, split=dataset_split
2626
) # type: ignore
2727
specs = []
28-
if "swe" in dataset_name.lower():
28+
dataset_name = dataset_name.lower()
29+
if "swe" in dataset_name:
2930
dataset_type = "swebench"
30-
elif "humaneval" in dataset_name.lower():
31+
elif (
32+
"humaneval" in dataset_name
33+
or "mbpp" in dataset_name
34+
or "bigcodebench" in dataset_name
35+
or "codecontests" in dataset_name
36+
):
3137
dataset_type = "simple"
3238
else:
3339
dataset_type = "commit0"
3440
for example in dataset:
35-
if "swe" in dataset_name.lower() or dataset_type == "simple":
41+
if "swe" in dataset_name or dataset_type == "simple":
3642
if split != "all" and split not in example["instance_id"]:
3743
continue
3844
else:

commit0/harness/run_pytest_ids.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -51,17 +51,23 @@ def main(
5151
dataset: Iterator[Union[RepoInstance, SimpleInstance]] = load_dataset(
5252
dataset_name, split=dataset_split
5353
) # type: ignore
54+
dataset_name = dataset_name.lower()
5455
spec = None
5556
example = None
5657
repo_name = None
5758
dataset_type = None
5859
for example in dataset:
5960
if repo_or_repo_dir.endswith("/"):
6061
repo_or_repo_dir = repo_or_repo_dir[:-1]
61-
if "swe" in dataset_name.lower():
62+
if "swe" in dataset_name:
6263
repo_name = example["instance_id"]
6364
dataset_type = "swebench"
64-
elif "humaneval" in dataset_name.lower():
65+
elif (
66+
"humaneval" in dataset_name
67+
or "mbpp" in dataset_name
68+
or "bigcodebench" in dataset_name
69+
or "codecontests" in dataset_name
70+
):
6571
repo_name = example["instance_id"]
6672
dataset_type = "simple"
6773
else:
@@ -130,7 +136,7 @@ def main(
130136
)
131137

132138
# make patch file
133-
if "swe" in dataset_name.lower():
139+
if "swe" in dataset_name:
134140
if branch == "reference":
135141
patch = (
136142
example["test"]["patch"] + "\n\n" + example["test"]["test_patch"]
@@ -164,7 +170,7 @@ def main(
164170
+ example["test"]
165171
)
166172
else:
167-
solution = open(test_ids).read()
173+
solution = test_ids
168174
prompt = example["prompt"] if "prompt" in example.keys() else ""
169175
matches = extract_code_blocks(solution)
170176
if len(matches) > 0:

commit0/harness/setup.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,18 @@ def main(
2323
base_dir: str,
2424
) -> None:
2525
dataset: Iterator[RepoInstance] = load_dataset(dataset_name, split=dataset_split) # type: ignore
26-
if "humaneval" in dataset_name.lower():
26+
dataset_name = dataset_name.lower()
27+
if (
28+
"humaneval" in dataset_name
29+
or "mbpp" in dataset_name
30+
or "bigcodebench" in dataset_name
31+
or "codecontests" in dataset_name
32+
):
2733
return
2834
for example in dataset:
2935
repo_name = example["repo"].split("/")[-1]
3036
clone_url = f"https://github.com/{example['repo']}.git"
31-
if "swe" in dataset_name.lower():
37+
if "swe" in dataset_name:
3238
if repo_split != "all" and repo_split not in example["instance_id"]:
3339
continue
3440
clone_dir = os.path.abspath(os.path.join(base_dir, example["instance_id"]))

0 commit comments

Comments
 (0)