diff --git a/.gitignore b/.gitignore index 9062184..0501971 100644 --- a/.gitignore +++ b/.gitignore @@ -162,9 +162,10 @@ cython_debug/ #.idea/ logs/ -repos/ +repos*/ config.yml hydra_outputs/ .commit0* .agent* -docs/analysis*.md \ No newline at end of file +docs/analysis*.md +agent/run_agent_no_rich.py \ No newline at end of file diff --git a/agent/README.md b/agent/README.md index 5d7a587..afcc998 100644 --- a/agent/README.md +++ b/agent/README.md @@ -1,6 +1,45 @@ # Agent for Commit0 This tool provides a command-line interface for configuring and running AI agents to assist with code development and testing. + +## (Update) Running with OpenHands + +**Step 1**: Clone (OpenHands)[https://github.com/All-Hands-AI/OpenHands/tree/main] and install (OpenHands)[https://github.com/All-Hands-AI/OpenHands/blob/main/evaluation/README.md#development-environment] + +**Step 2**: Create `config.toml` and write + +``` +[core] +workspace_base="~/OpenHands/evaluation/benchmarks/commit0" + +[llm] +model="anthropic/claude-3-5-sonnet-20241022" +api_key="..." +embedding_model="" +temperature = 0.0 +caching_prompt = true +``` + + +**Step 3**: Run +```bash +./evaluation/benchmarks/commit0/scripts/run_infer.sh SPLIT MODEL HEAD CodeActAgent 16 STEPS PARALLEL_NUMBER + +# Example +./evaluation/benchmarks/commit0/scripts/run_infer.sh lite llm.eval_deepseekv3 HEAD CodeActAgent 16 100 2 +``` + +**Step 3.1**: +You can do the following before running code to parallelize them on remote server from OpenHands + +```bash +export RUNTIME=remote +export SANDBOX_REMOTE_RUNTIME_API_URL="https://runtime.eval.all-hands.dev" +export ALLHANDS_API_KEY=... +``` + + +```python ## Quick Start Configure an agent: ```bash @@ -12,6 +51,11 @@ Run an agent on a specific branch: agent run [OPTIONS] BRANCH ``` +### Example +```bash +agent run sonnet --max-parallel-repos 16 --agent-config-file .agent_sonnet.yaml --commit0-config-file .commit0.yaml +``` + For more detailed information on available commands and options: ```bash agent -h diff --git a/agent/agent_utils.py b/agent/agent_utils.py index 4fdea82..72bd7ef 100644 --- a/agent/agent_utils.py +++ b/agent/agent_utils.py @@ -9,16 +9,22 @@ from import_deps import ModuleSet from graphlib import TopologicalSorter, CycleError import yaml - +from rank_bm25 import BM25Okapi from agent.class_types import AgentConfig +import subprocess PROMPT_HEADER = ">>> Here is the Task:\n" +FUNCTION_HEADER = "\n\n>>> Here are all functions in the file, complete the implementations for all functions (i.e., those with pass statements):\n" REFERENCE_HEADER = "\n\n>>> Here is the Reference for you to finish the task:\n" REPO_INFO_HEADER = "\n\n>>> Here is the Repository Information:\n" UNIT_TESTS_INFO_HEADER = "\n\n>>> Here are the Unit Tests Information:\n" LINT_INFO_HEADER = "\n\n>>> Here is the Lint Information:\n" SPEC_INFO_HEADER = "\n\n>>> Here is the Specification Information:\n" IMPORT_DEPENDENCIES_HEADER = "\n\n>>> Here are the Import Dependencies:\n" +FUNCTION_BY_FUNCTION_HEADER = """"\nYour task is to implement function {unimplemented_functions} by replacing the pass statement with actual functional code. +Please note that there could be multiple occurrences of {unimplemented_functions}, and you need to implement them all. +Do not change the names of existing functions or classes, as they may be referenced from other code like unit tests, etc. +When you generate code, you must maintain the original formatting of the function stubs (such as whitespaces), otherwise we will not able to search/replace blocks for code modifications, and therefore you will receive a score of 0 for your generated code.""" # prefix components: space = " " branch = "│ " @@ -123,6 +129,32 @@ def get_file_info(file_path: Path, prefix: str = "") -> str: return "\n".join(filter(None, tree_string)) +def get_unimplemented_functions(file_path: Path) -> List[str]: + """Get all the functions in a file.""" + with open(file_path, "r") as f: + content = f.read() + + # Find all function definitions with their bodies + pattern = r"def\s+(\w+)\s*\([^)]*\)[^:]*:(?:\s*(?:'''[\s\S]*?'''|\"\"\"[\s\S]*?\"\"\"))?\s*((?:(?!\ndef\s+).)*?)(?=\s*def\s+|\s*$)" + matches = re.finditer(pattern, content, re.DOTALL) + + # Keep only functions that have just 'pass' + # List to store unimplemented function definitions + unimplemented_functions = [] + for match in matches: + func_name = match.group(1) + func_body = match.group(2).strip() + # Check if function only contains 'pass' statement + if "pass" in func_body: + unimplemented_functions.append(f"def {func_name}()") + # # Find the full function definition using regex pattern + # func_pattern = rf"def\s+{func_name}\s*\([^)]*\)[^:]*:" + # func_match = re.search(func_pattern, content) + # if func_match: + # unimplemented.append(func_match.group(0)) + return unimplemented_functions + + def collect_test_files(directory: str) -> list[str]: """Collect all the test files in the directory.""" test_files = [] @@ -347,6 +379,7 @@ def get_message( agent_config: AgentConfig, repo_path: str, test_files: list[str] | None = None, + input_file: str | None = None, ) -> str: """Get the message to Aider.""" prompt = f"{PROMPT_HEADER}" + agent_config.user_prompt @@ -383,11 +416,11 @@ def get_message( with bz2.open("spec.pdf.bz2", "rb") as in_file: with open("spec.pdf", "wb") as out_file: out_file.write(in_file.read()) - spec_info = ( - f"\n{SPEC_INFO_HEADER} " - + get_specification(specification_pdf_path=Path(repo_path, "spec.pdf"))[ - : agent_config.max_spec_info_length - ] + spec_info = f"\n{SPEC_INFO_HEADER} " + get_specification( + specification_pdf_path=Path(repo_path, "spec.pdf"), + use_retrieval=True, + query=input_file if input_file else "", + top_k=10, ) else: spec_info = "" @@ -397,6 +430,42 @@ def get_message( return message_to_agent +def get_message_function_by_function( + agent_config: AgentConfig, + repo_path: str, + input_file: str, + test_files: list[str] | None = None, +) -> list[str]: + """Get the message to Aider.""" + context = get_message(agent_config, repo_path, test_files) + + if agent_config.implementation_strategy == "module_by_module": + function_info = [] + elif agent_config.implementation_strategy == "function_by_function": + function_info = [] + unimplemented_functions = get_unimplemented_functions( + file_path=Path(os.path.join(repo_path, input_file)) + ) + # Get the original function stubs and filter out implemented functions + for i in range(len(unimplemented_functions)): + function_info.append( + FUNCTION_BY_FUNCTION_HEADER.format( + unimplemented_functions=unimplemented_functions[i] + ) + ) + else: + raise ValueError( + f"Invalid implementation strategy: {agent_config.implementation_strategy}" + ) + + if agent_config.implementation_strategy == "function_by_function": + messages_to_agent = [context + uf for uf in function_info if len(uf) > 0] + else: + messages_to_agent = [] + + return messages_to_agent + + def update_message_with_dependencies(message: str, dependencies: list[str]) -> str: """Update the message with the dependencies.""" if len(dependencies) == 0: @@ -411,19 +480,43 @@ def update_message_with_dependencies(message: str, dependencies: list[str]) -> s return message -def get_specification(specification_pdf_path: Path) -> str: +def get_specification( + specification_pdf_path: Path, + use_retrieval: bool = True, + query: str = "", + top_k: int = 20, +) -> str: """Get the reference for a given specification PDF path.""" # TODO: after pdf_to_text is available, use it to extract the text from the PDF # Open the specified PDF file + document = fitz.open(specification_pdf_path) - text = "" + corpus = [] + # current_trunk = "" # Iterate through the pages for page_num in range(len(document)): page = document.load_page(page_num) # loads the specified page - text += page.get_text() # type: ignore - return text + current_page_text = page.get_text() # type: ignore + # Cut page text into chunks of 1000 characters + text_chunks = [ + current_page_text[i : i + 1000] + for i in range(0, len(current_page_text), 1000) + ] + corpus.extend(text_chunks) + # corpus.append(page.get_text()) # type: ignore + if not use_retrieval: + return "\n".join(corpus) + + assert query != "", "query should not be empty" + query = open(query).read() + tokenized_corpus = [doc.split(" ") for doc in corpus] + bm25 = BM25Okapi(tokenized_corpus) + doc_scores = bm25.get_scores(query) + sorted_doc_scores = sorted(enumerate(doc_scores), key=lambda x: x[1], reverse=True) + sorted_doc_indices = [i for i, _ in sorted_doc_scores] + return "\n".join(corpus[i] for i in sorted_doc_indices[:top_k]) def create_branch(repo: git.Repo, branch: str, from_commit: str) -> None: @@ -486,6 +579,23 @@ def get_changed_files_from_commits( return [] +def run_eval_after_each_commit( + branch: str, + backend: str, + commit0_config_file: str, +) -> str: + """Run the eval command after each commit.""" + eval_cmd = f"python -m commit0 evaluate --branch {branch} --backend {backend} --commit0-config-file {commit0_config_file} --timeout 100" + try: + result = subprocess.run( + eval_cmd, shell=True, capture_output=True, text=True, check=True + ) + return result.stdout + except subprocess.CalledProcessError as e: + print(f"Error running eval command: {e}") + return e.stdout if e.stdout else str(e) + + def args2string(agent_config: AgentConfig) -> str: """Converts specific fields from an `AgentConfig` object into a formatted string. diff --git a/agent/agents.py b/agent/agents.py index e908090..63ffd72 100644 --- a/agent/agents.py +++ b/agent/agents.py @@ -27,6 +27,8 @@ def __init__(self, log_file: Path): self.log_file = log_file self.last_cost = 0.0 + self.total_token_in = 0 + self.total_token_out = 0 class Agents(ABC): @@ -43,6 +45,8 @@ class AiderReturn(AgentReturn): def __init__(self, log_file: Path): super().__init__(log_file) self.last_cost = self.get_money_cost() + self.total_token_in = self.get_total_token_in() + self.total_token_out = self.get_total_token_out() def get_money_cost(self) -> float: """Get accumulated money cost from log file""" @@ -57,18 +61,54 @@ def get_money_cost(self) -> float: last_cost = float(match.group(1)) return last_cost + def get_total_token_in(self) -> int: + """Get total token in from log file""" + total_tokens = 0 + with open(self.log_file, "r") as file: + for line in file: + if "Tokens:" in line: + match = re.search(r"Tokens: ([\d.]+k?) sent", line) + if match: + token_str = match.group(1) + if token_str.endswith("k"): + total_tokens = int(float(token_str[:-1]) * 1000) + else: + total_tokens = int(float(token_str)) + return total_tokens + + def get_total_token_out(self) -> int: + """Get total token out from log file""" + total_tokens = 0 + with open(self.log_file, "r") as file: + for line in file: + if "Tokens:" in line: + match = re.search(r"(\d+) received", line) + if match: + total_str = match.group(1) + if total_str.endswith("k"): + total_tokens = int(float(total_str[:-1]) * 1000) + else: + total_tokens = int(float(total_str)) + return total_tokens + class AiderAgents(Agents): def __init__(self, max_iteration: int, model_name: str): super().__init__(max_iteration) self.model = Model(model_name) # Check if API key is set for the model - if "gpt" in model_name: + if "openrouter" in model_name: + api_key = os.environ.get("OPENROUTER_API_KEY", None) + elif "gpt" in model_name: api_key = os.environ.get("OPENAI_API_KEY", None) elif "claude" in model_name: api_key = os.environ.get("ANTHROPIC_API_KEY", None) elif "gemini" in model_name: - api_key = os.environ.get("API_KEY", None) + api_key = os.environ.get("GEMINI_API_KEY", None) + elif "deepseek" in model_name: + api_key = os.environ.get("DEEPSEEK_API_KEY", None) + elif "mistral" in model_name: + api_key = os.environ.get("MISTRAL_API_KEY", None) else: raise ValueError(f"Unsupported model: {model_name}") @@ -87,6 +127,7 @@ def run( log_dir: Path, test_first: bool = False, lint_first: bool = False, + current_attempt: int = 0, ) -> AgentReturn: """Start aider agent""" if test_cmd: @@ -99,11 +140,22 @@ def run( auto_lint = False log_dir = log_dir.resolve() log_dir.mkdir(parents=True, exist_ok=True) - input_history_file = log_dir / ".aider.input.history" - chat_history_file = log_dir / ".aider.chat.history.md" - + input_history_file = ( + log_dir / ".aider.input.history" + if current_attempt == 0 + else log_dir / f".aider_{current_attempt}.input.history" + ) + chat_history_file = ( + log_dir / ".aider.chat.history.md" + if current_attempt == 0 + else log_dir / f".aider_{current_attempt}.chat.history.md" + ) # Set up logging - log_file = log_dir / "aider.log" + log_file = ( + log_dir / "aider.log" + if current_attempt == 0 + else log_dir / f"aider_{current_attempt}.log" + ) logging.basicConfig( filename=log_file, level=logging.INFO, @@ -133,7 +185,7 @@ def run( io=io, ) coder.max_reflections = self.max_iteration - coder.stream = True + coder.stream = False # Run the agent if test_first: diff --git a/agent/class_types.py b/agent/class_types.py index 12c74d4..457f0b9 100644 --- a/agent/class_types.py +++ b/agent/class_types.py @@ -22,3 +22,5 @@ class AgentConfig: run_tests: bool max_iteration: int record_test_for_each_commit: bool + implementation_strategy: str + repeat_times_for_each_inquiry: int diff --git a/agent/cli.py b/agent/cli.py index b02bf7a..53f6393 100644 --- a/agent/cli.py +++ b/agent/cli.py @@ -1,5 +1,4 @@ import typer -from agent.run_agent_no_rich import run_agent as run_agent_no_rich from agent.run_agent import run_agent from commit0.harness.constants import RUN_AGENT_LOG_DIR import subprocess @@ -135,6 +134,14 @@ def config( False, help="Run the lint on the entire directory", ), + implementation_strategy: str = typer.Option( + "module_by_module", + help="Implementation strategy to use", + ), + repeat_times_for_each_inquiry: int = typer.Option( + 1, + help="Repeat times for each inquiry", + ), record_test_for_each_commit: bool = typer.Option( False, help="Record the test for each commit", @@ -173,6 +180,8 @@ def config( "use_lint_info": use_lint_info, "max_lint_info_length": max_lint_info_length, "run_entire_dir_lint": run_entire_dir_lint, + "implementation_strategy": implementation_strategy, + "repeat_times_for_each_inquiry": repeat_times_for_each_inquiry, "pre_commit_config_path": pre_commit_config_path, "record_test_for_each_commit": record_test_for_each_commit, } @@ -232,12 +241,13 @@ def run( display_repo_progress_num, ) else: - run_agent_no_rich( - branch, - override_previous_changes, - backend, - agent_config_file, - commit0_config_file, - log_dir, - max_parallel_repos, - ) + # run_agent_no_rich( + # branch, + # override_previous_changes, + # backend, + # agent_config_file, + # commit0_config_file, + # log_dir, + # max_parallel_repos, + # ) + raise NotImplementedError("Currently not supported") diff --git a/agent/display.py b/agent/display.py index b5605d1..1dce704 100644 --- a/agent/display.py +++ b/agent/display.py @@ -215,7 +215,8 @@ def update_agent_display( ("use_spec_info", "Use Spec", use_spec_info), ("use_lint_info", "Use Lint", use_lint_info), ] - + self.name_of_agent = agent_name.replace("/", "_") + self.name_of_llm = model_name.replace("/", "_") for attr_name, title, value in info_items: text = Text(f"{value}", justify="center") setattr(self, attr_name, text) @@ -438,11 +439,11 @@ def __exit__( } with open( - f"processing_summary_{self.branch_name}.json", + f"{self.log_dir_display}/processing_summary_{self.branch_name}_{self.name_of_agent}_{self.name_of_llm}.json", "w", ) as json_file: json.dump(summary_data, json_file, indent=4) print( - f"\nSummary has been written to processing_summary_{self.branch_name}.json" + f"\nSummary has been written to {self.log_dir_display}/processing_summary_{self.branch_name}_{self.name_of_agent}_{self.name_of_llm}.json" ) diff --git a/agent/run_agent.py b/agent/run_agent.py index 8978e17..f63af85 100644 --- a/agent/run_agent.py +++ b/agent/run_agent.py @@ -6,12 +6,15 @@ from agent.agent_utils import ( create_branch, get_message, + get_message_function_by_function, + run_eval_after_each_commit, get_target_edit_files, get_changed_files_from_commits, update_message_with_dependencies, get_lint_cmd, read_yaml_config, ) +from agent.agents import AgentReturn import json import subprocess from agent.agents import AiderAgents @@ -46,19 +49,81 @@ def __exit__( os.chdir(self.cwd) -def run_eval_after_each_commit( - branch: str, backend: str, commit0_config_file: str -) -> str: - """Run the eval command after each commit.""" - eval_cmd = f"python -m commit0 evaluate --branch {branch} --backend {backend} --commit0-config-file {commit0_config_file} --timeout 100" - try: - result = subprocess.run( - eval_cmd, shell=True, capture_output=True, text=True, check=True +def run_agent_multiple_times_on_same_inquiry( + agent: AiderAgents, + repo: Repo, + branch: str, + message: str, + fnames: list[str], + test_cmd: str, + test_first: bool, + lint_cmd: str, + lint_first: bool, + log_dir: Path, + repeat_times_for_each_inquiry: int, + backend: str, + commit0_config_file: str, +) -> AgentReturn | None: + """Run agent multiple times on the same inquiry and return the best performing agent return""" + if repeat_times_for_each_inquiry == 1: + return agent.run( + message, test_cmd, lint_cmd, fnames, log_dir, test_first, lint_first ) - return result.stdout - except subprocess.CalledProcessError as e: - print(f"Error running eval command: {e}") - return e.stdout if e.stdout else str(e) + else: + commit_before_run = repo.head.commit.hexsha + commit_results = {} + best_commit_diff = "" + best_eval_result = float("-inf") + best_agent_return = None + for attempt in range(repeat_times_for_each_inquiry): + agent_return = agent.run( + message, + test_cmd, + lint_cmd, + fnames, + log_dir, + test_first, + lint_first, + attempt, + ) + current_commit = repo.head.commit.hexsha + eval_result = run_eval_after_each_commit( + branch, backend, commit0_config_file + ) + # Get diff and store results + diff = repo.git.diff(commit_before_run, current_commit) + commit_results[current_commit] = {"eval_result": eval_result, "diff": diff} + # print("current_commit: ", current_commit) + # print("commit_results: ", eval_result.split("average pass rate: ")[-1] if "average pass rate: " in eval_result else 0) + # with open("/home/nan/commit0_rebuttal/tmp.json", "w") as fuck: + # json.dump(commit_results, fuck) + # Track best performing commit's diff + score = float( + eval_result.split("average pass rate: ")[-1] + if "average pass rate: " in eval_result + else 0 + ) + if best_commit_diff is None: + # if score > best_eval_result: + best_eval_result = score + best_commit_diff = diff + best_agent_return = agent_return + else: + if score > best_eval_result: + best_eval_result = score + best_commit_diff = diff + best_agent_return = agent_return + + repo.git.reset("--hard", commit_before_run) + with open(log_dir.resolve() / "eval_results.json", "w") as f: + json.dump(commit_results, f, indent=4) + patch_path = os.path.abspath(str(log_dir.resolve() / "best_diff.patch")) + with open(patch_path, "w") as f: + f.write(best_commit_diff + "\n") + repo.git.execute(["git", "apply", patch_path]) + repo.git.add(fnames) + repo.git.commit("-m", f"Applied best performing changes for {fnames}") + return best_agent_return def run_agent_for_repo( @@ -156,6 +221,7 @@ def run_agent_for_repo( if agent_config is None: raise ValueError("Invalid input") + agent_return = None if agent_config.run_tests: update_queue.put(("start_repo", (repo_name, len(test_files)))) # when unit test feedback is available, iterate over test files @@ -170,13 +236,20 @@ def run_agent_for_repo( message = get_message(agent_config, repo_path, test_files=[test_file]) # display the test file to terminal - agent_return = agent.run( - "", - test_cmd, - lint_cmd, - target_edit_files, - test_log_dir, + agent_return = run_agent_multiple_times_on_same_inquiry( + agent=agent, + repo=local_repo, + branch=branch, + message=message, + fnames=[test_file], + test_cmd=test_cmd, test_first=True, + lint_cmd=lint_cmd, + lint_first=False, + log_dir=test_log_dir, + repeat_times_for_each_inquiry=agent_config.repeat_times_for_each_inquiry, + backend=backend, + commit0_config_file=commit0_config_file, ) if agent_config.record_test_for_each_commit: current_commit = local_repo.head.commit.hexsha @@ -188,7 +261,11 @@ def run_agent_for_repo( update_queue.put( ( "update_money_display", - (repo_name, test_file, agent_return.last_cost), + ( + repo_name, + test_file, + agent_return.last_cost if agent_return is not None else 0, + ), ) ) elif agent_config.run_entire_dir_lint: @@ -203,13 +280,20 @@ def run_agent_for_repo( ) # display the test file to terminal - agent_return = agent.run( - "", - "", - lint_cmd, - [lint_file], - lint_log_dir, + agent_return = run_agent_multiple_times_on_same_inquiry( + agent=agent, + repo=local_repo, + branch=branch, + message="", + fnames=[lint_file], + test_cmd="", + test_first=False, + lint_cmd=lint_cmd, lint_first=True, + log_dir=lint_log_dir, + repeat_times_for_each_inquiry=agent_config.repeat_times_for_each_inquiry, + backend=backend, + commit0_config_file=commit0_config_file, ) if agent_config.record_test_for_each_commit: current_commit = local_repo.head.commit.hexsha @@ -221,25 +305,74 @@ def run_agent_for_repo( update_queue.put( ( "update_money_display", - (repo_name, lint_file, agent_return.last_cost), + ( + repo_name, + lint_file, + agent_return.last_cost if agent_return is not None else 0, + ), ) ) else: # when unit test feedback is not available, iterate over target files to edit - message = get_message(agent_config, repo_path, test_files=test_files) - update_queue.put(("start_repo", (repo_name, len(target_edit_files)))) for f in target_edit_files: update_queue.put(("set_current_file", (repo_name, f))) - if agent_config.add_import_module_to_context: - dependencies = import_dependencies.get(f, []) - message = update_message_with_dependencies(message, dependencies) file_name = f.replace(".py", "").replace("/", "__") file_log_dir = experiment_log_dir / file_name lint_cmd = get_lint_cmd( repo_name, agent_config.use_lint_info, commit0_config_file ) - agent_return = agent.run(message, "", lint_cmd, [f], file_log_dir) + if agent_config.implementation_strategy == "function_by_function": + messages = get_message_function_by_function( + agent_config, repo_path, f, test_files + ) + for message in messages: + agent_return = run_agent_multiple_times_on_same_inquiry( + agent=agent, + repo=local_repo, + branch=branch, + message=message, + fnames=[f], + test_cmd="", + test_first=False, + lint_cmd=lint_cmd, + lint_first=False, + log_dir=file_log_dir, + repeat_times_for_each_inquiry=agent_config.repeat_times_for_each_inquiry, + backend=backend, + commit0_config_file=commit0_config_file, + ) + + elif agent_config.implementation_strategy == "module_by_module": + message = get_message( + agent_config, repo_path, test_files=test_files, input_file=f + ) + if agent_config.add_import_module_to_context: + dependencies = import_dependencies.get(f, []) + message = update_message_with_dependencies( + message, dependencies + ) + + agent_return = run_agent_multiple_times_on_same_inquiry( + agent=agent, + repo=local_repo, + branch=branch, + message=message, + fnames=[f], + test_cmd="", + test_first=False, + lint_cmd=lint_cmd, + lint_first=False, + log_dir=file_log_dir, + repeat_times_for_each_inquiry=agent_config.repeat_times_for_each_inquiry, + backend=backend, + commit0_config_file=commit0_config_file, + ) + else: + raise ValueError( + f"Invalid implementation strategy: {agent_config.implementation_strategy}" + ) + if agent_config.record_test_for_each_commit: current_commit = local_repo.head.commit.hexsha eval_results[current_commit] = run_eval_after_each_commit( @@ -249,7 +382,11 @@ def run_agent_for_repo( update_queue.put( ( "update_money_display", - (repo_name, file_name, agent_return.last_cost), + ( + repo_name, + file_name, + agent_return.last_cost if agent_return is not None else 0, + ), ) ) if agent_config.record_test_for_each_commit: @@ -292,10 +429,8 @@ def run_agent( in SPLIT.get(commit0_config["repo_split"], []) ) ] - assert len(filtered_dataset) > 0, "No examples available" - # if len(filtered_dataset) > 1: - # sys.stdout = open(os.devnull, "w") + assert len(filtered_dataset) > 0, "No examples available" if agent_config.add_import_module_to_context: # Install Chrome for Playwright for browser-based agents diff --git a/agent/run_agent_no_rich.py b/agent/run_agent_no_rich.py deleted file mode 100644 index 7f953e0..0000000 --- a/agent/run_agent_no_rich.py +++ /dev/null @@ -1,269 +0,0 @@ -import os -import yaml -import multiprocessing -from tqdm import tqdm -from datasets import load_dataset -from git import Repo -from agent.agent_utils import ( - create_branch, - get_message, - get_target_edit_files, - get_changed_files_from_commits, - update_message_with_dependencies, - get_lint_cmd, - read_yaml_config, -) -import subprocess -import json -from agent.agents import AiderAgents -from typing import cast -from agent.class_types import AgentConfig -from commit0.harness.constants import SPLIT -from commit0.harness.get_pytest_ids import main as get_tests -from commit0.harness.constants import RUN_AGENT_LOG_DIR, RepoInstance -from commit0.cli import read_commit0_config_file -from pathlib import Path -from datetime import datetime -from agent.run_agent import DirContext, run_eval_after_each_commit - - -def run_agent_for_repo( - repo_base_dir: str, - agent_config: AgentConfig, - example: RepoInstance, - branch: str, - override_previous_changes: bool = False, - backend: str = "modal", - log_dir: str = str(RUN_AGENT_LOG_DIR.resolve()), - commit0_config_file: str = "", -) -> None: - """Run Aider for a given repository.""" - # get repo info - commit0_config = read_commit0_config_file(commit0_config_file) - - assert "commit0" in commit0_config["dataset_name"] - _, repo_name = example["repo"].split("/") - - # repo_name = repo_name.lower() - # repo_name = repo_name.replace(".", "-") - - repo_path = os.path.join(repo_base_dir, repo_name) - repo_path = os.path.abspath(repo_path) - - try: - local_repo = Repo(repo_path) - except Exception: - raise Exception( - f"{repo_path} is not a git repo. Check if base_dir is correctly specified." - ) - - if agent_config.agent_name == "aider": - agent = AiderAgents(agent_config.max_iteration, agent_config.model_name) - else: - raise NotImplementedError( - f"{agent_config.agent_name} is not implemented; please add your implementations in baselines/agents.py." - ) - - # Check if there are changes in the current branch - if local_repo.is_dirty(): - # Stage all changes - local_repo.git.add(A=True) - # Commit changes with the message "left from last change" - local_repo.index.commit("left from last change") - - # # if branch_name is not provided, create a new branch name based on agent_config - # if branch is None: - # branch = args2string(agent_config) - create_branch(local_repo, branch, example["base_commit"]) - - # in cases where the latest commit of branch is not commit 0 - # set it back to commit 0 - latest_commit = local_repo.commit(branch) - if latest_commit.hexsha != example["base_commit"] and override_previous_changes: - local_repo.git.reset("--hard", example["base_commit"]) - - # get target files to edit and test files to run - target_edit_files, import_dependencies = get_target_edit_files( - local_repo, - example["src_dir"], - example["test"]["test_dir"], - branch, - example["reference_commit"], - agent_config.use_topo_sort_dependencies, - ) - - lint_files = get_changed_files_from_commits( - local_repo, "HEAD", example["base_commit"] - ) - # Call the commit0 get-tests command to retrieve test files - test_files_str = [xx for x in get_tests(repo_name, verbose=0) for xx in x] - test_files = sorted(list(set([i.split(":")[0] for i in test_files_str]))) - - # prepare the log dir - experiment_log_dir = ( - Path(log_dir) - / repo_name - / branch - / datetime.now().strftime("%Y-%m-%d_%H-%M-%S") - ) - experiment_log_dir.mkdir(parents=True, exist_ok=True) - eval_results = {} - - # write agent_config to .agent.yaml in the log_dir for record - agent_config_log_file = experiment_log_dir / ".agent.yaml" - with open(agent_config_log_file, "w") as agent_config_file: - yaml.dump(agent_config, agent_config_file) - - with DirContext(repo_path): - if agent_config is None: - raise ValueError("Invalid input") - - if agent_config.run_tests: - # when unit test feedback is available, iterate over test files - for test_file in test_files: - test_cmd = f"python -m commit0 test {repo_path} {test_file} --branch {branch} --backend {backend} --commit0-config-file {commit0_config_file} --timeout 100" - test_file_name = test_file.replace(".py", "").replace("/", "__") - test_log_dir = experiment_log_dir / test_file_name - lint_cmd = get_lint_cmd( - repo_name, agent_config.use_lint_info, commit0_config_file - ) - message = get_message(agent_config, repo_path, test_files=[test_file]) - - # display the test file to terminal - _ = agent.run( - "", - test_cmd, - lint_cmd, - target_edit_files, - test_log_dir, - test_first=True, - ) - if agent_config.record_test_for_each_commit: - current_commit = local_repo.head.commit.hexsha - eval_results[current_commit] = run_eval_after_each_commit( - branch, backend, commit0_config_file - ) - elif agent_config.run_entire_dir_lint: - # when unit test feedback is available, iterate over test files - for lint_file in lint_files: - lint_file_name = lint_file.replace(".py", "").replace("/", "__") - lint_log_dir = experiment_log_dir / lint_file_name - lint_cmd = get_lint_cmd( - repo_name, agent_config.use_lint_info, commit0_config_file - ) - - # display the test file to terminal - _ = agent.run( - "", - "", - lint_cmd, - [lint_file], - lint_log_dir, - lint_first=True, - ) - if agent_config.record_test_for_each_commit: - current_commit = local_repo.head.commit.hexsha - eval_results[current_commit] = run_eval_after_each_commit( - branch, backend, commit0_config_file - ) - else: - # when unit test feedback is not available, iterate over target files to edit - message = get_message(agent_config, repo_path, test_files=test_files) - - for f in target_edit_files: - if agent_config.add_import_module_to_context: - dependencies = import_dependencies.get(f, []) - message = update_message_with_dependencies(message, dependencies) - file_name = f.replace(".py", "").replace("/", "__") - file_log_dir = experiment_log_dir / file_name - lint_cmd = get_lint_cmd( - repo_name, agent_config.use_lint_info, commit0_config_file - ) - _ = agent.run(message, "", lint_cmd, [f], file_log_dir) - if agent_config.record_test_for_each_commit: - current_commit = local_repo.head.commit.hexsha - eval_results[current_commit] = run_eval_after_each_commit( - branch, backend, commit0_config_file - ) - if agent_config.record_test_for_each_commit: - with open(experiment_log_dir / "eval_results.json", "w") as f: - json.dump(eval_results, f) - - -def run_agent( - branch: str, - override_previous_changes: bool, - backend: str, - agent_config_file: str, - commit0_config_file: str, - log_dir: str, - max_parallel_repos: int, -) -> None: - """Main function to run Aider for a given repository. - - Will run in parallel for each repo. - """ - config = read_yaml_config(agent_config_file) - - agent_config = AgentConfig(**config) - - commit0_config_file = os.path.abspath(commit0_config_file) - commit0_config = read_commit0_config_file(commit0_config_file) - - dataset = load_dataset( - commit0_config["dataset_name"], split=commit0_config["dataset_split"] - ) - filtered_dataset = [ - example - for example in dataset - if commit0_config["repo_split"] == "all" - or ( - isinstance(example, dict) - and "repo" in example - and isinstance(example["repo"], str) - and example["repo"].split("/")[-1] - in SPLIT.get(commit0_config["repo_split"], []) - ) - ] - assert len(filtered_dataset) > 0, "No examples available" - - # if len(filtered_dataset) > 1: - # sys.stdout = open(os.devnull, "w") - if agent_config.add_import_module_to_context: - # Install Chrome for Playwright for browser-based agents - try: - subprocess.run(["playwright", "install", "chromium"], check=True) - print("Chrome installed successfully for Playwright") - except subprocess.CalledProcessError as e: - print(f"Error installing Chrome for Playwright: {e}") - except FileNotFoundError: - print("Playwright not found. Make sure it's installed and in your PATH.") - - with tqdm( - total=len(filtered_dataset), smoothing=0, desc="Running Aider for repos" - ) as pbar: - with multiprocessing.Pool(processes=max_parallel_repos) as pool: - results = [] - - # Use apply_async to submit jobs and add progress bar updates - for example in filtered_dataset: - result = pool.apply_async( - run_agent_for_repo, - args=( - commit0_config["base_dir"], - agent_config, - cast(RepoInstance, example), - branch, - override_previous_changes, - backend, - log_dir, - commit0_config_file, - ), - callback=lambda _: pbar.update( - 1 - ), # Update progress bar on task completion - ) - results.append(result) - - for result in results: - result.wait() diff --git a/pyproject.toml b/pyproject.toml index 7eabe00..aa44ecf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,6 +22,7 @@ dependencies = [ "strenum>=0.4.15", "e2b-code-interpreter>=1.0.4", "python-dotenv>=1.0.1", + "rank_bm25>=0.2.1", ] classifiers = [ "License :: OSI Approved :: MIT License",