Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/packages/autogen-ext/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ llama-cpp = [
"llama-cpp-python>=0.3.8",
]

graphrag = ["graphrag>=1.0.1"]
graphrag = ["graphrag>=2.3.0"]
chromadb = ["chromadb>=1.0.0"]
mem0 = ["mem0ai>=0.1.98"]
mem0-local = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,19 @@

class DataConfig(BaseModel):
input_dir: str
entity_table: str = "create_final_nodes"
entity_embedding_table: str = "create_final_entities"
entity_table: str = "entities"
entity_embedding_table: str = "entities"
community_table: str = "communities"
community_level: int = 2


class GlobalDataConfig(DataConfig):
community_table: str = "create_final_communities"
community_report_table: str = "create_final_community_reports"
community_report_table: str = "community_reports"


class LocalDataConfig(DataConfig):
relationship_table: str = "create_final_relationships"
text_unit_table: str = "create_final_text_units"
relationship_table: str = "relationships"
text_unit_table: str = "text_units"


class ContextConfig(BaseModel):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# mypy: disable-error-code="no-any-unimported,misc"
from pathlib import Path

import pandas as pd
Expand All @@ -7,14 +6,15 @@
from autogen_core.tools import BaseTool
from pydantic import BaseModel, Field

from graphrag.config.config_file_loader import load_config_from_file
import graphrag.config.defaults as defs
from graphrag.config.load_config import load_config
from graphrag.language_model.manager import ModelManager
from graphrag.language_model.protocol import ChatModel
from graphrag.query.indexer_adapters import (
read_indexer_communities,
read_indexer_entities,
read_indexer_reports,
)
from graphrag.query.llm.base import BaseLLM
from graphrag.query.llm.get_client import get_llm
from graphrag.query.structured_search.global_search.community_context import GlobalCommunityContext
from graphrag.query.structured_search.global_search.search import GlobalSearch

Expand Down Expand Up @@ -64,6 +64,7 @@ class GlobalSearchTool(BaseTool[GlobalSearchToolArgs, GlobalSearchToolReturn]):
.. code-block:: python

import asyncio
from pathlib import Path
from autogen_ext.models.openai import OpenAIChatCompletionClient
from autogen_agentchat.ui import Console
from autogen_ext.tools.graphrag import GlobalSearchTool
Expand All @@ -78,7 +79,7 @@ async def main():
)

# Set up global search tool
global_tool = GlobalSearchTool.from_settings(settings_path="./settings.yaml")
global_tool = GlobalSearchTool.from_settings(root_dir=Path("./"), config_filepath=Path("./settings.yaml"))

