Skip to content

Commit 08617ff

Browse files
authored
Merge pull request #103 from commit-0/integration2
Switched to pytest style output
2 parents fa81bc4 + 5157a12 commit 08617ff

File tree

4 files changed

+40
-10
lines changed

4 files changed

+40
-10
lines changed

commit0/harness/constants.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from enum import Enum
22
from pathlib import Path
3-
from typing import Dict, ItemsView
3+
from typing import Dict, ItemsView, KeysView
44
from pydantic import BaseModel
55

66

@@ -16,17 +16,24 @@ class RepoInstance(BaseModel):
1616
def __getitem__(self, item: str):
1717
return getattr(self, item)
1818

19+
def keys(self) -> KeysView[str]:
20+
"""Return the field names of the model as dictionary keys."""
21+
return self.__annotations__.keys()
22+
1923

2024
class SimpleInstance(BaseModel):
2125
instance_id: str
2226
prompt: str
2327
canonical_solution: str
2428
test: str
25-
entry_point: str
2629

2730
def __getitem__(self, item: str):
2831
return getattr(self, item)
2932

33+
def keys(self) -> KeysView[str]:
34+
"""Return the field names of the model as dictionary keys."""
35+
return self.__annotations__.keys()
36+
3037

3138
class Files(BaseModel):
3239
eval_script: Dict[str, Path]

commit0/harness/run_pytest_ids.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import git
22
import os
3-
import re
43
import sys
54
import traceback
65
from datasets import load_dataset
@@ -21,6 +20,7 @@
2120
generate_patch_between_commits,
2221
setup_logger,
2322
close_logger,
23+
extract_code_blocks,
2424
)
2525
from commit0.harness.execution_context import (
2626
ExecutionBackend,
@@ -165,15 +165,13 @@ def main(
165165
)
166166
else:
167167
solution = open(test_ids).read()
168-
pattern = r"```python\n(.*?)```"
169-
matches = re.finditer(pattern, solution, re.DOTALL)
170-
matches = [match.group(1).strip() for match in matches]
168+
prompt = example["prompt"] if "prompt" in example.keys() else ""
169+
matches = extract_code_blocks(solution)
171170
if len(matches) > 0:
172171
solution = "\n\n".join(matches)
173172
else:
174-
solution = example["prompt"] + "\n\n" + solution
173+
solution = prompt + "\n\n" + solution
175174
patch = solution + "\n\n" + example["test"]
176-
patch = patch + "\n\n" + f"check({example['entry_point']})"
177175
eval_script = spec.eval_script
178176

179177
patch_file = Path(log_dir / "patch.diff")

commit0/harness/spec.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@ def make_repo_script_list(self) -> list[str]:
185185
f"mkdir {self.repo_directory} && cd {self.repo_directory}",
186186
"uv venv --python 3.12",
187187
"source .venv/bin/activate",
188+
"uv pip install -U pytest pytest-cov coverage pytest-json-report",
188189
"which python",
189190
]
190191
return setup_commands
@@ -195,7 +196,7 @@ def make_eval_script_list(self) -> list[str]:
195196
f"cd {self.repo_directory}",
196197
"source .venv/bin/activate",
197198
"cat /patch.diff > test.py",
198-
"uv run test.py > test_output.txt 2>&1",
199+
"pytest test.py > test_output.txt 2>&1",
199200
"echo $? > pytest_exit_code.txt",
200201
]
201202
return eval_script_list

commit0/harness/utils.py

+25-1
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@
44
import logging
55
import os
66
import time
7+
import re
78
import sys
89
from pathlib import Path
9-
from typing import Optional, Union
10+
from typing import List, Optional, Union
1011

1112
from fastcore.net import HTTP404NotFoundError, HTTP403ForbiddenError # type: ignore
1213
from ghapi.core import GhApi
@@ -220,4 +221,27 @@ def get_active_branch(repo_path: Union[str, Path]) -> str:
220221
return branch
221222

222223

224+
def extract_code_blocks(text: str) -> List[str]:
225+
"""Extract Python code blocks from a given text wrapped in markdown markers.
226+
227+
This function identifies and extracts all Python code blocks within a provided
228+
text. The code blocks should be surrounded by markdown-style markers, such as
229+
```python ... ```.
230+
231+
Args:
232+
----
233+
text (str): The input text containing Python code blocks marked with
234+
```python ... ```.
235+
236+
Returns:
237+
-------
238+
List[str]: A list of strings, each containing a Python code block extracted
239+
from the text.
240+
241+
"""
242+
pattern = r"```python\n(.*?)```"
243+
matches = re.finditer(pattern, text, re.DOTALL)
244+
return [match.group(1).strip() for match in matches]
245+
246+
223247
__all__ = []

0 commit comments

Comments
 (0)