Skip to content

extending integration to more datasets #104

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions commit0/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,18 +257,24 @@ def test(
if reference:
branch = "reference"
else:
if "humaneval" not in commit0_config["dataset_name"].split("/")[-1].lower():
dataset_name = commit0_config["dataset_name"].lower()
if (
"humaneval" in dataset_name
or "mbpp" in dataset_name
or "bigcodebench" in dataset_name
or "codecontests" in dataset_name
):
branch = repo_or_repo_path
else:
if branch is None and not reference:
git_path = os.path.join(
commit0_config["base_dir"], repo_or_repo_path.split("/")[-1]
)
branch = get_active_branch(git_path)
else:
branch = test_ids

if stdin:
# Read test names from stdin
test_ids = sys.stdin.read().strip()
test_ids = sys.stdin.read()
elif test_ids is None:
typer.echo("Error: test_ids must be provided or use --stdin option", err=True)
raise typer.Exit(code=1)
Expand Down
12 changes: 9 additions & 3 deletions commit0/harness/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,20 @@ def main(
dataset_name, split=dataset_split
) # type: ignore
specs = []
if "swe" in dataset_name.lower():
dataset_name = dataset_name.lower()
if "swe" in dataset_name:
dataset_type = "swebench"
elif "humaneval" in dataset_name.lower():
elif (
"humaneval" in dataset_name
or "mbpp" in dataset_name
or "bigcodebench" in dataset_name
or "codecontests" in dataset_name
):
dataset_type = "simple"
else:
dataset_type = "commit0"
for example in dataset:
if "swe" in dataset_name.lower() or dataset_type == "simple":
if "swe" in dataset_name or dataset_type == "simple":
if split != "all" and split not in example["instance_id"]:
continue
else:
Expand Down
14 changes: 10 additions & 4 deletions commit0/harness/run_pytest_ids.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,17 +51,23 @@ def main(
dataset: Iterator[Union[RepoInstance, SimpleInstance]] = load_dataset(
dataset_name, split=dataset_split
) # type: ignore
dataset_name = dataset_name.lower()
spec = None
example = None
repo_name = None
dataset_type = None
for example in dataset:
if repo_or_repo_dir.endswith("/"):
repo_or_repo_dir = repo_or_repo_dir[:-1]
if "swe" in dataset_name.lower():
if "swe" in dataset_name:
repo_name = example["instance_id"]
dataset_type = "swebench"
elif "humaneval" in dataset_name.lower():
elif (
"humaneval" in dataset_name
or "mbpp" in dataset_name
or "bigcodebench" in dataset_name
or "codecontests" in dataset_name
):
repo_name = example["instance_id"]
dataset_type = "simple"
else:
Expand Down Expand Up @@ -130,7 +136,7 @@ def main(
)

# make patch file
if "swe" in dataset_name.lower():
if "swe" in dataset_name:
if branch == "reference":
patch = (
example["test"]["patch"] + "\n\n" + example["test"]["test_patch"]
Expand Down Expand Up @@ -164,7 +170,7 @@ def main(
+ example["test"]
)
else:
solution = open(test_ids).read()
solution = test_ids
prompt = example["prompt"] if "prompt" in example.keys() else ""
matches = extract_code_blocks(solution)
if len(matches) > 0:
Expand Down
10 changes: 8 additions & 2 deletions commit0/harness/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,18 @@ def main(
base_dir: str,
) -> None:
dataset: Iterator[RepoInstance] = load_dataset(dataset_name, split=dataset_split) # type: ignore
if "humaneval" in dataset_name.lower():
dataset_name = dataset_name.lower()
if (
"humaneval" in dataset_name
or "mbpp" in dataset_name
or "bigcodebench" in dataset_name
or "codecontests" in dataset_name
):
return
for example in dataset:
repo_name = example["repo"].split("/")[-1]
clone_url = f"https://github.com/{example['repo']}.git"
if "swe" in dataset_name.lower():
if "swe" in dataset_name:
if repo_split != "all" and repo_split not in example["instance_id"]:
continue
clone_dir = os.path.abspath(os.path.join(base_dir, example["instance_id"]))
Expand Down
Loading