Skip to content

Commit

Permalink
Merge pull request #424 from topoteretes/feature/cog-971-preparing-sw…
Browse files Browse the repository at this point in the history
…e-bench-run

Feature/cog 971 preparing swe bench run
  • Loading branch information
Vasilije1990 authored Jan 10, 2025
2 parents f7e808e + 6f7bbb0 commit f694ca2
Show file tree
Hide file tree
Showing 11 changed files with 106 additions and 49 deletions.
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
I need you to solve this issue by looking at the provided edges retrieved from a knowledge graph and
generate a single patch file that I can apply directly to this repository using git apply.
Please respond with a single patch file in the following format.
You are a senior software engineer. I need you to solve this issue by looking at the provided context and
generate a single patch file that I can apply directly to this repository using git apply.
Additionally, please make sure that you provide code only with correct syntax and
you apply the patch on the relevant files (together with their path that you can try to find out from the github issue). Don't change the names of existing
functions or classes, as they may be referenced from other code.
Please respond only with a single patch file in the following format without adding any additional context or string.
1 change: 1 addition & 0 deletions cognee/modules/chunking/models/DocumentChunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class DocumentChunk(DataPoint):
chunk_index: int
cut_type: str
is_part_of: Document
pydantic_type: str = "DocumentChunk"
contains: List[Entity] = None

_metadata: dict = {"index_fields": ["text"], "type": "DocumentChunk"}
1 change: 1 addition & 0 deletions cognee/modules/engine/models/Entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,6 @@ class Entity(DataPoint):
name: str
is_a: EntityType
description: str
pydantic_type: str = "Entity"

_metadata: dict = {"index_fields": ["name"], "type": "Entity"}
1 change: 1 addition & 0 deletions cognee/modules/engine/models/EntityType.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,6 @@ class EntityType(DataPoint):
__tablename__ = "entity_type"
name: str
description: str
pydantic_type: str = "EntityType"

_metadata: dict = {"index_fields": ["name"], "type": "EntityType"}
72 changes: 61 additions & 11 deletions cognee/modules/retrieval/description_to_codepart_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,35 @@
from cognee.modules.users.methods import get_default_user
from cognee.modules.users.models import User
from cognee.shared.utils import send_telemetry
from cognee.api.v1.search import SearchType
from cognee.api.v1.search.search_v2 import search
from cognee.infrastructure.llm.get_llm_client import get_llm_client


async def code_description_to_code_part_search(query: str, user: User = None, top_k=2) -> list:
async def code_description_to_code_part_search(
query: str, include_docs=False, user: User = None, top_k=5
) -> list:
if user is None:
user = await get_default_user()

if user is None:
raise PermissionError("No user found in the system. Please create a user.")

retrieved_codeparts = await code_description_to_code_part(query, user, top_k)
retrieved_codeparts = await code_description_to_code_part(query, user, top_k, include_docs)
return retrieved_codeparts


