Skip to content

Commit ad23e19

Browse files
committed
seperate commit0 specific stuff from agent stuff
1 parent 585c526 commit ad23e19

File tree

6 files changed

+53
-93
lines changed

6 files changed

+53
-93
lines changed

baselines/class_types.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,9 @@ class Commit0Config:
1111

1212

1313
@dataclass
14-
class AiderConfig:
15-
llm_name: str
14+
class AgentConfig:
15+
agent_name: str
16+
model_name: str
1617
use_user_prompt: bool
1718
user_prompt: str
1819
use_repo_info: bool

baselines/baseline_utils.py renamed to baselines/commit0_utils.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from pathlib import Path
55
from typing import Any, Dict, List
66

7-
from baselines.class_types import AiderConfig
7+
from baselines.class_types import AgentConfig
88

99
PROMPT_HEADER = ">>> Here is the Task:\n"
1010
REFERENCE_HEADER = "\n\n>>> Here is the Reference for you to finish the task:\n"
@@ -138,51 +138,51 @@ def get_target_edit_files(target_dir: str) -> list[str]:
138138
return files
139139

140140

141-
def get_message_to_aider(
142-
aider_config: AiderConfig,
141+
def get_message(
142+
agent_config: AgentConfig,
143143
target_edit_files_cmd_args: str,
144144
repo_path: str,
145145
ds: Dict[str, Any],
146146
) -> str:
147147
"""Get the message to Aider."""
148-
prompt = f"{PROMPT_HEADER} " + aider_config.user_prompt
148+
prompt = f"{PROMPT_HEADER}" + agent_config.user_prompt
149149

150-
if aider_config.use_unit_tests_info and ds["test"]["test_dir"]:
150+
if agent_config.use_unit_tests_info and ds["test"]["test_dir"]:
151151
unit_tests_info = (
152152
f"\n{UNIT_TESTS_INFO_HEADER} "
153153
+ get_dir_info(
154154
dir_path=Path(os.path.join(repo_path, ds["test"]["test_dir"])),
155155
prefix="",
156156
include_stubs=True,
157-
)[: aider_config.max_unit_tests_info_length]
157+
)[: agent_config.max_unit_tests_info_length]
158158
)
159159
else:
160160
unit_tests_info = ""
161161

162162
# TODO: assuming we have specification, which we currently do not have
163-
if aider_config.use_reference_info and ds["specification"]:
163+
if agent_config.use_reference_info and ds["specification"]:
164164
reference = (
165165
f"\n{REFERENCE_HEADER} "
166166
+ get_reference(ds["specification"])[
167-
: aider_config.max_reference_info_length
167+
: agent_config.max_reference_info_length
168168
]
169169
)
170170
else:
171171
reference = ""
172172

173-
if aider_config.use_repo_info:
173+
if agent_config.use_repo_info:
174174
repo_info = (
175175
f"\n{REPO_INFO_HEADER} "
176176
+ get_dir_info(
177177
dir_path=Path(repo_path), prefix="", max_depth=2, include_stubs=False
178-
)[: aider_config.max_repo_info_length]
178+
)[: agent_config.max_repo_info_length]
179179
)
180180
else:
181181
repo_info = ""
182182

183-
message_to_aider = prompt + reference + repo_info + unit_tests_info
183+
message_to_agent = prompt + reference + repo_info + unit_tests_info
184184

185-
return message_to_aider
185+
return message_to_agent
186186

187187

188188
def get_reference(specification_pdf_path: str) -> str:

baselines/configs/aider.yaml renamed to baselines/configs/agent.yaml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,11 @@ defaults:
66
commit0_config:
77
repo_split: minitorch
88

9-
aider_config:
9+
agent_config:
1010
use_user_prompt: false
1111
use_repo_info: false
1212
use_unit_tests_info: false
1313
use_reference_info: false
1414
use_lint_info: false
1515
pre_commit_config_path: .pre-commit-config.yaml
1616
run_tests: false
17-
llm_name: o1-preview

baselines/configs/base.yaml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,9 @@ commit0_config:
1010
repo_split: "simpy"
1111
num_workers: 10
1212

