Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 9 additions & 36 deletions docs/examples_notebooks/drift_search.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
"outputs": [],
"source": [
"import os\n",
"from pathlib import Path\n",
"\n",
"import pandas as pd\n",
"from graphrag.config.enums import ModelType\n",
Expand Down Expand Up @@ -62,15 +61,13 @@
"# load description embeddings to an in-memory lancedb vectorstore\n",
"# to connect to a remote db, specify url and port values.\n",
"description_embedding_store = LanceDBVectorStore(\n",
" vector_store_schema_config=VectorStoreSchemaConfig(\n",
" index_name=\"default-entity-description\"\n",
" ),\n",
" vector_store_schema_config=VectorStoreSchemaConfig(index_name=\"entity_description\"),\n",
")\n",
"description_embedding_store.connect(db_uri=LANCEDB_URI)\n",
"\n",
"full_content_embedding_store = LanceDBVectorStore(\n",
" vector_store_schema_config=VectorStoreSchemaConfig(\n",
" index_name=\"default-community-full_content\"\n",
" index_name=\"community_full_content\"\n",
" )\n",
")\n",
"full_content_embedding_store.connect(db_uri=LANCEDB_URI)\n",
Expand All @@ -88,7 +85,11 @@
"text_units = read_indexer_text_units(text_unit_df)\n",
"\n",
"print(f\"Text unit records: {len(text_unit_df)}\")\n",
"text_unit_df.head()"
"text_unit_df.head()\n",
"\n",
"report_df = pd.read_parquet(f\"{INPUT_DIR}/{COMMUNITY_REPORT_TABLE}.parquet\")\n",
"reports = read_indexer_reports(report_df, community_df, COMMUNITY_LEVEL)\n",
"read_indexer_report_embeddings(reports, full_content_embedding_store)"
]
},
{
Expand Down Expand Up @@ -118,7 +119,7 @@
" api_key=api_key,\n",
" type=ModelType.Embedding,\n",
" model_provider=\"openai\",\n",
" model=\"text-embedding-3-small\",\n",
" model=\"text-embedding-3-large\",\n",
" max_retries=20,\n",
")\n",
"\n",
Expand All @@ -129,44 +130,16 @@
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def read_community_reports(\n",
" input_dir: str,\n",
" community_report_table: str = COMMUNITY_REPORT_TABLE,\n",
"):\n",
" \"\"\"Embeds the full content of the community reports and saves the DataFrame with embeddings to the output path.\"\"\"\n",
" input_path = Path(input_dir) / f\"{community_report_table}.parquet\"\n",
" return pd.read_parquet(input_path)\n",
"\n",
"\n",
"report_df = read_community_reports(INPUT_DIR)\n",
"reports = read_indexer_reports(\n",
" report_df,\n",
" community_df,\n",
" COMMUNITY_LEVEL,\n",
" content_embedding_col=\"full_content_embeddings\",\n",
")\n",
"read_indexer_report_embeddings(reports, full_content_embedding_store)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"drift_params = DRIFTSearchConfig(\n",
" temperature=0,\n",
" max_tokens=12_000,\n",
" primer_folds=1,\n",
" drift_k_followups=3,\n",
" n_depth=3,\n",
" n=1,\n",
")\n",
"\n",
"context_builder = DRIFTSearchContextBuilder(\n",
Expand Down Expand Up @@ -216,7 +189,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "graphrag",
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.

This file was deleted.

Binary file not shown.

This file was deleted.

Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.

This file was deleted.

Binary file not shown.

This file was deleted.

Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.

This file was deleted.

Binary file not shown.

This file was deleted.

Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
$dcac110d-2a49-4777-a51e-5078fed1b0df�id = '__DUMMY__'
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
$c2afb84c-4b3a-4ccd-8843-0deaa25bd971�id = '__DUMMY__'
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -28,18 +28,11 @@ async def embed_text(
num_threads: int,
vector_store: BaseVectorStore,
id_column: str = "id",
title_column: str | None = None,
):
"""Embed a piece of text into a vector space. The operation outputs a new column containing a mapping between doc_id and vector."""
if embed_column not in input.columns:
msg = f"Column {embed_column} not found in input dataframe with columns {input.columns}"
raise ValueError(msg)
title = title_column or embed_column
if title not in input.columns:
msg = (
f"Column {title} not found in input dataframe with columns {input.columns}"
)
raise ValueError(msg)
if id_column not in input.columns:
msg = f"Column {id_column} not found in input dataframe with columns {input.columns}"
raise ValueError(msg)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@

"""All the steps to transform final entities."""

from uuid import uuid4

import pandas as pd

from graphrag.data_model.schemas import COMMUNITY_REPORTS_FINAL_COLUMNS
from graphrag.index.utils.hashing import gen_sha512_hash


def finalize_community_reports(
Expand All @@ -25,7 +24,9 @@ def finalize_community_reports(

community_reports["community"] = community_reports["community"].astype(int)
community_reports["human_readable_id"] = community_reports["community"]
community_reports["id"] = [uuid4().hex for _ in range(len(community_reports))]
community_reports["id"] = community_reports.apply(
lambda row: gen_sha512_hash(row, ["full_content"]), axis=1
)

return community_reports.loc[
:,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
Add sections and commentary to the response as appropriate for the length and format.
Additionally provide a score between 0 and 100 representing how well the response addresses the overall research question: {global_query}. Based on your response, suggest up to five follow-up questions that could be asked to further explore the topic as it relates to the overall research question. Do not include scores or follow up questions in the 'response' field of the JSON, add them to the respective 'score' and 'follow_up_queries' keys of the JSON output. Format your response in JSON with the following keys and values:
Additionally provide a score between 0 and 100 representing how well the response addresses the overall research question: {global_query}. Based on your response, suggest up to {followups} follow-up questions that could be asked to further explore the topic as it relates to the overall research question. Do not include scores or follow up questions in the 'response' field of the JSON, add them to the respective 'score' and 'follow_up_queries' keys of the JSON output. Format your response in JSON with the following keys and values:
{{'response': str, Put your answer, formatted in markdown, here. Do not answer the global query in this section.
'score': int,
Expand Down
28 changes: 1 addition & 27 deletions packages/graphrag/graphrag/query/indexer_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,12 @@

import pandas as pd

from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.data_model.community import Community
from graphrag.data_model.community_report import CommunityReport
from graphrag.data_model.covariate import Covariate
from graphrag.data_model.entity import Entity
from graphrag.data_model.relationship import Relationship
from graphrag.data_model.text_unit import TextUnit
from graphrag.language_model.manager import ModelManager
from graphrag.language_model.protocol.base import EmbeddingModel
from graphrag.query.input.loaders.dfs import (
read_communities,
Expand Down Expand Up @@ -76,8 +74,6 @@ def read_indexer_reports(
final_communities: pd.DataFrame,
community_level: int | None,
dynamic_community_selection: bool = False,
content_embedding_col: str = "full_content_embedding",
config: GraphRagConfig | None = None,
) -> list[CommunityReport]:
"""Read in the Community Reports from the raw indexing outputs.

Expand All @@ -102,29 +98,7 @@ def read_indexer_reports(
filtered_community_df, on="community", how="inner"
)

if config and (
content_embedding_col not in reports_df.columns
or reports_df.loc[:, content_embedding_col].isna().any()
):
# TODO: Find a way to retrieve the right embedding model id.
embedding_model_settings = config.get_language_model_config(
"default_embedding_model"
)
embedder = ModelManager().get_or_create_embedding_model(
name="default_embedding",
model_type=embedding_model_settings.type,
config=embedding_model_settings,
)
reports_df = embed_community_reports(
reports_df, embedder, embedding_col=content_embedding_col
)

return read_community_reports(
df=reports_df,
id_col="id",
short_id_col="community",
content_embedding_col=content_embedding_col,
)
return read_community_reports(df=reports_df, id_col="id", short_id_col="community")


def read_indexer_report_embeddings(
Expand Down
4 changes: 0 additions & 4 deletions packages/graphrag/graphrag/query/input/loaders/dfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,6 @@ def read_community_reports(
summary_col: str = "summary",
content_col: str = "full_content",
rank_col: str | None = "rank",
content_embedding_col: str | None = "full_content_embedding",
attributes_cols: list[str] | None = None,
) -> list[CommunityReport]:
"""Read community reports from a dataframe using pre-converted records."""
Expand All @@ -213,9 +212,6 @@ def read_community_reports(
summary=to_str(row, summary_col),
full_content=to_str(row, content_col),
rank=to_optional_float(row, rank_col),
full_content_embedding=to_optional_list(
row, content_embedding_col, item_type=float
),
attributes=(
{col: row.get(col) for col in attributes_cols}
if attributes_cols
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,13 @@ def is_complete(self) -> bool:
"""Check if the action is complete (i.e., an answer is available)."""
return self.answer is not None

async def search(self, search_engine: Any, global_query: str, scorer: Any = None):
async def search(
self,
search_engine: Any,
global_query: str,
k_followups: int,
scorer: Any = None,
):
"""
Execute an asynchronous search using the search engine, and update the action with the results.

Expand All @@ -71,7 +77,9 @@ async def search(self, search_engine: Any, global_query: str, scorer: Any = None
return self

search_result = await search_engine.search(
drift_query=global_query, query=self.query
query=self.query,
drift_query=global_query,
k_followups=k_followups,
)

# Do not launch exception as it will roll up with other steps
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class DRIFTSearchContextBuilder(DRIFTContextBuilder):
def __init__(
self,
model: ChatModel,
config: DRIFTSearchConfig,
text_embedder: EmbeddingModel,
entities: list[Entity],
entity_text_embeddings: BaseVectorStore,
Expand All @@ -49,14 +50,13 @@ def __init__(
covariates: dict[str, list[Covariate]] | None = None,
tokenizer: Tokenizer | None = None,
embedding_vectorstore_key: str = EntityVectorStoreKey.ID,
config: DRIFTSearchConfig | None = None,
local_system_prompt: str | None = None,
local_mixed_context: LocalSearchMixedContext | None = None,
reduce_system_prompt: str | None = None,
response_type: str | None = None,
):
"""Initialize the DRIFT search context builder with necessary components."""
self.config = config or DRIFTSearchConfig()
self.config = config
self.model = model
self.text_embedder = text_embedder
self.tokenizer = tokenizer or get_tokenizer()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,11 @@ def _process_primer_results(
raise ValueError(error_msg)

async def _search_step(
self, global_query: str, search_engine: LocalSearch, actions: list[DriftAction]
self,
global_query: str,
k_followups: int,
search_engine: LocalSearch,
actions: list[DriftAction],
) -> list[DriftAction]:
"""
Perform an asynchronous search step by executing each DriftAction asynchronously.
Expand All @@ -171,7 +175,11 @@ async def _search_step(
list[DriftAction]: The results from executing the search actions asynchronously.
"""
tasks = [
action.search(search_engine=search_engine, global_query=global_query)
action.search(
search_engine=search_engine,
global_query=global_query,
k_followups=k_followups,
)
for action in actions
]
return await tqdm_asyncio.gather(*tasks, leave=False)
Expand Down Expand Up @@ -241,7 +249,10 @@ async def search(
)
# Process actions
results = await self._search_step(
global_query=query, search_engine=self.local_search, actions=actions
global_query=query,
k_followups=self.context_builder.config.drift_k_followups,
search_engine=self.local_search,
actions=actions,
)

# Update query state
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ async def search(
context_data=context_result.context_chunks,
response_type=self.response_type,
global_query=drift_query,
followups=kwargs.get("k_followups", 0),
)
else:
search_prompt = self.system_prompt.format(
Expand Down
4 changes: 0 additions & 4 deletions packages/graphrag/graphrag/vector_stores/cosmosdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,14 +156,10 @@ def load_documents(self, documents: list[VectorStoreDocument]) -> None:
# Upload documents to CosmosDB
for doc in documents:
if doc.vector is not None:
print("Document to store:") # noqa: T201
print(doc) # noqa: T201
doc_json = {
self.id_field: doc.id,
self.vector_field: doc.vector,
}
print("Storing document in CosmosDB:") # noqa: T201
print(doc_json) # noqa: T201
self._container_client.upsert_item(doc_json)

def similarity_search_by_vector(
Expand Down
7 changes: 1 addition & 6 deletions packages/graphrag/graphrag/vector_stores/lancedb.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,7 @@ def load_documents(self, documents: list[VectorStoreDocument]) -> None:
})

if data:
self.document_collection = self.db_connection.create_table(
self.index_name if self.index_name else "",
data=data,
mode="overwrite",
schema=data.schema,
)
self.document_collection.add(data)

def similarity_search_by_vector(
self, query_embedding: list[float] | np.ndarray, k: int = 10
Expand Down
2 changes: 1 addition & 1 deletion tests/fixtures/min-csv/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@
},
{
"query": "What is the major conflict in this story and who are the protagonist and antagonist?",
"method": "global"
"method": "drift"
}
],
"slow": false
Expand Down
Loading