async def code_description_to_code_part(query: str, user: User, top_k: int) -> List[str]:
async def code_description_to_code_part(
query: str, user: User, top_k: int, include_docs: bool = False
) -> List[str]:
"""
Maps a code description query to relevant code parts using a CodeGraph pipeline.
Args:
query (str): The search query describing the code parts.
user (User): The user performing the search.
top_k (int): Number of codegraph descriptions to match ( num of corresponding codeparts will be higher)
include_docs(bool): Boolean showing whether we have the docs in the graph or not
Returns:
Set[str]: A set of unique code parts matching the query.
Expand All @@ -55,21 +63,49 @@ async def code_description_to_code_part(query: str, user: User, top_k: int) -> L
)

try:
results = await vector_engine.search("code_summary_text", query_text=query, limit=top_k)
if not results:
if include_docs:
search_results = await search(SearchType.INSIGHTS, query_text=query)

concatenated_descriptions = " ".join(
obj["description"]
for tpl in search_results
for obj in tpl
if isinstance(obj, dict) and "description" in obj
)

llm_client = get_llm_client()
context_from_documents = await llm_client.acreate_structured_output(
text_input=f"The retrieved context from documents"
f" is {concatenated_descriptions}.",
system_prompt="You are a Senior Software Engineer, summarize the context from documents"
f" in a way that it is gonna be provided next to codeparts as context"
f" while trying to solve this github issue connected to the project: {query}]",
response_model=str,
)

code_summaries = await vector_engine.search(
"code_summary_text", query_text=query, limit=top_k
)
if not code_summaries:
logging.warning("No results found for query: '%s' by user: %s", query, user.id)
return []

memory_fragment = CogneeGraph()
await memory_fragment.project_graph_from_db(
graph_engine,
node_properties_to_project=["id", "type", "text", "source_code"],
node_properties_to_project=[
"id",
"type",
"text",
"source_code",
"pydantic_type",
],
edge_properties_to_project=["relationship_name"],
)

code_pieces_to_return = set()

for node in results:
for node in code_summaries:
node_id = str(node.id)
node_to_search_from = memory_fragment.get_node(node_id)

Expand All @@ -78,9 +114,16 @@ async def code_description_to_code_part(query: str, user: User, top_k: int) -> L
continue

for code_file in node_to_search_from.get_skeleton_neighbours():
for code_file_edge in code_file.get_skeleton_edges():
if code_file_edge.get_attribute("relationship_name") == "contains":
code_pieces_to_return.add(code_file_edge.get_destination_node())
if code_file.get_attribute("pydantic_type") == "SourceCodeChunk":
for code_file_edge in code_file.get_skeleton_edges():
if code_file_edge.get_attribute("relationship_name") == "code_chunk_of":
code_pieces_to_return.add(code_file_edge.get_destination_node())
elif code_file.get_attribute("pydantic_type") == "CodePart":
code_pieces_to_return.add(code_file)
elif code_file.get_attribute("pydantic_type") == "CodeFile":
for code_file_edge in code_file.get_skeleton_edges():
if code_file_edge.get_attribute("relationship_name") == "contains":
code_pieces_to_return.add(code_file_edge.get_destination_node())

logging.info(
"Search completed for user: %s, query: '%s'. Found %d code pieces.",
Expand All @@ -89,7 +132,14 @@ async def code_description_to_code_part(query: str, user: User, top_k: int) -> L
len(code_pieces_to_return),
)

return list(code_pieces_to_return)
context = ""
for code_piece in code_pieces_to_return:
context = context + code_piece.get_attribute("source_code")

if include_docs:
context = context_from_documents + context

return context

except Exception as exec_error:
logging.error(
Expand Down
4 changes: 4 additions & 0 deletions cognee/shared/CodeGraphEntities.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
class Repository(DataPoint):
__tablename__ = "Repository"
path: str
pydantic_type: str = "Repository"
_metadata: dict = {"index_fields": [], "type": "Repository"}


class CodeFile(DataPoint):
__tablename__ = "codefile"
extracted_id: str # actually file path
pydantic_type: str = "CodeFile"
source_code: Optional[str] = None
part_of: Optional[Repository] = None
depends_on: Optional[List["CodeFile"]] = None
Expand All @@ -22,6 +24,7 @@ class CodeFile(DataPoint):
class CodePart(DataPoint):
__tablename__ = "codepart"
# part_of: Optional[CodeFile] = None
pydantic_type: str = "CodePart"
source_code: Optional[str] = None
_metadata: dict = {"index_fields": [], "type": "CodePart"}

Expand All @@ -30,6 +33,7 @@ class SourceCodeChunk(DataPoint):
__tablename__ = "sourcecodechunk"
code_chunk_of: Optional[CodePart] = None
source_code: Optional[str] = None
pydantic_type: str = "SourceCodeChunk"
previous_chunk: Optional["SourceCodeChunk"] = None

_metadata: dict = {"index_fields": ["source_code"], "type": "SourceCodeChunk"}
Expand Down
4 changes: 4 additions & 0 deletions cognee/shared/data_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ class SummarizedContent(BaseModel):

summary: str
description: str
pydantic_type: str = "SummarizedContent"


class SummarizedFunction(BaseModel):
Expand All @@ -239,13 +240,15 @@ class SummarizedFunction(BaseModel):
inputs: Optional[List[str]] = None
outputs: Optional[List[str]] = None
decorators: Optional[List[str]] = None
pydantic_type: str = "SummarizedFunction"


class SummarizedClass(BaseModel):
name: str
description: str
methods: Optional[List[SummarizedFunction]] = None
decorators: Optional[List[str]] = None
pydantic_type: str = "SummarizedClass"


class SummarizedCode(BaseModel):
Expand All @@ -256,6 +259,7 @@ class SummarizedCode(BaseModel):
classes: List[SummarizedClass] = []
functions: List[SummarizedFunction] = []
workflow_description: Optional[str] = None
pydantic_type: str = "SummarizedCode"


class GraphDBType(Enum):
Expand Down
1 change: 0 additions & 1 deletion cognee/tasks/repo_processor/extract_code_parts.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Dict, List
import parso

import logging

logger = logging.getLogger(__name__)
Expand Down
1 change: 0 additions & 1 deletion cognee/tasks/repo_processor/get_local_dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import jedi
import parso
from parso.tree import BaseNode

import logging

logger = logging.getLogger(__name__)
Expand Down
1 change: 1 addition & 0 deletions cognee/tasks/summarization/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,6 @@ class CodeSummary(DataPoint):
__tablename__ = "code_summary"
text: str
summarizes: Union[CodeFile, CodePart, SourceCodeChunk]
pydantic_type: str = "CodeSummary"

_metadata: dict = {"index_fields": ["text"], "type": "CodeSummary"}
60 changes: 27 additions & 33 deletions evals/eval_swe_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
from cognee.api.v1.search import SearchType
from cognee.infrastructure.llm.get_llm_client import get_llm_client
from cognee.infrastructure.llm.prompts import read_query_prompt
from cognee.modules.retrieval.brute_force_triplet_search import brute_force_triplet_search
from cognee.shared.utils import render_graph
from evals.eval_utils import download_github_repo, retrieved_edges_to_string
from cognee.modules.retrieval.description_to_codepart_search import (
code_description_to_code_part_search,
)


def check_install_package(package_name):
Expand All @@ -32,35 +32,28 @@ def check_install_package(package_name):
return False


async def generate_patch_with_cognee(instance, llm_client, search_type=SearchType.CHUNKS):
repo_path = download_github_repo(instance, "../RAW_GIT_REPOS")

async for result in run_code_graph_pipeline(repo_path, include_docs=True):
print(result)

print("Here we have the repo under the repo_path")

await render_graph(None, include_labels=True, include_nodes=True)

async def generate_patch_with_cognee(instance):
"""repo_path = download_github_repo(instance, "../RAW_GIT_REPOS")"""
include_docs = True
problem_statement = instance["problem_statement"]
instructions = read_query_prompt("patch_gen_kg_instructions.txt")

retrieved_edges = await brute_force_triplet_search(
problem_statement,
top_k=3,
collections=["code_summary_text"],
)
repo_path = "/Users/laszlohajdu/Documents/GitHub/graph_rag/"
async for result in run_code_graph_pipeline(repo_path, include_docs=include_docs):
print(result)

retrieved_edges_str = retrieved_edges_to_string(retrieved_edges)
retrieved_codeparts = await code_description_to_code_part_search(
problem_statement, include_docs=include_docs
)

prompt = "\n".join(
[
problem_statement,
"<patch>",
PATCH_EXAMPLE,
"</patch>",
"These are the retrieved edges:",
retrieved_edges_str,
"This is the additional context to solve the problem (description from documentation together with codeparts):",
retrieved_codeparts,
]
)

Expand All @@ -86,26 +79,25 @@ async def generate_patch_without_cognee(instance, llm_client):


async def get_preds(dataset, with_cognee=True):
llm_client = get_llm_client()

if with_cognee:
model_name = "with_cognee"
pred_func = generate_patch_with_cognee
else:
model_name = "without_cognee"
pred_func = generate_patch_without_cognee

futures = [(instance["instance_id"], pred_func(instance, llm_client)) for instance in dataset]
model_patches = await asyncio.gather(*[x[1] for x in futures])
preds = []

preds = [
{
"instance_id": instance_id,
"model_patch": model_patch,
"model_name_or_path": model_name,
}
for (instance_id, _), model_patch in zip(futures, model_patches)
]
for instance in dataset:
instance_id = instance["instance_id"]
model_patch = await pred_func(instance) # Sequentially await the async function
preds.append(
{
"instance_id": instance_id,
"model_patch": model_patch,
"model_name_or_path": model_name,
}
)

return preds

Expand Down Expand Up @@ -135,6 +127,7 @@ async def main():
with open(predictions_path, "w") as file:
json.dump(preds, file)

""" This part is for the evaluation
subprocess.run(
[
"python",
Expand All @@ -152,6 +145,7 @@ async def main():
"test_run",
]
)
"""


if __name__ == "__main__":
Expand Down

0 comments on commit f694ca2

Please sign in to comment.