# Create assistant agent with the global search tool
assistant_agent = AssistantAgent(
Expand All @@ -104,7 +105,7 @@ async def main():
def __init__(
self,
token_encoder: tiktoken.Encoding,
llm: BaseLLM,
model: ChatModel,
data_config: DataConfig,
context_config: ContextConfig = _default_context_config,
mapreduce_config: MapReduceConfig = _default_mapreduce_config,
Expand All @@ -115,22 +116,20 @@ def __init__(
name="global_search_tool",
description="Perform a global search with given parameters using graphrag.",
)
# Use the provided LLM
self._llm = llm
# Use the provided model
self._model = model

# Load parquet files
community_df: pd.DataFrame = pd.read_parquet(f"{data_config.input_dir}/{data_config.community_table}.parquet") # type: ignore
entity_df: pd.DataFrame = pd.read_parquet(f"{data_config.input_dir}/{data_config.entity_table}.parquet") # type: ignore
report_df: pd.DataFrame = pd.read_parquet( # type: ignore
f"{data_config.input_dir}/{data_config.community_report_table}.parquet"
)
entity_embedding_df: pd.DataFrame = pd.read_parquet( # type: ignore
f"{data_config.input_dir}/{data_config.entity_embedding_table}.parquet"
)

communities = read_indexer_communities(community_df, entity_df, report_df)
reports = read_indexer_reports(report_df, entity_df, data_config.community_level)
entities = read_indexer_entities(entity_df, entity_embedding_df, data_config.community_level)
# Fix: Use correct argument order and types for GraphRAG API
communities = read_indexer_communities(community_df, report_df)
reports = read_indexer_reports(report_df, community_df, data_config.community_level)
entities = read_indexer_entities(entity_df, community_df, data_config.community_level)

context_builder = GlobalCommunityContext(
community_reports=reports,
Expand Down Expand Up @@ -164,7 +163,7 @@ def __init__(
}

self._search_engine = GlobalSearch(
llm=self._llm,
model=self._model,
context_builder=context_builder,
token_encoder=token_encoder,
max_data_tokens=context_config.max_data_tokens,
Expand All @@ -178,37 +177,56 @@ def __init__(
)

async def run(self, args: GlobalSearchToolArgs, cancellation_token: CancellationToken) -> GlobalSearchToolReturn:
search_result = await self._search_engine.asearch(args.query)
search_result = await self._search_engine.search(args.query)
assert isinstance(search_result.response, str), "Expected response to be a string"
return GlobalSearchToolReturn(answer=search_result.response)

@classmethod
def from_settings(cls, settings_path: str | Path) -> "GlobalSearchTool":
def from_settings(cls, root_dir: str | Path, config_filepath: str | Path | None = None) -> "GlobalSearchTool":
"""Create a GlobalSearchTool instance from GraphRAG settings file.

Args:
settings_path: Path to the GraphRAG settings.yaml file
root_dir: Path to the GraphRAG root directory
config_filepath: Path to the GraphRAG settings file (optional)

Returns:
An initialized GlobalSearchTool instance
"""
# Load GraphRAG config
config = load_config_from_file(settings_path)

# Initialize token encoder
token_encoder = tiktoken.get_encoding(config.encoding_model)

# Initialize LLM using graphrag's get_client
llm = get_llm(config)
if isinstance(root_dir, str):
root_dir = Path(root_dir)
if isinstance(config_filepath, str):
config_filepath = Path(config_filepath)
config = load_config(root_dir=root_dir, config_filepath=config_filepath)

# Get the language model configuration from the models section
chat_model_config = config.models.get(defs.DEFAULT_CHAT_MODEL_ID)

if chat_model_config is None:
raise ValueError("default_chat_model not found in config.models")

# Initialize token encoder based on the model being used
try:
token_encoder = tiktoken.encoding_for_model(chat_model_config.model)
except KeyError:
# Fallback to cl100k_base if model is not recognized by tiktoken
token_encoder = tiktoken.get_encoding("cl100k_base")

# Create the LLM using ModelManager
model = ModelManager().get_or_create_chat_model(
name="global_search_model",
model_type=chat_model_config.type,
config=chat_model_config,
)

# Create data config from storage paths
data_config = DataConfig(
input_dir=str(Path(config.storage.base_dir)),
input_dir=str(config.output.base_dir),
)

return cls(
token_encoder=token_encoder,
llm=llm,
model=model,
data_config=data_config,
context_config=_default_context_config,
mapreduce_config=_default_mapreduce_config,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# mypy: disable-error-code="no-any-unimported,misc"
import os
from pathlib import Path

import pandas as pd
Expand All @@ -8,14 +7,15 @@
from autogen_core.tools import BaseTool
from pydantic import BaseModel, Field

from graphrag.config.config_file_loader import load_config_from_file
import graphrag.config.defaults as defs
from graphrag.config.load_config import load_config
from graphrag.language_model.manager import ModelManager
from graphrag.language_model.protocol import ChatModel, EmbeddingModel
from graphrag.query.indexer_adapters import (
read_indexer_entities,
read_indexer_relationships,
read_indexer_text_units,
)
from graphrag.query.llm.base import BaseLLM, BaseTextEmbedding
from graphrag.query.llm.get_client import get_llm, get_text_embedder
from graphrag.query.structured_search.local_search.mixed_context import LocalSearchMixedContext
from graphrag.query.structured_search.local_search.search import LocalSearch
from graphrag.vector_stores.lancedb import LanceDBVectorStore
Expand Down Expand Up @@ -64,6 +64,7 @@ class LocalSearchTool(BaseTool[LocalSearchToolArgs, LocalSearchToolReturn]):
.. code-block:: python

import asyncio
from pathlib import Path
from autogen_ext.models.openai import OpenAIChatCompletionClient
from autogen_agentchat.ui import Console
from autogen_ext.tools.graphrag import LocalSearchTool
Expand All @@ -78,7 +79,7 @@ async def main():
)

# Set up local search tool
local_tool = LocalSearchTool.from_settings(settings_path="./settings.yaml")
local_tool = LocalSearchTool.from_settings(root_dir=Path("./"), config_filepath=Path("./settings.yaml"))

# Create assistant agent with the local search tool
assistant_agent = AssistantAgent(
Expand All @@ -103,8 +104,8 @@ async def main():

Args:
token_encoder (tiktoken.Encoding): The tokenizer used for text encoding
llm (BaseLLM): The language model to use for search
embedder (BaseTextEmbedding): The text embedding model to use
model: The chat model to use for search (GraphRAG ChatModel)
embedder: The text embedding model to use (GraphRAG EmbeddingModel)
data_config (DataConfig): Configuration for data source locations and settings
context_config (LocalContextConfig, optional): Configuration for context building. Defaults to default config.
search_config (SearchConfig, optional): Configuration for search operations. Defaults to default config.
Expand All @@ -113,8 +114,8 @@ async def main():
def __init__(
self,
token_encoder: tiktoken.Encoding,
llm: BaseLLM,
embedder: BaseTextEmbedding,
model: ChatModel, # ChatModel from GraphRAG
embedder: EmbeddingModel, # EmbeddingModel from GraphRAG
data_config: DataConfig,
context_config: LocalContextConfig = _default_context_config,
search_config: SearchConfig = _default_search_config,
Expand All @@ -125,30 +126,23 @@ def __init__(
name="local_search_tool",
description="Perform a local search with given parameters using graphrag.",
)
# Use the adapter
self._llm = llm
# Use the provided models
self._model = model
self._embedder = embedder

# Load parquet files
entity_df: pd.DataFrame = pd.read_parquet(f"{data_config.input_dir}/{data_config.entity_table}.parquet") # type: ignore
entity_embedding_df: pd.DataFrame = pd.read_parquet( # type: ignore
f"{data_config.input_dir}/{data_config.entity_embedding_table}.parquet"
)
relationship_df: pd.DataFrame = pd.read_parquet( # type: ignore
f"{data_config.input_dir}/{data_config.relationship_table}.parquet"
)
text_unit_df: pd.DataFrame = pd.read_parquet(f"{data_config.input_dir}/{data_config.text_unit_table}.parquet") # type: ignore
community_df: pd.DataFrame = pd.read_parquet(f"{data_config.input_dir}/{data_config.community_table}.parquet") # type: ignore

# Read data using indexer adapters
entities = read_indexer_entities(entity_df, entity_embedding_df, data_config.community_level)
entities = read_indexer_entities(entity_df, community_df, data_config.community_level)
relationships = read_indexer_relationships(relationship_df)
text_units = read_indexer_text_units(text_unit_df)
# Set up vector store for entity embeddings
description_embedding_store = LanceDBVectorStore(
collection_name="default-entity-description",
)
description_embedding_store.connect(db_uri=os.path.join(data_config.input_dir, "lancedb"))

description_embedding_store = LanceDBVectorStore(
collection_name="default-entity-description",
)
Expand Down Expand Up @@ -180,47 +174,70 @@ def __init__(
}

self._search_engine = LocalSearch(
llm=self._llm,
model=self._model,
context_builder=context_builder,
token_encoder=token_encoder,
llm_params=llm_params,
context_builder_params=context_builder_params,
response_type=search_config.response_type,
context_builder_params=context_builder_params,
model_params=llm_params,
)

async def run(self, args: LocalSearchToolArgs, cancellation_token: CancellationToken) -> LocalSearchToolReturn:
search_result = await self._search_engine.asearch(args.query) # type: ignore
search_result = await self._search_engine.search(args.query) # type: ignore[reportUnknownMemberType]
assert isinstance(search_result.response, str), "Expected response to be a string"
return LocalSearchToolReturn(answer=search_result.response)

@classmethod
def from_settings(cls, settings_path: str | Path) -> "LocalSearchTool":
def from_settings(cls, root_dir: Path, config_filepath: Path | None = None) -> "LocalSearchTool":
"""Create a LocalSearchTool instance from GraphRAG settings file.

Args:
settings_path: Path to the GraphRAG settings.yaml file
root_dir: Path to the GraphRAG root directory
config_filepath: Path to the GraphRAG settings file (optional)

Returns:
An initialized LocalSearchTool instance
"""
# Load GraphRAG config
config = load_config_from_file(settings_path)

# Initialize token encoder
token_encoder = tiktoken.get_encoding(config.encoding_model)
config = load_config(root_dir=root_dir, config_filepath=config_filepath)

# Get the language model configurations from the models section
chat_model_config = config.models.get(defs.DEFAULT_CHAT_MODEL_ID)
embedding_model_config = config.models.get(defs.DEFAULT_EMBEDDING_MODEL_ID)

if chat_model_config is None:
raise ValueError("default_chat_model not found in config.models")
if embedding_model_config is None:
raise ValueError("default_embedding_model not found in config.models")

# Initialize token encoder based on the model being used
try:
token_encoder = tiktoken.encoding_for_model(chat_model_config.model)
except KeyError:
# Fallback to cl100k_base if model is not recognized by tiktoken
token_encoder = tiktoken.get_encoding("cl100k_base")

# Create the models using ModelManager
model = ModelManager().get_or_create_chat_model(
name="local_search_model",
model_type=chat_model_config.type,
config=chat_model_config,
)

# Initialize LLM and embedder using graphrag's get_client functions
llm = get_llm(config)
embedder = get_text_embedder(config)
embedder = ModelManager().get_or_create_embedding_model(
name="local_search_embedder",
model_type=embedding_model_config.type,
config=embedding_model_config,
)

# Create data config from storage paths
data_config = DataConfig(
input_dir=str(Path(config.storage.base_dir)),
input_dir=str(config.output.base_dir),
)

return cls(
token_encoder=token_encoder,
llm=llm,
model=model,
embedder=embedder,
data_config=data_config,
context_config=_default_context_config,
Expand Down
Loading
Loading