13-
aider_config:
14-
llm_name: "claude-3-5-sonnet-20240620"
13+
agent_config:
14+
agent_name: "aider"
15+
model_name: "claude-3-5-sonnet-20240620"
1516
use_user_prompt: false
1617
user_prompt: "Here is your task:\nYou need to implement all functions with 'NotImplementedError('IMPLEMENT ME HERE')' and pass the unit tests.\nDo not change the names of existing functions or classes, as they may be referenced from other code like unit tests, etc.\nWhen you generate code, you must maintain the original formatting of the function stubs (such as whitespaces), otherwise we will not able to search/replace blocks for code modifications, and therefore you will receive a score of 0 for your generated code."
1718
use_repo_info: false

baselines/run_aider.py renamed to baselines/run_agent.py

Lines changed: 32 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -4,24 +4,21 @@
44
import hydra
55
from datasets import load_dataset
66
import traceback
7-
from baselines.baseline_utils import (
8-
get_message_to_aider,
7+
from baselines.commit0_utils import (
8+
get_message,
99
get_target_edit_files,
1010
)
11+
from baselines.agents import AiderAgents
1112
from typing import Optional, Type
1213
from types import TracebackType
1314
from hydra.core.config_store import ConfigStore
14-
from baselines.class_types import AiderConfig, Commit0Config
15+
from baselines.class_types import AgentConfig, Commit0Config
1516
from commit0.harness.constants import SPLIT
1617
from commit0.harness.get_pytest_ids import main as get_tests
1718
from tqdm import tqdm
1819
from concurrent.futures import ThreadPoolExecutor, as_completed
1920
from commit0.harness.constants import RUN_AIDER_LOG_DIR
2021

21-
from aider.coders import Coder
22-
from aider.models import Model
23-
from aider.io import InputOutput
24-
2522

2623
class DirContext:
2724
def __init__(self, d):
@@ -38,33 +35,9 @@ def __exit__(
3835
os.chdir(self.cwd)
3936

4037

41-
def run_aider(
42-
model_name: str,
43-
fnames: list[str],
44-
message: str,
45-
test_cmd: str,
46-
lint_cmd: str,
47-
log_dir: Path,
48-
) -> None:
49-
if test_cmd:
50-
auto_test = True
51-
else:
52-
auto_test = False
53-
if lint_cmd:
54-
auto_lint = True
55-
else:
56-
auto_lint = False
57-
model = Model(model_name)
58-
input_history_file = log_dir / ".aider.input.history"
59-
chat_history_file = log_dir / ".aider.chat.history.md"
60-
io = InputOutput(yes=True, input_history_file=input_history_file, chat_history_file=chat_history_file)
61-
coder = Coder.create(main_model=model, fnames=fnames, auto_lint=auto_lint, lint_cmds=lint_cmd, io=io)
62-
coder.run(message)
63-
64-
65-
def run_aider_for_repo(
38+
def run_agent_for_repo(
6639
commit0_config: Commit0Config | None,
67-
aider_config: AiderConfig | None,
40+
agent_config: AgentConfig | None,
6841
ds: dict,
6942
) -> None:
7043
"""Run Aider for a given repository."""
@@ -83,59 +56,42 @@ def run_aider_for_repo(
8356

8457
target_edit_files = get_target_edit_files(repo_path)
8558

59+
if agent_config.agent_name == "aider":
60+
agent = AiderAgents(agent_config.model_name)
61+
else:
62+
raise NotImplementedError(f"{agent_config.agent} is not implemented; please add your implementations in baselines/agents.py.")
63+
8664
with DirContext(repo_path):
87-
if commit0_config is None or aider_config is None:
65+
if commit0_config is None or agent_config is None:
8866
raise ValueError("Invalid input")
8967

90-
message_to_aider = get_message_to_aider(
91-
aider_config, target_edit_files, repo_path, ds
68+
message = get_message(
69+
agent_config, target_edit_files, repo_path, ds
9270
)
9371

94-
if aider_config.use_lint_info:
72+
if agent_config.use_lint_info:
9573
lint_cmd = "pre-commit run --config ../../.pre-commit-config.yaml --files"
9674
else:
9775
lint_cmd = ""
9876

99-
if aider_config.run_tests:
77+
if agent_config.run_tests:
78+
# when unit test feedback is available, iterate over test files
10079
for test_file in test_files:
10180
test_cmd = f"python -m commit0 test {repo_path} {test_file}"
102-
# set up logging
10381
test_file_name = test_file.replace(".py", "").replace("/", "__")
10482
log_dir = RUN_AIDER_LOG_DIR / "with_tests" / test_file_name
105-
log_dir.mkdir(parents=True, exist_ok=True)
106-
log_file = log_dir / "run_aider.log"
107-
108-
aider_cmd = run_aider(
109-
aider_config.llm_name,
110-
target_edit_files,
111-
message_to_aider,
112-
test_cmd,
113-
lint_cmd,
114-
log_dir,
115-
)
11683

117-
# write aider command to log file
118-
aider_cmd_file = Path(log_dir / "aider_cmd.sh")
119-
aider_cmd_file.write_text(aider_cmd)
120-
121-
# write test command to log file
122-
test_cmd_file = Path(log_dir / "test_cmd.sh")
123-
test_cmd_file.write_text(test_cmd)
84+
agent.run(
85+
message, test_cmd, lint_cmd, target_edit_files, log_dir,
86+
)
12487
else:
125-
test_cmd = ""
88+
# when unit test feedback is not available, iterate over target files to edit
12689
for f in target_edit_files:
12790
file_name = f.replace(".py", "").replace("/", "__")
12891
log_dir = RUN_AIDER_LOG_DIR / "no_tests" / file_name
129-
log_dir.mkdir(parents=True, exist_ok=True)
130-
log_file = log_dir / "run_aider.log"
131-
132-
aider_cmd = run_aider(
133-
aider_config.llm_name,
134-
[f],
135-
message_to_aider,
136-
test_cmd,
137-
lint_cmd,
138-
log_dir,
92+
93+
agent.run(
94+
message, "", lint_cmd, [f], log_dir
13995
)
14096

14197

@@ -146,15 +102,15 @@ def main() -> None:
146102
"""
147103
cs = ConfigStore.instance()
148104
cs.store(name="user", node=Commit0Config)
149-
cs.store(name="user", node=AiderConfig)
105+
cs.store(name="user", node=AgentConfig)
150106

151107
hydra.initialize(version_base=None, config_path="configs")
152-
config = hydra.compose(config_name="aider")
108+
config = hydra.compose(config_name="agent")
153109

154110
commit0_config = Commit0Config(**config.commit0_config)
155-
aider_config = AiderConfig(**config.aider_config)
111+
agent_config = AgentConfig(**config.agent_config)
156112

157-
if commit0_config is None or aider_config is None:
113+
if commit0_config is None or agent_config is None:
158114
raise ValueError("Invalid input")
159115

160116
dataset = load_dataset(
@@ -173,6 +129,7 @@ def main() -> None:
173129
in SPLIT.get(commit0_config.repo_split, [])
174130
)
175131
]
132+
assert len(filtered_dataset) > 0, "No examples available"
176133

177134
with tqdm(
178135
total=len(filtered_dataset), smoothing=0, desc="Running Aider for repos"
@@ -181,10 +138,10 @@ def main() -> None:
181138
# Create a future for running Aider for each repo
182139
futures = {
183140
executor.submit(
184-
run_aider_for_repo,
141+
run_agent_for_repo,
185142
commit0_config,
186-
aider_config,
187-
example if isinstance(example, dict) else {},
143+
agent_config,
144+
example
188145
): example
189146
for example in filtered_dataset
190147
}

commit0/harness/run_pytest_ids.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ def main(
5252
repo_name = None
5353
for example in dataset:
5454
repo_name = example["repo"].split("/")[-1]
55+
if repo_or_repo_dir.endswith("/"):
56+
repo_or_repo_dir = repo_or_repo_dir[:-1]
5557
if repo_name in os.path.basename(repo_or_repo_dir):
5658
spec = make_spec(example)
5759
break

0 commit comments

Comments
 (0)