diff --git a/.semversioner/next-release/major-20251009203808375389.json b/.semversioner/next-release/major-20251009203808375389.json new file mode 100644 index 0000000000..4bf235ecb0 --- /dev/null +++ b/.semversioner/next-release/major-20251009203808375389.json @@ -0,0 +1,4 @@ +{ + "type": "major", + "description": "Simplify internal args with stronger types and firmer boundaries." +} diff --git a/docs/examples_notebooks/index_migration_to_v1.ipynb b/docs/examples_notebooks/index_migration_to_v1.ipynb index ecff51929a..581f5cef64 100644 --- a/docs/examples_notebooks/index_migration_to_v1.ipynb +++ b/docs/examples_notebooks/index_migration_to_v1.ipynb @@ -202,45 +202,44 @@ "metadata": {}, "outputs": [], "source": [ - "from graphrag.index.flows.generate_text_embeddings import generate_text_embeddings\n", - "\n", "from graphrag.cache.factory import CacheFactory\n", "from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks\n", - "from graphrag.config.embeddings import get_embedded_fields, get_embedding_settings\n", + "from graphrag.config.get_vector_store_settings import get_vector_store_settings\n", + "from graphrag.index.workflows.generate_text_embeddings import generate_text_embeddings\n", "\n", "# We only need to re-run the embeddings workflow, to ensure that embeddings for all required search fields are in place\n", "# We'll construct the context and run this function flow directly to avoid everything else\n", "\n", "\n", - "embedded_fields = get_embedded_fields(config)\n", - "text_embed = get_embedding_settings(config)\n", + "vector_store_config = get_vector_store_settings(config)\n", + "model_config = config.get_language_model_config(config.embed_text.model_id)\n", "callbacks = NoopWorkflowCallbacks()\n", "cache_config = config.cache.model_dump() # type: ignore\n", "cache = CacheFactory().create_cache(\n", " cache_type=cache_config[\"type\"], # type: ignore\n", - " root_dir=PROJECT_DIRECTORY,\n", - " kwargs=cache_config,\n", + " **cache_config,\n", ")\n", "\n", "await generate_text_embeddings(\n", - " final_documents=None,\n", - " final_relationships=None,\n", - " final_text_units=final_text_units,\n", - " final_entities=final_entities,\n", - " final_community_reports=final_community_reports,\n", + " documents=None,\n", + " relationships=None,\n", + " text_units=final_text_units,\n", + " entities=final_entities,\n", + " community_reports=final_community_reports,\n", " callbacks=callbacks,\n", " cache=cache,\n", - " storage=storage,\n", - " text_embed_config=text_embed,\n", - " embedded_fields=embedded_fields,\n", - " snapshot_embeddings_enabled=False,\n", + " model_config=model_config,\n", + " batch_size=config.embed_text.batch_size,\n", + " batch_max_tokens=config.embed_text.batch_max_tokens,\n", + " vector_store_config=vector_store_config,\n", + " embedded_fields=config.embed_text.names,\n", ")" ] } ], "metadata": { "kernelspec": { - "display_name": ".venv", + "display_name": "graphrag", "language": "python", "name": "python3" }, @@ -254,7 +253,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.9" + "version": "3.12.10" } }, "nbformat": 4, diff --git a/graphrag/config/defaults.py b/graphrag/config/defaults.py index c2a0cb6fb9..eadafd9860 100644 --- a/graphrag/config/defaults.py +++ b/graphrag/config/defaults.py @@ -60,6 +60,7 @@ ENCODING_MODEL = "o200k_base" COGNITIVE_SERVICES_AUDIENCE = "https://cognitiveservices.azure.com/.default" +DEFAULT_ENTITY_TYPES = ["organization", "person", "geo", "event"] DEFAULT_RETRY_SERVICES: dict[str, Callable[..., Retry]] = { "native": NativeRetry, @@ -125,7 +126,6 @@ class CommunityReportDefaults: text_prompt: None = None max_length: int = 2000 max_input_length: int = 8000 - strategy: None = None model_id: str = DEFAULT_CHAT_MODEL_ID @@ -162,10 +162,9 @@ class DriftSearchDefaults: class EmbedTextDefaults: """Default values for embedding text.""" - model: str = "text-embedding-3-small" + model_id: str = DEFAULT_EMBEDDING_MODEL_ID batch_size: int = 16 batch_max_tokens: int = 8191 - model_id: str = DEFAULT_EMBEDDING_MODEL_ID names: list[str] = field(default_factory=lambda: default_embeddings) strategy: None = None vector_store_id: str = DEFAULT_VECTOR_STORE_ID diff --git a/graphrag/config/get_embedding_settings.py b/graphrag/config/get_vector_store_settings.py similarity index 53% rename from graphrag/config/get_embedding_settings.py rename to graphrag/config/get_vector_store_settings.py index 9522f31359..3771d65ff3 100644 --- a/graphrag/config/get_embedding_settings.py +++ b/graphrag/config/get_vector_store_settings.py @@ -1,19 +1,16 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License -"""A module containing get_embedding_settings.""" +"""A module containing get_vector_store_settings.""" from graphrag.config.models.graph_rag_config import GraphRagConfig -def get_embedding_settings( +def get_vector_store_settings( settings: GraphRagConfig, vector_store_params: dict | None = None, ) -> dict: """Transform GraphRAG config into settings for workflows.""" - embeddings_llm_settings = settings.get_language_model_config( - settings.embed_text.model_id - ) vector_store_settings = settings.get_vector_store_config( settings.embed_text.vector_store_id ).model_dump() @@ -23,16 +20,7 @@ def get_embedding_settings( # settings.vector_store.base contains connection information, or may be undefined # settings.vector_store. contains the specific settings for this embedding # - strategy = settings.embed_text.resolved_strategy( - embeddings_llm_settings - ) # get the default strategy - strategy.update({ - "vector_store": { - **(vector_store_params or {}), - **(vector_store_settings), - } - }) # update the default strategy with the vector store settings - # This ensures the vector store config is part of the strategy and not the global config return { - "strategy": strategy, + **(vector_store_params or {}), + **(vector_store_settings), } diff --git a/graphrag/config/init_content.py b/graphrag/config/init_content.py index 421cfe2e30..aadacf8f38 100644 --- a/graphrag/config/init_content.py +++ b/graphrag/config/init_content.py @@ -28,7 +28,6 @@ # api_version: 2024-05-01-preview model_supports_json: true # recommended if this is available for your model. concurrent_requests: {language_model_defaults.concurrent_requests} - async_mode: {language_model_defaults.async_mode.value} # or asyncio retry_strategy: {language_model_defaults.retry_strategy} max_retries: {language_model_defaults.max_retries} tokens_per_minute: null @@ -42,7 +41,6 @@ # api_base: https://.openai.azure.com # api_version: 2024-05-01-preview concurrent_requests: {language_model_defaults.concurrent_requests} - async_mode: {language_model_defaults.async_mode.value} # or asyncio retry_strategy: {language_model_defaults.retry_strategy} max_retries: {language_model_defaults.max_retries} tokens_per_minute: null @@ -102,7 +100,6 @@ extract_graph_nlp: text_analyzer: extractor_type: {graphrag_config_defaults.extract_graph_nlp.text_analyzer.extractor_type.value} # [regex_english, syntactic_parser, cfg] - async_mode: {graphrag_config_defaults.extract_graph_nlp.async_mode.value} # or asyncio cluster_graph: max_cluster_size: {graphrag_config_defaults.cluster_graph.max_cluster_size} diff --git a/graphrag/config/models/community_reports_config.py b/graphrag/config/models/community_reports_config.py index b4e9259489..5369dea232 100644 --- a/graphrag/config/models/community_reports_config.py +++ b/graphrag/config/models/community_reports_config.py @@ -3,12 +3,24 @@ """Parameterization settings for the default configuration.""" +from dataclasses import dataclass from pathlib import Path from pydantic import BaseModel, Field from graphrag.config.defaults import graphrag_config_defaults -from graphrag.config.models.language_model_config import LanguageModelConfig +from graphrag.prompts.index.community_report import COMMUNITY_REPORT_PROMPT +from graphrag.prompts.index.community_report_text_units import ( + COMMUNITY_REPORT_TEXT_PROMPT, +) + + +@dataclass +class CommunityReportPrompts: + """Community report prompt templates.""" + + graph_prompt: str + text_prompt: str class CommunityReportsConfig(BaseModel): @@ -34,32 +46,16 @@ class CommunityReportsConfig(BaseModel): description="The maximum input length in tokens to use when generating reports.", default=graphrag_config_defaults.community_reports.max_input_length, ) - strategy: dict | None = Field( - description="The override strategy to use.", - default=graphrag_config_defaults.community_reports.strategy, - ) - def resolved_strategy( - self, root_dir: str, model_config: LanguageModelConfig - ) -> dict: - """Get the resolved community report extraction strategy.""" - from graphrag.index.operations.summarize_communities.typing import ( - CreateCommunityReportsStrategyType, - ) - - return self.strategy or { - "type": CreateCommunityReportsStrategyType.graph_intelligence, - "llm": model_config.model_dump(), - "graph_prompt": (Path(root_dir) / self.graph_prompt).read_text( + def resolved_prompts(self, root_dir: str) -> CommunityReportPrompts: + """Get the resolved community report extraction prompts.""" + return CommunityReportPrompts( + graph_prompt=(Path(root_dir) / self.graph_prompt).read_text( encoding="utf-8" ) if self.graph_prompt - else None, - "text_prompt": (Path(root_dir) / self.text_prompt).read_text( - encoding="utf-8" - ) + else COMMUNITY_REPORT_PROMPT, + text_prompt=(Path(root_dir) / self.text_prompt).read_text(encoding="utf-8") if self.text_prompt - else None, - "max_report_length": self.max_length, - "max_input_length": self.max_input_length, - } + else COMMUNITY_REPORT_TEXT_PROMPT, + ) diff --git a/graphrag/config/models/text_embedding_config.py b/graphrag/config/models/embed_text_config.py similarity index 57% rename from graphrag/config/models/text_embedding_config.py rename to graphrag/config/models/embed_text_config.py index e154675a1e..f785bf6eed 100644 --- a/graphrag/config/models/text_embedding_config.py +++ b/graphrag/config/models/embed_text_config.py @@ -6,10 +6,9 @@ from pydantic import BaseModel, Field from graphrag.config.defaults import graphrag_config_defaults -from graphrag.config.models.language_model_config import LanguageModelConfig -class TextEmbeddingConfig(BaseModel): +class EmbedTextConfig(BaseModel): """Configuration section for text embeddings.""" model_id: str = Field( @@ -32,21 +31,3 @@ class TextEmbeddingConfig(BaseModel): description="The specific embeddings to perform.", default=graphrag_config_defaults.embed_text.names, ) - strategy: dict | None = Field( - description="The override strategy to use.", - default=graphrag_config_defaults.embed_text.strategy, - ) - - def resolved_strategy(self, model_config: LanguageModelConfig) -> dict: - """Get the resolved text embedding strategy.""" - from graphrag.index.operations.embed_text.embed_text import ( - TextEmbedStrategyType, - ) - - return self.strategy or { - "type": TextEmbedStrategyType.openai, - "llm": model_config.model_dump(), - "num_threads": model_config.concurrent_requests, - "batch_size": self.batch_size, - "batch_max_tokens": self.batch_max_tokens, - } diff --git a/graphrag/config/models/extract_claims_config.py b/graphrag/config/models/extract_claims_config.py index 166cc29d4e..78fe926795 100644 --- a/graphrag/config/models/extract_claims_config.py +++ b/graphrag/config/models/extract_claims_config.py @@ -3,15 +3,23 @@ """Parameterization settings for the default configuration.""" +from dataclasses import dataclass from pathlib import Path from pydantic import BaseModel, Field from graphrag.config.defaults import graphrag_config_defaults -from graphrag.config.models.language_model_config import LanguageModelConfig +from graphrag.prompts.index.extract_claims import EXTRACT_CLAIMS_PROMPT -class ClaimExtractionConfig(BaseModel): +@dataclass +class ClaimExtractionPrompts: + """Claim extraction prompt templates.""" + + extraction_prompt: str + + +class ExtractClaimsConfig(BaseModel): """Configuration section for claim extraction.""" enabled: bool = Field( @@ -34,22 +42,11 @@ class ClaimExtractionConfig(BaseModel): description="The maximum number of entity gleanings to use.", default=graphrag_config_defaults.extract_claims.max_gleanings, ) - strategy: dict | None = Field( - description="The override strategy to use.", - default=graphrag_config_defaults.extract_claims.strategy, - ) - def resolved_strategy( - self, root_dir: str, model_config: LanguageModelConfig - ) -> dict: - """Get the resolved claim extraction strategy.""" - return self.strategy or { - "llm": model_config.model_dump(), - "extraction_prompt": (Path(root_dir) / self.prompt).read_text( - encoding="utf-8" - ) + def resolved_prompts(self, root_dir: str) -> ClaimExtractionPrompts: + """Get the resolved claim extraction prompts.""" + return ClaimExtractionPrompts( + extraction_prompt=(Path(root_dir) / self.prompt).read_text(encoding="utf-8") if self.prompt - else None, - "claim_description": self.description, - "max_gleanings": self.max_gleanings, - } + else EXTRACT_CLAIMS_PROMPT, + ) diff --git a/graphrag/config/models/extract_graph_config.py b/graphrag/config/models/extract_graph_config.py index 915ff5d8a5..b8dfce3e40 100644 --- a/graphrag/config/models/extract_graph_config.py +++ b/graphrag/config/models/extract_graph_config.py @@ -3,12 +3,20 @@ """Parameterization settings for the default configuration.""" +from dataclasses import dataclass from pathlib import Path from pydantic import BaseModel, Field from graphrag.config.defaults import graphrag_config_defaults -from graphrag.config.models.language_model_config import LanguageModelConfig +from graphrag.prompts.index.extract_graph import GRAPH_EXTRACTION_PROMPT + + +@dataclass +class ExtractGraphPrompts: + """Graph extraction prompt templates.""" + + extraction_prompt: str class ExtractGraphConfig(BaseModel): @@ -30,26 +38,11 @@ class ExtractGraphConfig(BaseModel): description="The maximum number of entity gleanings to use.", default=graphrag_config_defaults.extract_graph.max_gleanings, ) - strategy: dict | None = Field( - description="Override the default entity extraction strategy", - default=graphrag_config_defaults.extract_graph.strategy, - ) - def resolved_strategy( - self, root_dir: str, model_config: LanguageModelConfig - ) -> dict: - """Get the resolved entity extraction strategy.""" - from graphrag.index.operations.extract_graph.typing import ( - ExtractEntityStrategyType, - ) - - return self.strategy or { - "type": ExtractEntityStrategyType.graph_intelligence, - "llm": model_config.model_dump(), - "extraction_prompt": (Path(root_dir) / self.prompt).read_text( - encoding="utf-8" - ) + def resolved_prompts(self, root_dir: str) -> ExtractGraphPrompts: + """Get the resolved graph extraction prompts.""" + return ExtractGraphPrompts( + extraction_prompt=(Path(root_dir) / self.prompt).read_text(encoding="utf-8") if self.prompt - else None, - "max_gleanings": self.max_gleanings, - } + else GRAPH_EXTRACTION_PROMPT, + ) diff --git a/graphrag/config/models/graph_rag_config.py b/graphrag/config/models/graph_rag_config.py index 4846984903..2b5961321a 100644 --- a/graphrag/config/models/graph_rag_config.py +++ b/graphrag/config/models/graph_rag_config.py @@ -19,7 +19,8 @@ from graphrag.config.models.cluster_graph_config import ClusterGraphConfig from graphrag.config.models.community_reports_config import CommunityReportsConfig from graphrag.config.models.drift_search_config import DRIFTSearchConfig -from graphrag.config.models.extract_claims_config import ClaimExtractionConfig +from graphrag.config.models.embed_text_config import EmbedTextConfig +from graphrag.config.models.extract_claims_config import ExtractClaimsConfig from graphrag.config.models.extract_graph_config import ExtractGraphConfig from graphrag.config.models.extract_graph_nlp_config import ExtractGraphNLPConfig from graphrag.config.models.global_search_config import GlobalSearchConfig @@ -33,7 +34,6 @@ from graphrag.config.models.summarize_descriptions_config import ( SummarizeDescriptionsConfig, ) -from graphrag.config.models.text_embedding_config import TextEmbeddingConfig from graphrag.config.models.vector_store_config import VectorStoreConfig from graphrag.language_model.providers.litellm.services.rate_limiter.rate_limiter_factory import ( RateLimiterFactory, @@ -249,9 +249,9 @@ def _validate_reporting_base_dir(self) -> None: ) """List of workflows to run, in execution order.""" - embed_text: TextEmbeddingConfig = Field( + embed_text: EmbedTextConfig = Field( description="Text embedding configuration.", - default=TextEmbeddingConfig(), + default=EmbedTextConfig(), ) """Text embedding configuration.""" @@ -285,9 +285,9 @@ def _validate_reporting_base_dir(self) -> None: ) """The cluster graph configuration to use.""" - extract_claims: ClaimExtractionConfig = Field( + extract_claims: ExtractClaimsConfig = Field( description="The claim extraction configuration to use.", - default=ClaimExtractionConfig( + default=ExtractClaimsConfig( enabled=graphrag_config_defaults.extract_claims.enabled, ), ) diff --git a/graphrag/config/models/summarize_descriptions_config.py b/graphrag/config/models/summarize_descriptions_config.py index ef293f69c8..3ab1fdaec4 100644 --- a/graphrag/config/models/summarize_descriptions_config.py +++ b/graphrag/config/models/summarize_descriptions_config.py @@ -3,12 +3,20 @@ """Parameterization settings for the default configuration.""" +from dataclasses import dataclass from pathlib import Path from pydantic import BaseModel, Field from graphrag.config.defaults import graphrag_config_defaults -from graphrag.config.models.language_model_config import LanguageModelConfig +from graphrag.prompts.index.summarize_descriptions import SUMMARIZE_PROMPT + + +@dataclass +class SummarizeDescriptionsPrompts: + """Description summarization prompt templates.""" + + summarize_prompt: str class SummarizeDescriptionsConfig(BaseModel): @@ -30,27 +38,11 @@ class SummarizeDescriptionsConfig(BaseModel): description="Maximum tokens to submit from the input entity descriptions.", default=graphrag_config_defaults.summarize_descriptions.max_input_tokens, ) - strategy: dict | None = Field( - description="The override strategy to use.", - default=graphrag_config_defaults.summarize_descriptions.strategy, - ) - def resolved_strategy( - self, root_dir: str, model_config: LanguageModelConfig - ) -> dict: - """Get the resolved description summarization strategy.""" - from graphrag.index.operations.summarize_descriptions.summarize_descriptions import ( - SummarizeStrategyType, - ) - - return self.strategy or { - "type": SummarizeStrategyType.graph_intelligence, - "llm": model_config.model_dump(), - "summarize_prompt": (Path(root_dir) / self.prompt).read_text( - encoding="utf-8" - ) + def resolved_prompts(self, root_dir: str) -> SummarizeDescriptionsPrompts: + """Get the resolved description summarization prompts.""" + return SummarizeDescriptionsPrompts( + summarize_prompt=(Path(root_dir) / self.prompt).read_text(encoding="utf-8") if self.prompt - else None, - "max_summary_length": self.max_length, - "max_input_tokens": self.max_input_tokens, - } + else SUMMARIZE_PROMPT, + ) diff --git a/graphrag/index/operations/build_noun_graph/build_noun_graph.py b/graphrag/index/operations/build_noun_graph/build_noun_graph.py index dca2644ca9..8d3310e766 100644 --- a/graphrag/index/operations/build_noun_graph/build_noun_graph.py +++ b/graphrag/index/operations/build_noun_graph/build_noun_graph.py @@ -8,7 +8,6 @@ import numpy as np import pandas as pd -from graphrag.cache.noop_pipeline_cache import NoopPipelineCache from graphrag.cache.pipeline_cache import PipelineCache from graphrag.config.enums import AsyncType from graphrag.index.operations.build_noun_graph.np_extractors.base import ( @@ -23,9 +22,9 @@ async def build_noun_graph( text_unit_df: pd.DataFrame, text_analyzer: BaseNounPhraseExtractor, normalize_edge_weights: bool, - num_threads: int = 4, - async_mode: AsyncType = AsyncType.Threaded, - cache: PipelineCache | None = None, + num_threads: int, + async_mode: AsyncType, + cache: PipelineCache, ) -> tuple[pd.DataFrame, pd.DataFrame]: """Build a noun graph from text units.""" text_units = text_unit_df.loc[:, ["id", "text"]] @@ -43,9 +42,9 @@ async def build_noun_graph( async def _extract_nodes( text_unit_df: pd.DataFrame, text_analyzer: BaseNounPhraseExtractor, - num_threads: int = 4, - async_mode: AsyncType = AsyncType.Threaded, - cache: PipelineCache | None = None, + num_threads: int, + async_mode: AsyncType, + cache: PipelineCache, ) -> pd.DataFrame: """ Extract initial nodes and edges from text units. @@ -53,7 +52,6 @@ async def _extract_nodes( Input: text unit df with schema [id, text, document_id] Returns a dataframe with schema [id, title, frequency, text_unit_ids]. """ - cache = cache or NoopPipelineCache() cache = cache.child("extract_noun_phrases") async def extract(row): diff --git a/graphrag/index/operations/embed_text/embed_text.py b/graphrag/index/operations/embed_text/embed_text.py index 46b81ce212..04099f8109 100644 --- a/graphrag/index/operations/embed_text/embed_text.py +++ b/graphrag/index/operations/embed_text/embed_text.py @@ -4,51 +4,37 @@ """A module containing embed_text, load_strategy and create_row_from_embedding_data methods definition.""" import logging -from enum import Enum -from typing import Any import numpy as np import pandas as pd -from graphrag.cache.pipeline_cache import PipelineCache from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks from graphrag.config.embeddings import create_index_name from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig -from graphrag.index.operations.embed_text.strategies.typing import TextEmbeddingStrategy +from graphrag.index.operations.embed_text.run_embed_text import run_embed_text +from graphrag.language_model.protocol.base import EmbeddingModel +from graphrag.tokenizer.tokenizer import Tokenizer from graphrag.vector_stores.base import BaseVectorStore, VectorStoreDocument from graphrag.vector_stores.factory import VectorStoreFactory logger = logging.getLogger(__name__) -# Per Azure OpenAI Limits -# https://learn.microsoft.com/en-us/azure/ai-services/openai/reference -DEFAULT_EMBEDDING_BATCH_SIZE = 500 - - -class TextEmbedStrategyType(str, Enum): - """TextEmbedStrategyType class definition.""" - - openai = "openai" - mock = "mock" - - def __repr__(self): - """Get a string representation.""" - return f'"{self.value}"' - async def embed_text( input: pd.DataFrame, callbacks: WorkflowCallbacks, - cache: PipelineCache, + model: EmbeddingModel, + tokenizer: Tokenizer, embed_column: str, - strategy: dict, embedding_name: str, + batch_size: int, + batch_max_tokens: int, + num_threads: int, + vector_store_config: dict, id_column: str = "id", title_column: str | None = None, ): """Embed a piece of text into a vector space. The operation outputs a new column containing a mapping between doc_id and vector.""" - vector_store_config = strategy.get("vector_store") - if vector_store_config: index_name = _get_index_name(vector_store_config, embedding_name) vector_store: BaseVectorStore = _create_vector_store( @@ -60,11 +46,14 @@ async def embed_text( return await _text_embed_with_vector_store( input=input, callbacks=callbacks, - cache=cache, + model=model, + tokenizer=tokenizer, embed_column=embed_column, - strategy=strategy, vector_store=vector_store, vector_store_config=vector_store_workflow_config, + batch_size=batch_size, + batch_max_tokens=batch_max_tokens, + num_threads=num_threads, id_column=id_column, title_column=title_column, ) @@ -72,25 +61,29 @@ async def embed_text( return await _text_embed_in_memory( input=input, callbacks=callbacks, - cache=cache, + model=model, + tokenizer=tokenizer, embed_column=embed_column, - strategy=strategy, + batch_size=batch_size, + batch_max_tokens=batch_max_tokens, + num_threads=num_threads, ) async def _text_embed_in_memory( input: pd.DataFrame, callbacks: WorkflowCallbacks, - cache: PipelineCache, + model: EmbeddingModel, + tokenizer: Tokenizer, embed_column: str, - strategy: dict, + batch_size: int, + batch_max_tokens: int, + num_threads: int, ): - strategy_type = strategy["type"] - strategy_exec = load_strategy(strategy_type) - strategy_config = {**strategy} - texts: list[str] = input[embed_column].tolist() - result = await strategy_exec(texts, callbacks, cache, strategy_config) + result = await run_embed_text( + texts, callbacks, model, tokenizer, batch_size, batch_max_tokens, num_threads + ) return result.embeddings @@ -98,22 +91,18 @@ async def _text_embed_in_memory( async def _text_embed_with_vector_store( input: pd.DataFrame, callbacks: WorkflowCallbacks, - cache: PipelineCache, + model: EmbeddingModel, + tokenizer: Tokenizer, embed_column: str, - strategy: dict[str, Any], vector_store: BaseVectorStore, vector_store_config: dict, - id_column: str = "id", + batch_size: int, + batch_max_tokens: int, + num_threads: int, + id_column: str, title_column: str | None = None, ): - strategy_type = strategy["type"] - strategy_exec = load_strategy(strategy_type) - strategy_config = {**strategy} - # Get vector-storage configuration - insert_batch_size: int = ( - vector_store_config.get("batch_size") or DEFAULT_EMBEDDING_BATCH_SIZE - ) overwrite: bool = vector_store_config.get("overwrite", True) @@ -142,18 +131,26 @@ async def _text_embed_with_vector_store( all_results = [] - num_total_batches = (input.shape[0] + insert_batch_size - 1) // insert_batch_size - while insert_batch_size * i < input.shape[0]: + num_total_batches = (input.shape[0] + batch_size - 1) // batch_size + while batch_size * i < input.shape[0]: logger.info( "uploading text embeddings batch %d/%d of size %d to vector store", i + 1, num_total_batches, - insert_batch_size, + batch_size, ) - batch = input.iloc[insert_batch_size * i : insert_batch_size * (i + 1)] + batch = input.iloc[batch_size * i : batch_size * (i + 1)] texts: list[str] = batch[embed_column].tolist() ids: list[str] = batch[id_column].tolist() - result = await strategy_exec(texts, callbacks, cache, strategy_config) + result = await run_embed_text( + texts, + callbacks, + model, + tokenizer, + batch_size, + batch_max_tokens, + num_threads, + ) if result.embeddings: embeddings = [ embedding for embedding in result.embeddings if embedding is not None @@ -219,23 +216,3 @@ def _get_index_name(vector_store_config: dict, embedding_name: str) -> str: msg = f"using vector store {vector_store_config.get('type')} with container_name {container_name} for embedding {embedding_name}: {index_name}" logger.info(msg) return index_name - - -def load_strategy(strategy: TextEmbedStrategyType) -> TextEmbeddingStrategy: - """Load strategy method definition.""" - match strategy: - case TextEmbedStrategyType.openai: - from graphrag.index.operations.embed_text.strategies.openai import ( - run as run_openai, - ) - - return run_openai - case TextEmbedStrategyType.mock: - from graphrag.index.operations.embed_text.strategies.mock import ( - run as run_mock, - ) - - return run_mock - case _: - msg = f"Unknown strategy: {strategy}" - raise ValueError(msg) diff --git a/graphrag/index/operations/embed_text/strategies/openai.py b/graphrag/index/operations/embed_text/run_embed_text.py similarity index 80% rename from graphrag/index/operations/embed_text/strategies/openai.py rename to graphrag/index/operations/embed_text/run_embed_text.py index ef8aadff6b..fc3da5ee6c 100644 --- a/graphrag/index/operations/embed_text/strategies/openai.py +++ b/graphrag/index/operations/embed_text/run_embed_text.py @@ -5,47 +5,43 @@ import asyncio import logging -from typing import Any +from dataclasses import dataclass import numpy as np -from graphrag.cache.pipeline_cache import PipelineCache from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks -from graphrag.config.models.language_model_config import LanguageModelConfig -from graphrag.index.operations.embed_text.strategies.typing import TextEmbeddingResult from graphrag.index.text_splitting.text_splitting import TokenTextSplitter from graphrag.index.utils.is_null import is_null -from graphrag.language_model.manager import ModelManager from graphrag.language_model.protocol.base import EmbeddingModel from graphrag.logger.progress import ProgressTicker, progress_ticker -from graphrag.tokenizer.get_tokenizer import get_tokenizer +from graphrag.tokenizer.tokenizer import Tokenizer logger = logging.getLogger(__name__) -async def run( +@dataclass +class TextEmbeddingResult: + """Text embedding result class definition.""" + + embeddings: list[list[float] | None] | None + + +async def run_embed_text( input: list[str], callbacks: WorkflowCallbacks, - cache: PipelineCache, - args: dict[str, Any], + model: EmbeddingModel, + tokenizer: Tokenizer, + batch_size: int, + batch_max_tokens: int, + num_threads: int, ) -> TextEmbeddingResult: """Run the Claim extraction chain.""" if is_null(input): return TextEmbeddingResult(embeddings=None) - batch_size = args.get("batch_size", 16) - batch_max_tokens = args.get("batch_max_tokens", 8191) - llm_config = args["llm"] - llm_config = LanguageModelConfig(**args["llm"]) - splitter = _get_splitter(llm_config, batch_max_tokens) - model = ModelManager().get_or_create_embedding_model( - name="text_embedding", - model_type=llm_config.type, - config=llm_config, - callbacks=callbacks, - cache=cache, - ) - semaphore: asyncio.Semaphore = asyncio.Semaphore(args.get("num_threads", 4)) + splitter = _get_splitter(tokenizer, batch_max_tokens) + + semaphore: asyncio.Semaphore = asyncio.Semaphore(num_threads) # Break up the input texts. The sizes here indicate how many snippets are in each input text texts, input_sizes = _prepare_embed_texts(input, splitter) @@ -76,11 +72,9 @@ async def run( return TextEmbeddingResult(embeddings=embeddings) -def _get_splitter( - config: LanguageModelConfig, batch_max_tokens: int -) -> TokenTextSplitter: +def _get_splitter(tokenizer: Tokenizer, batch_max_tokens: int) -> TokenTextSplitter: return TokenTextSplitter( - tokenizer=get_tokenizer(model_config=config), + tokenizer=tokenizer, chunk_size=batch_max_tokens, ) diff --git a/graphrag/index/operations/embed_text/strategies/__init__.py b/graphrag/index/operations/embed_text/strategies/__init__.py deleted file mode 100644 index 8cbe7a580e..0000000000 --- a/graphrag/index/operations/embed_text/strategies/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""The Indexing Engine embed strategies package root.""" diff --git a/graphrag/index/operations/embed_text/strategies/mock.py b/graphrag/index/operations/embed_text/strategies/mock.py deleted file mode 100644 index a65ad9721f..0000000000 --- a/graphrag/index/operations/embed_text/strategies/mock.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""A module containing run and _embed_text methods definitions.""" - -import random -from collections.abc import Iterable -from typing import Any - -from graphrag.cache.pipeline_cache import PipelineCache -from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks -from graphrag.index.operations.embed_text.strategies.typing import TextEmbeddingResult -from graphrag.logger.progress import ProgressTicker, progress_ticker - - -async def run( # noqa RUF029 async is required for interface - input: list[str], - callbacks: WorkflowCallbacks, - cache: PipelineCache, - _args: dict[str, Any], -) -> TextEmbeddingResult: - """Run the Claim extraction chain.""" - input = input if isinstance(input, Iterable) else [input] - ticker = progress_ticker( - callbacks.progress, len(input), description="generate embeddings progress: " - ) - return TextEmbeddingResult( - embeddings=[_embed_text(cache, text, ticker) for text in input] - ) - - -def _embed_text(_cache: PipelineCache, _text: str, tick: ProgressTicker) -> list[float]: - """Embed a single piece of text.""" - tick(1) - return [random.random(), random.random(), random.random()] # noqa S311 diff --git a/graphrag/index/operations/embed_text/strategies/typing.py b/graphrag/index/operations/embed_text/strategies/typing.py deleted file mode 100644 index f45a7eb36e..0000000000 --- a/graphrag/index/operations/embed_text/strategies/typing.py +++ /dev/null @@ -1,28 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""A module containing 'TextEmbeddingResult' model.""" - -from collections.abc import Awaitable, Callable -from dataclasses import dataclass - -from graphrag.cache.pipeline_cache import PipelineCache -from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks - - -@dataclass -class TextEmbeddingResult: - """Text embedding result class definition.""" - - embeddings: list[list[float] | None] | None - - -TextEmbeddingStrategy = Callable[ - [ - list[str], - WorkflowCallbacks, - PipelineCache, - dict, - ], - Awaitable[TextEmbeddingResult], -] diff --git a/graphrag/index/operations/extract_covariates/claim_extractor.py b/graphrag/index/operations/extract_covariates/claim_extractor.py index e50d05ce83..99db5d7e2b 100644 --- a/graphrag/index/operations/extract_covariates/claim_extractor.py +++ b/graphrag/index/operations/extract_covariates/claim_extractor.py @@ -13,13 +13,18 @@ from graphrag.language_model.protocol.base import ChatModel from graphrag.prompts.index.extract_claims import ( CONTINUE_PROMPT, - EXTRACT_CLAIMS_PROMPT, LOOP_PROMPT, ) -DEFAULT_TUPLE_DELIMITER = "<|>" -DEFAULT_RECORD_DELIMITER = "##" -DEFAULT_COMPLETION_DELIMITER = "<|COMPLETE|>" +INPUT_TEXT_KEY = "input_text" +INPUT_ENTITY_SPEC_KEY = "entity_specs" +INPUT_CLAIM_DESCRIPTION_KEY = "claim_description" +INPUT_RESOLVED_ENTITIES_KEY = "resolved_entities" +RECORD_DELIMITER_KEY = "record_delimiter" +COMPLETION_DELIMITER_KEY = "completion_delimiter" +TUPLE_DELIMITER = "<|>" +RECORD_DELIMITER = "##" +COMPLETION_DELIMITER = "<|COMPLETE|>" logger = logging.getLogger(__name__) @@ -36,47 +41,19 @@ class ClaimExtractor: _model: ChatModel _extraction_prompt: str - _summary_prompt: str - _output_formatter_prompt: str - _input_text_key: str - _input_entity_spec_key: str - _input_claim_description_key: str - _tuple_delimiter_key: str - _record_delimiter_key: str - _completion_delimiter_key: str _max_gleanings: int _on_error: ErrorHandlerFn def __init__( self, - model_invoker: ChatModel, - extraction_prompt: str | None = None, - input_text_key: str | None = None, - input_entity_spec_key: str | None = None, - input_claim_description_key: str | None = None, - input_resolved_entities_key: str | None = None, - tuple_delimiter_key: str | None = None, - record_delimiter_key: str | None = None, - completion_delimiter_key: str | None = None, + model: ChatModel, + extraction_prompt: str, max_gleanings: int | None = None, on_error: ErrorHandlerFn | None = None, ): """Init method definition.""" - self._model = model_invoker - self._extraction_prompt = extraction_prompt or EXTRACT_CLAIMS_PROMPT - self._input_text_key = input_text_key or "input_text" - self._input_entity_spec_key = input_entity_spec_key or "entity_specs" - self._tuple_delimiter_key = tuple_delimiter_key or "tuple_delimiter" - self._record_delimiter_key = record_delimiter_key or "record_delimiter" - self._completion_delimiter_key = ( - completion_delimiter_key or "completion_delimiter" - ) - self._input_claim_description_key = ( - input_claim_description_key or "claim_description" - ) - self._input_resolved_entities_key = ( - input_resolved_entities_key or "resolved_entities" - ) + self._model = model + self._extraction_prompt = extraction_prompt self._max_gleanings = ( max_gleanings if max_gleanings is not None @@ -85,35 +62,21 @@ def __init__( self._on_error = on_error or (lambda _e, _s, _d: None) async def __call__( - self, inputs: dict[str, Any], prompt_variables: dict | None = None + self, + texts, + entity_spec, + resolved_entities, + claim_description, ) -> ClaimExtractorResult: """Call method definition.""" - if prompt_variables is None: - prompt_variables = {} - texts = inputs[self._input_text_key] - entity_spec = str(inputs[self._input_entity_spec_key]) - claim_description = inputs[self._input_claim_description_key] - resolved_entities = inputs.get(self._input_resolved_entities_key, {}) source_doc_map = {} - - prompt_args = { - self._input_entity_spec_key: entity_spec, - self._input_claim_description_key: claim_description, - self._tuple_delimiter_key: prompt_variables.get(self._tuple_delimiter_key) - or DEFAULT_TUPLE_DELIMITER, - self._record_delimiter_key: prompt_variables.get(self._record_delimiter_key) - or DEFAULT_RECORD_DELIMITER, - self._completion_delimiter_key: prompt_variables.get( - self._completion_delimiter_key - ) - or DEFAULT_COMPLETION_DELIMITER, - } - all_claims: list[dict] = [] for doc_index, text in enumerate(texts): document_id = f"d{doc_index}" try: - claims = await self._process_document(prompt_args, text, doc_index) + claims = await self._process_document( + text, claim_description, entity_spec + ) all_claims += [ self._clean_claim(c, document_id, resolved_entities) for c in claims ] @@ -147,23 +110,17 @@ def _clean_claim( return claim async def _process_document( - self, prompt_args: dict, doc, doc_index: int + self, text: str, claim_description: str, entity_spec: dict ) -> list[dict]: - record_delimiter = prompt_args.get( - self._record_delimiter_key, DEFAULT_RECORD_DELIMITER - ) - completion_delimiter = prompt_args.get( - self._completion_delimiter_key, DEFAULT_COMPLETION_DELIMITER - ) - response = await self._model.achat( self._extraction_prompt.format(**{ - self._input_text_key: doc, - **prompt_args, + INPUT_TEXT_KEY: text, + INPUT_CLAIM_DESCRIPTION_KEY: claim_description, + INPUT_ENTITY_SPEC_KEY: entity_spec, }) ) results = response.output.content or "" - claims = results.strip().removesuffix(completion_delimiter) + claims = results.strip().removesuffix(COMPLETION_DELIMITER) # if gleanings are specified, enter a loop to extract more claims # there are two exit criteria: (a) we hit the configured max, (b) the model says there are no more claims @@ -175,8 +132,8 @@ async def _process_document( history=response.history, ) extension = response.output.content or "" - claims += record_delimiter + extension.strip().removesuffix( - completion_delimiter + claims += RECORD_DELIMITER + extension.strip().removesuffix( + COMPLETION_DELIMITER ) # If this isn't the last loop, check to see if we should continue @@ -192,37 +149,26 @@ async def _process_document( if response.output.content != "Y": break - return self._parse_claim_tuples(results, prompt_args) + return self._parse_claim_tuples(results) - def _parse_claim_tuples( - self, claims: str, prompt_variables: dict - ) -> list[dict[str, Any]]: + def _parse_claim_tuples(self, claims: str) -> list[dict[str, Any]]: """Parse claim tuples.""" - record_delimiter = prompt_variables.get( - self._record_delimiter_key, DEFAULT_RECORD_DELIMITER - ) - completion_delimiter = prompt_variables.get( - self._completion_delimiter_key, DEFAULT_COMPLETION_DELIMITER - ) - tuple_delimiter = prompt_variables.get( - self._tuple_delimiter_key, DEFAULT_TUPLE_DELIMITER - ) def pull_field(index: int, fields: list[str]) -> str | None: return fields[index].strip() if len(fields) > index else None result: list[dict[str, Any]] = [] claims_values = ( - claims.strip().removesuffix(completion_delimiter).split(record_delimiter) + claims.strip().removesuffix(COMPLETION_DELIMITER).split(RECORD_DELIMITER) ) for claim in claims_values: claim = claim.strip().removeprefix("(").removesuffix(")") # Ignore the completion delimiter - if claim == completion_delimiter: + if claim == COMPLETION_DELIMITER: continue - claim_fields = claim.split(tuple_delimiter) + claim_fields = claim.split(TUPLE_DELIMITER) result.append({ "subject_id": pull_field(0, claim_fields), "object_id": pull_field(1, claim_fields), diff --git a/graphrag/index/operations/extract_covariates/extract_covariates.py b/graphrag/index/operations/extract_covariates/extract_covariates.py index d29ca61e9d..bc2e1fa9de 100644 --- a/graphrag/index/operations/extract_covariates/extract_covariates.py +++ b/graphrag/index/operations/extract_covariates/extract_covariates.py @@ -10,55 +10,45 @@ import pandas as pd -from graphrag.cache.pipeline_cache import PipelineCache from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks -from graphrag.config.defaults import graphrag_config_defaults from graphrag.config.enums import AsyncType -from graphrag.config.models.language_model_config import LanguageModelConfig from graphrag.index.operations.extract_covariates.claim_extractor import ClaimExtractor from graphrag.index.operations.extract_covariates.typing import ( Covariate, CovariateExtractionResult, ) from graphrag.index.utils.derive_from_rows import derive_from_rows -from graphrag.language_model.manager import ModelManager +from graphrag.language_model.protocol.base import ChatModel logger = logging.getLogger(__name__) -DEFAULT_ENTITY_TYPES = ["organization", "person", "geo", "event"] - - async def extract_covariates( input: pd.DataFrame, callbacks: WorkflowCallbacks, - cache: PipelineCache, + model: ChatModel, column: str, covariate_type: str, - strategy: dict[str, Any] | None, - async_mode: AsyncType = AsyncType.AsyncIO, - entity_types: list[str] | None = None, - num_threads: int = 4, + max_gleanings: int, + claim_description: str, + prompt: str, + entity_types: list[str], + num_threads: int, + async_type: AsyncType, ): """Extract claims from a piece of text.""" - logger.debug("extract_covariates strategy=%s", strategy) - if entity_types is None: - entity_types = DEFAULT_ENTITY_TYPES - resolved_entities_map = {} - strategy = strategy or {} - strategy_config = {**strategy} - async def run_strategy(row): text = row[column] result = await run_extract_claims( input=text, entity_types=entity_types, resolved_entities_map=resolved_entities_map, - callbacks=callbacks, - cache=cache, - strategy_config=strategy_config, + model=model, + max_gleanings=max_gleanings, + claim_description=claim_description, + prompt=prompt, ) return [ create_row_from_claim_data(row, item, covariate_type) @@ -69,8 +59,8 @@ async def run_strategy(row): input, run_strategy, callbacks, - async_type=async_mode, num_threads=num_threads, + async_type=async_type, progress_msg="extract covariates progress: ", ) return pd.DataFrame([item for row in results for item in row or []]) @@ -85,53 +75,29 @@ async def run_extract_claims( input: str | Iterable[str], entity_types: list[str], resolved_entities_map: dict[str, str], - callbacks: WorkflowCallbacks, - cache: PipelineCache, - strategy_config: dict[str, Any], + model: ChatModel, + max_gleanings: int, + claim_description: str, + prompt: str, ) -> CovariateExtractionResult: """Run the Claim extraction chain.""" - llm_config = LanguageModelConfig(**strategy_config["llm"]) - llm = ModelManager().get_or_create_chat_model( - name="extract_claims", - model_type=llm_config.type, - config=llm_config, - callbacks=callbacks, - cache=cache, - ) - - extraction_prompt = strategy_config.get("extraction_prompt") - max_gleanings = strategy_config.get( - "max_gleanings", graphrag_config_defaults.extract_claims.max_gleanings - ) - tuple_delimiter = strategy_config.get("tuple_delimiter") - record_delimiter = strategy_config.get("record_delimiter") - completion_delimiter = strategy_config.get("completion_delimiter") - extractor = ClaimExtractor( - model_invoker=llm, - extraction_prompt=extraction_prompt, + model=model, + extraction_prompt=prompt, max_gleanings=max_gleanings, on_error=lambda e, s, d: logger.error( "Claim Extraction Error", exc_info=e, extra={"stack": s, "details": d} ), ) - claim_description = strategy_config.get("claim_description") - if claim_description is None: - msg = "claim_description is required for claim extraction" - raise ValueError(msg) - input = [input] if isinstance(input, str) else input - results = await extractor({ - "input_text": input, - "entity_specs": entity_types, - "resolved_entities": resolved_entities_map, - "claim_description": claim_description, - "tuple_delimiter": tuple_delimiter, - "record_delimiter": record_delimiter, - "completion_delimiter": completion_delimiter, - }) + results = await extractor( + texts=input, + entity_spec=entity_types, + resolved_entities=resolved_entities_map, + claim_description=claim_description, + ) claim_data = results.output return CovariateExtractionResult([create_covariate(item) for item in claim_data]) diff --git a/graphrag/index/operations/extract_graph/extract_graph.py b/graphrag/index/operations/extract_graph/extract_graph.py index 76bcf40c76..3aa87404ec 100644 --- a/graphrag/index/operations/extract_graph/extract_graph.py +++ b/graphrag/index/operations/extract_graph/extract_graph.py @@ -4,58 +4,49 @@ """A module containing entity_extract methods.""" import logging -from typing import Any +import networkx as nx import pandas as pd -from graphrag.cache.pipeline_cache import PipelineCache from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks from graphrag.config.enums import AsyncType +from graphrag.index.operations.extract_graph.graph_extractor import GraphExtractor from graphrag.index.operations.extract_graph.typing import ( Document, - EntityExtractStrategy, - ExtractEntityStrategyType, + EntityExtractionResult, + EntityTypes, ) from graphrag.index.utils.derive_from_rows import derive_from_rows +from graphrag.language_model.protocol.base import ChatModel logger = logging.getLogger(__name__) -DEFAULT_ENTITY_TYPES = ["organization", "person", "geo", "event"] - - async def extract_graph( text_units: pd.DataFrame, callbacks: WorkflowCallbacks, - cache: PipelineCache, text_column: str, id_column: str, - strategy: dict[str, Any] | None, - async_mode: AsyncType = AsyncType.AsyncIO, - entity_types=DEFAULT_ENTITY_TYPES, - num_threads: int = 4, + model: ChatModel, + prompt: str, + entity_types: list[str], + max_gleanings: int, + num_threads: int, + async_type: AsyncType, ) -> tuple[pd.DataFrame, pd.DataFrame]: """Extract a graph from a piece of text using a language model.""" - logger.debug("entity_extract strategy=%s", strategy) - if entity_types is None: - entity_types = DEFAULT_ENTITY_TYPES - strategy = strategy or {} - strategy_exec = _load_strategy( - strategy.get("type", ExtractEntityStrategyType.graph_intelligence) - ) - strategy_config = {**strategy} - num_started = 0 async def run_strategy(row): nonlocal num_started text = row[text_column] id = row[id_column] - result = await strategy_exec( + result = await run_extract_graph( [Document(text=text, id=id)], entity_types, - cache, - strategy_config, + model, + prompt, + max_gleanings, ) num_started += 1 return [result.entities, result.relationships, result.graph] @@ -64,8 +55,8 @@ async def run_strategy(row): text_units, run_strategy, callbacks, - async_type=async_mode, num_threads=num_threads, + async_type=async_type, progress_msg="extract graph progress: ", ) @@ -82,19 +73,52 @@ async def run_strategy(row): return (entities, relationships) -def _load_strategy(strategy_type: ExtractEntityStrategyType) -> EntityExtractStrategy: - """Load strategy method definition.""" - match strategy_type: - case ExtractEntityStrategyType.graph_intelligence: - from graphrag.index.operations.extract_graph.graph_intelligence_strategy import ( - run_graph_intelligence, +async def run_extract_graph( + docs: list[Document], + entity_types: EntityTypes, + model: ChatModel, + prompt: str, + max_gleanings: int, +) -> EntityExtractionResult: + """Run the graph intelligence entity extraction strategy.""" + extractor = GraphExtractor( + model=model, + prompt=prompt, + max_gleanings=max_gleanings, + on_error=lambda e, s, d: logger.error( + "Entity Extraction Error", exc_info=e, extra={"stack": s, "details": d} + ), + ) + text_list = [doc.text.strip() for doc in docs] + + results = await extractor( + list(text_list), + entity_types=entity_types, + ) + + graph = results.output + # Map the "source_id" back to the "id" field + for _, node in graph.nodes(data=True): # type: ignore + if node is not None: + node["source_id"] = ",".join( + docs[int(id)].id for id in node["source_id"].split(",") ) - return run_graph_intelligence + for _, _, edge in graph.edges(data=True): # type: ignore + if edge is not None: + edge["source_id"] = ",".join( + docs[int(id)].id for id in edge["source_id"].split(",") + ) + + entities = [ + ({"title": item[0], **(item[1] or {})}) + for item in graph.nodes(data=True) + if item is not None + ] + + relationships = nx.to_pandas_edgelist(graph) - case _: - msg = f"Unknown strategy: {strategy_type}" - raise ValueError(msg) + return EntityExtractionResult(entities, relationships, graph) def _merge_entities(entity_dfs) -> pd.DataFrame: diff --git a/graphrag/index/operations/extract_graph/graph_extractor.py b/graphrag/index/operations/extract_graph/graph_extractor.py index d1f66c3e81..8e98eb7ec4 100644 --- a/graphrag/index/operations/extract_graph/graph_extractor.py +++ b/graphrag/index/operations/extract_graph/graph_extractor.py @@ -12,19 +12,21 @@ import networkx as nx -from graphrag.config.defaults import graphrag_config_defaults from graphrag.index.typing.error_handler import ErrorHandlerFn from graphrag.index.utils.string import clean_str from graphrag.language_model.protocol.base import ChatModel from graphrag.prompts.index.extract_graph import ( CONTINUE_PROMPT, - GRAPH_EXTRACTION_PROMPT, LOOP_PROMPT, ) -DEFAULT_TUPLE_DELIMITER = "<|>" -DEFAULT_RECORD_DELIMITER = "##" -DEFAULT_COMPLETION_DELIMITER = "<|COMPLETE|>" +INPUT_TEXT_KEY = "input_text" +RECORD_DELIMITER_KEY = "record_delimiter" +COMPLETION_DELIMITER_KEY = "completion_delimiter" +ENTITY_TYPES_KEY = "entity_types" +TUPLE_DELIMITER = "<|>" +RECORD_DELIMITER = "##" +COMPLETION_DELIMITER = "<|COMPLETE|>" DEFAULT_ENTITY_TYPES = ["organization", "person", "geo", "event"] logger = logging.getLogger(__name__) @@ -43,79 +45,36 @@ class GraphExtractor: _model: ChatModel _join_descriptions: bool - _tuple_delimiter_key: str - _record_delimiter_key: str - _entity_types_key: str - _input_text_key: str - _completion_delimiter_key: str - _entity_name_key: str - _input_descriptions_key: str _extraction_prompt: str - _summarization_prompt: str _max_gleanings: int _on_error: ErrorHandlerFn def __init__( self, - model_invoker: ChatModel, - tuple_delimiter_key: str | None = None, - record_delimiter_key: str | None = None, - input_text_key: str | None = None, - entity_types_key: str | None = None, - completion_delimiter_key: str | None = None, - prompt: str | None = None, + model: ChatModel, + prompt: str, + max_gleanings: int, join_descriptions=True, - max_gleanings: int | None = None, on_error: ErrorHandlerFn | None = None, ): """Init method definition.""" - # TODO: streamline construction - self._model = model_invoker + self._model = model self._join_descriptions = join_descriptions - self._input_text_key = input_text_key or "input_text" - self._tuple_delimiter_key = tuple_delimiter_key or "tuple_delimiter" - self._record_delimiter_key = record_delimiter_key or "record_delimiter" - self._completion_delimiter_key = ( - completion_delimiter_key or "completion_delimiter" - ) - self._entity_types_key = entity_types_key or "entity_types" - self._extraction_prompt = prompt or GRAPH_EXTRACTION_PROMPT - self._max_gleanings = ( - max_gleanings - if max_gleanings is not None - else graphrag_config_defaults.extract_graph.max_gleanings - ) + self._extraction_prompt = prompt + self._max_gleanings = max_gleanings self._on_error = on_error or (lambda _e, _s, _d: None) async def __call__( - self, texts: list[str], prompt_variables: dict[str, Any] | None = None + self, texts: list[str], entity_types: list[str] ) -> GraphExtractionResult: """Call method definition.""" - if prompt_variables is None: - prompt_variables = {} all_records: dict[int, str] = {} source_doc_map: dict[int, str] = {} - # Wire defaults into the prompt variables - prompt_variables = { - **prompt_variables, - self._tuple_delimiter_key: prompt_variables.get(self._tuple_delimiter_key) - or DEFAULT_TUPLE_DELIMITER, - self._record_delimiter_key: prompt_variables.get(self._record_delimiter_key) - or DEFAULT_RECORD_DELIMITER, - self._completion_delimiter_key: prompt_variables.get( - self._completion_delimiter_key - ) - or DEFAULT_COMPLETION_DELIMITER, - self._entity_types_key: ",".join( - prompt_variables[self._entity_types_key] or DEFAULT_ENTITY_TYPES - ), - } - for doc_index, text in enumerate(texts): try: # Invoke the entity extraction - result = await self._process_document(text, prompt_variables) + result = await self._process_document(text, entity_types) source_doc_map[doc_index] = text all_records[doc_index] = result except Exception as e: @@ -131,8 +90,8 @@ async def __call__( output = await self._process_results( all_records, - prompt_variables.get(self._tuple_delimiter_key, DEFAULT_TUPLE_DELIMITER), - prompt_variables.get(self._record_delimiter_key, DEFAULT_RECORD_DELIMITER), + TUPLE_DELIMITER, + RECORD_DELIMITER, ) return GraphExtractionResult( @@ -140,13 +99,11 @@ async def __call__( source_docs=source_doc_map, ) - async def _process_document( - self, text: str, prompt_variables: dict[str, str] - ) -> str: + async def _process_document(self, text: str, entity_types: list[str]) -> str: response = await self._model.achat( self._extraction_prompt.format(**{ - **prompt_variables, - self._input_text_key: text, + INPUT_TEXT_KEY: text, + ENTITY_TYPES_KEY: ",".join(entity_types), }), ) results = response.output.content or "" diff --git a/graphrag/index/operations/extract_graph/graph_intelligence_strategy.py b/graphrag/index/operations/extract_graph/graph_intelligence_strategy.py deleted file mode 100644 index b335d191a6..0000000000 --- a/graphrag/index/operations/extract_graph/graph_intelligence_strategy.py +++ /dev/null @@ -1,102 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""A module containing run_graph_intelligence, run_extract_graph and _create_text_splitter methods to run graph intelligence.""" - -import logging - -import networkx as nx - -from graphrag.cache.pipeline_cache import PipelineCache -from graphrag.config.defaults import graphrag_config_defaults -from graphrag.config.models.language_model_config import LanguageModelConfig -from graphrag.index.operations.extract_graph.graph_extractor import GraphExtractor -from graphrag.index.operations.extract_graph.typing import ( - Document, - EntityExtractionResult, - EntityTypes, - StrategyConfig, -) -from graphrag.language_model.manager import ModelManager -from graphrag.language_model.protocol.base import ChatModel - -logger = logging.getLogger(__name__) - - -async def run_graph_intelligence( - docs: list[Document], - entity_types: EntityTypes, - cache: PipelineCache, - args: StrategyConfig, -) -> EntityExtractionResult: - """Run the graph intelligence entity extraction strategy.""" - llm_config = LanguageModelConfig(**args["llm"]) - - llm = ModelManager().get_or_create_chat_model( - name="extract_graph", - model_type=llm_config.type, - config=llm_config, - cache=cache, - ) - - return await run_extract_graph(llm, docs, entity_types, args) - - -async def run_extract_graph( - model: ChatModel, - docs: list[Document], - entity_types: EntityTypes, - args: StrategyConfig, -) -> EntityExtractionResult: - """Run the entity extraction chain.""" - tuple_delimiter = args.get("tuple_delimiter", None) - record_delimiter = args.get("record_delimiter", None) - completion_delimiter = args.get("completion_delimiter", None) - extraction_prompt = args.get("extraction_prompt", None) - max_gleanings = args.get( - "max_gleanings", graphrag_config_defaults.extract_graph.max_gleanings - ) - - extractor = GraphExtractor( - model_invoker=model, - prompt=extraction_prompt, - max_gleanings=max_gleanings, - on_error=lambda e, s, d: logger.error( - "Entity Extraction Error", exc_info=e, extra={"stack": s, "details": d} - ), - ) - text_list = [doc.text.strip() for doc in docs] - - results = await extractor( - list(text_list), - { - "entity_types": entity_types, - "tuple_delimiter": tuple_delimiter, - "record_delimiter": record_delimiter, - "completion_delimiter": completion_delimiter, - }, - ) - - graph = results.output - # Map the "source_id" back to the "id" field - for _, node in graph.nodes(data=True): # type: ignore - if node is not None: - node["source_id"] = ",".join( - docs[int(id)].id for id in node["source_id"].split(",") - ) - - for _, _, edge in graph.edges(data=True): # type: ignore - if edge is not None: - edge["source_id"] = ",".join( - docs[int(id)].id for id in edge["source_id"].split(",") - ) - - entities = [ - ({"title": item[0], **(item[1] or {})}) - for item in graph.nodes(data=True) - if item is not None - ] - - relationships = nx.to_pandas_edgelist(graph) - - return EntityExtractionResult(entities, relationships, graph) diff --git a/graphrag/index/operations/extract_graph/typing.py b/graphrag/index/operations/extract_graph/typing.py index 3c7c134753..d74eb9a476 100644 --- a/graphrag/index/operations/extract_graph/typing.py +++ b/graphrag/index/operations/extract_graph/typing.py @@ -5,7 +5,6 @@ from collections.abc import Awaitable, Callable from dataclasses import dataclass -from enum import Enum from typing import Any import networkx as nx @@ -44,14 +43,3 @@ class EntityExtractionResult: ], Awaitable[EntityExtractionResult], ] - - -class ExtractEntityStrategyType(str, Enum): - """ExtractEntityStrategyType class definition.""" - - graph_intelligence = "graph_intelligence" - nltk = "nltk" - - def __repr__(self): - """Get a string representation.""" - return f'"{self.value}"' diff --git a/graphrag/index/operations/summarize_communities/community_reports_extractor.py b/graphrag/index/operations/summarize_communities/community_reports_extractor.py index 1442a44a1b..3ca29251a0 100644 --- a/graphrag/index/operations/summarize_communities/community_reports_extractor.py +++ b/graphrag/index/operations/summarize_communities/community_reports_extractor.py @@ -11,7 +11,6 @@ from graphrag.index.typing.error_handler import ErrorHandlerFn from graphrag.language_model.protocol.base import ChatModel -from graphrag.prompts.index.community_report import COMMUNITY_REPORT_PROMPT logger = logging.getLogger(__name__) @@ -58,16 +57,16 @@ class CommunityReportsExtractor: def __init__( self, - model_invoker: ChatModel, - extraction_prompt: str | None = None, + model: ChatModel, + extraction_prompt: str, + max_report_length: int, on_error: ErrorHandlerFn | None = None, - max_report_length: int | None = None, ): """Init method definition.""" - self._model = model_invoker - self._extraction_prompt = extraction_prompt or COMMUNITY_REPORT_PROMPT + self._model = model + self._extraction_prompt = extraction_prompt self._on_error = on_error or (lambda _e, _s, _d: None) - self._max_report_length = max_report_length or 1500 + self._max_report_length = max_report_length async def __call__(self, input_text: str): """Call method definition.""" diff --git a/graphrag/index/operations/summarize_communities/strategies.py b/graphrag/index/operations/summarize_communities/strategies.py deleted file mode 100644 index 06ce44b27d..0000000000 --- a/graphrag/index/operations/summarize_communities/strategies.py +++ /dev/null @@ -1,85 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""A module containing run, _run_extractor and _load_nodes_edges_for_claim_chain methods definition.""" - -import logging - -from graphrag.cache.pipeline_cache import PipelineCache -from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks -from graphrag.config.models.language_model_config import LanguageModelConfig -from graphrag.index.operations.summarize_communities.community_reports_extractor import ( - CommunityReportsExtractor, -) -from graphrag.index.operations.summarize_communities.typing import ( - CommunityReport, - Finding, - StrategyConfig, -) -from graphrag.language_model.manager import ModelManager -from graphrag.language_model.protocol.base import ChatModel - -logger = logging.getLogger(__name__) - - -async def run_graph_intelligence( - community: str | int, - input: str, - level: int, - callbacks: WorkflowCallbacks, - cache: PipelineCache, - args: StrategyConfig, -) -> CommunityReport | None: - """Run the graph intelligence entity extraction strategy.""" - llm_config = LanguageModelConfig(**args["llm"]) - llm = ModelManager().get_or_create_chat_model( - name="community_reporting", - model_type=llm_config.type, - config=llm_config, - callbacks=callbacks, - cache=cache, - ) - - return await _run_extractor(llm, community, input, level, args) - - -async def _run_extractor( - model: ChatModel, - community: str | int, - input: str, - level: int, - args: StrategyConfig, -) -> CommunityReport | None: - extractor = CommunityReportsExtractor( - model, - extraction_prompt=args.get("extraction_prompt", None), - max_report_length=args.get("max_report_length", None), - on_error=lambda e, stack, _data: logger.error( - "Community Report Extraction Error", exc_info=e, extra={"stack": stack} - ), - ) - - try: - results = await extractor(input) - report = results.structured_output - if report is None: - logger.warning("No report found for community: %s", community) - return None - - return CommunityReport( - community=community, - full_content=results.output, - level=level, - rank=report.rating, - title=report.title, - rating_explanation=report.rating_explanation, - summary=report.summary, - findings=[ - Finding(explanation=f.explanation, summary=f.summary) - for f in report.findings - ], - full_content_json=report.model_dump_json(indent=4), - ) - except Exception: - logger.exception("Error processing community: %s", community) - return None diff --git a/graphrag/index/operations/summarize_communities/summarize_communities.py b/graphrag/index/operations/summarize_communities/summarize_communities.py index c31a4b0d77..f6028332bc 100644 --- a/graphrag/index/operations/summarize_communities/summarize_communities.py +++ b/graphrag/index/operations/summarize_communities/summarize_communities.py @@ -9,19 +9,22 @@ import pandas as pd import graphrag.data_model.schemas as schemas -from graphrag.cache.pipeline_cache import PipelineCache from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks from graphrag.config.enums import AsyncType +from graphrag.index.operations.summarize_communities.community_reports_extractor import ( + CommunityReportsExtractor, +) from graphrag.index.operations.summarize_communities.typing import ( CommunityReport, CommunityReportsStrategy, - CreateCommunityReportsStrategyType, + Finding, ) from graphrag.index.operations.summarize_communities.utils import ( get_levels, ) from graphrag.index.utils.derive_from_rows import derive_from_rows +from graphrag.language_model.protocol.base import ChatModel from graphrag.logger.progress import progress_ticker from graphrag.tokenizer.tokenizer import Tokenizer @@ -34,18 +37,17 @@ async def summarize_communities( local_contexts, level_context_builder: Callable, callbacks: WorkflowCallbacks, - cache: PipelineCache, - strategy: dict, + model: ChatModel, + prompt: str, tokenizer: Tokenizer, max_input_length: int, - async_mode: AsyncType = AsyncType.AsyncIO, - num_threads: int = 4, + max_report_length: int, + num_threads: int, + async_type: AsyncType, ): """Generate community summaries.""" reports: list[CommunityReport | None] = [] tick = progress_ticker(callbacks.progress, len(local_contexts)) - strategy_exec = load_strategy(strategy["type"]) - strategy_config = {**strategy} community_hierarchy = ( communities.explode("children") .rename({"children": "sub_community"}, axis=1) @@ -70,13 +72,13 @@ async def summarize_communities( async def run_generate(record): result = await _generate_report( - strategy_exec, + run_extractor, community_id=record[schemas.COMMUNITY_ID], community_level=record[schemas.COMMUNITY_LEVEL], community_context=record[schemas.CONTEXT_STRING], - callbacks=callbacks, - cache=cache, - strategy=strategy_config, + model=model, + extraction_prompt=prompt, + max_report_length=max_report_length, ) tick() return result @@ -86,7 +88,7 @@ async def run_generate(record): run_generate, callbacks=NoopWorkflowCallbacks(), num_threads=num_threads, - async_type=async_mode, + async_type=async_type, progress_msg=f"level {levels[i]} summarize communities progress: ", ) reports.extend([lr for lr in local_reports if lr is not None]) @@ -96,35 +98,63 @@ async def run_generate(record): async def _generate_report( runner: CommunityReportsStrategy, - callbacks: WorkflowCallbacks, - cache: PipelineCache, - strategy: dict, + model: ChatModel, + extraction_prompt: str, community_id: int, community_level: int, community_context: str, + max_report_length: int, ) -> CommunityReport | None: """Generate a report for a single community.""" return await runner( community_id, community_context, community_level, - callbacks, - cache, - strategy, + model, + extraction_prompt, + max_report_length, ) -def load_strategy( - strategy: CreateCommunityReportsStrategyType, -) -> CommunityReportsStrategy: - """Load strategy method definition.""" - match strategy: - case CreateCommunityReportsStrategyType.graph_intelligence: - from graphrag.index.operations.summarize_communities.strategies import ( - run_graph_intelligence, - ) +async def run_extractor( + community: str | int, + input: str, + level: int, + model: ChatModel, + extraction_prompt: str, + max_report_length: int, +) -> CommunityReport | None: + """Run the graph intelligence entity extraction strategy.""" + extractor = CommunityReportsExtractor( + model, + extraction_prompt=extraction_prompt, + max_report_length=max_report_length, + on_error=lambda e, stack, _data: logger.error( + "Community Report Extraction Error", exc_info=e, extra={"stack": stack} + ), + ) + + try: + results = await extractor(input) + report = results.structured_output + if report is None: + logger.warning("No report found for community: %s", community) + return None - return run_graph_intelligence - case _: - msg = f"Unknown strategy: {strategy}" - raise ValueError(msg) + return CommunityReport( + community=community, + full_content=results.output, + level=level, + rank=report.rating, + title=report.title, + rating_explanation=report.rating_explanation, + summary=report.summary, + findings=[ + Finding(explanation=f.explanation, summary=f.summary) + for f in report.findings + ], + full_content_json=report.model_dump_json(indent=4), + ) + except Exception: + logger.exception("Error processing community: %s", community) + return None diff --git a/graphrag/index/operations/summarize_communities/typing.py b/graphrag/index/operations/summarize_communities/typing.py index 6dddf3d6b1..709c5ccc6a 100644 --- a/graphrag/index/operations/summarize_communities/typing.py +++ b/graphrag/index/operations/summarize_communities/typing.py @@ -4,16 +4,13 @@ """A module containing 'Finding' and 'CommunityReport' models.""" from collections.abc import Awaitable, Callable -from enum import Enum from typing import Any from typing_extensions import TypedDict -from graphrag.cache.pipeline_cache import PipelineCache -from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks +from graphrag.language_model.protocol.base import ChatModel ExtractedEntity = dict[str, Any] -StrategyConfig = dict[str, Any] RowContext = dict[str, Any] EntityTypes = list[str] Claim = dict[str, Any] @@ -45,19 +42,9 @@ class CommunityReport(TypedDict): str | int, str, int, - WorkflowCallbacks, - PipelineCache, - StrategyConfig, + ChatModel, + str, + int, ], Awaitable[CommunityReport | None], ] - - -class CreateCommunityReportsStrategyType(str, Enum): - """CreateCommunityReportsStrategyType class definition.""" - - graph_intelligence = "graph_intelligence" - - def __repr__(self): - """Get a string representation.""" - return f'"{self.value}"' diff --git a/graphrag/index/operations/summarize_descriptions/description_summary_extractor.py b/graphrag/index/operations/summarize_descriptions/description_summary_extractor.py index 6a44ee1df7..4a999fd685 100644 --- a/graphrag/index/operations/summarize_descriptions/description_summary_extractor.py +++ b/graphrag/index/operations/summarize_descriptions/description_summary_extractor.py @@ -8,7 +8,6 @@ from graphrag.index.typing.error_handler import ErrorHandlerFn from graphrag.language_model.protocol.base import ChatModel -from graphrag.prompts.index.summarize_descriptions import SUMMARIZE_PROMPT from graphrag.tokenizer.get_tokenizer import get_tokenizer # these tokens are used in the prompt @@ -36,17 +35,17 @@ class SummarizeExtractor: def __init__( self, - model_invoker: ChatModel, + model: ChatModel, max_summary_length: int, max_input_tokens: int, - summarization_prompt: str | None = None, + summarization_prompt: str, on_error: ErrorHandlerFn | None = None, ): """Init method definition.""" # TODO: streamline construction - self._model = model_invoker - self._tokenizer = get_tokenizer(model_invoker.config) - self._summarization_prompt = summarization_prompt or SUMMARIZE_PROMPT + self._model = model + self._tokenizer = get_tokenizer(model.config) + self._summarization_prompt = summarization_prompt self._on_error = on_error or (lambda _e, _s, _d: None) self._max_summary_length = max_summary_length self._max_input_tokens = max_input_tokens diff --git a/graphrag/index/operations/summarize_descriptions/graph_intelligence_strategy.py b/graphrag/index/operations/summarize_descriptions/graph_intelligence_strategy.py deleted file mode 100644 index d3259b290f..0000000000 --- a/graphrag/index/operations/summarize_descriptions/graph_intelligence_strategy.py +++ /dev/null @@ -1,65 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""A module containing run_graph_intelligence, run_resolve_entities and _create_text_list_splitter methods to run graph intelligence.""" - -import logging - -from graphrag.cache.pipeline_cache import PipelineCache -from graphrag.config.models.language_model_config import LanguageModelConfig -from graphrag.index.operations.summarize_descriptions.description_summary_extractor import ( - SummarizeExtractor, -) -from graphrag.index.operations.summarize_descriptions.typing import ( - StrategyConfig, - SummarizedDescriptionResult, -) -from graphrag.language_model.manager import ModelManager -from graphrag.language_model.protocol.base import ChatModel - -logger = logging.getLogger(__name__) - - -async def run_graph_intelligence( - id: str | tuple[str, str], - descriptions: list[str], - cache: PipelineCache, - args: StrategyConfig, -) -> SummarizedDescriptionResult: - """Run the graph intelligence entity extraction strategy.""" - llm_config = LanguageModelConfig(**args["llm"]) - llm = ModelManager().get_or_create_chat_model( - name="summarize_descriptions", - model_type=llm_config.type, - config=llm_config, - cache=cache, - ) - - return await run_summarize_descriptions(llm, id, descriptions, args) - - -async def run_summarize_descriptions( - model: ChatModel, - id: str | tuple[str, str], - descriptions: list[str], - args: StrategyConfig, -) -> SummarizedDescriptionResult: - """Run the entity extraction chain.""" - # Extraction Arguments - summarize_prompt = args.get("summarize_prompt", None) - max_input_tokens = args["max_input_tokens"] - max_summary_length = args["max_summary_length"] - extractor = SummarizeExtractor( - model_invoker=model, - summarization_prompt=summarize_prompt, - on_error=lambda e, stack, details: logger.error( - "Entity Extraction Error", - exc_info=e, - extra={"stack": stack, "details": details}, - ), - max_summary_length=max_summary_length, - max_input_tokens=max_input_tokens, - ) - - result = await extractor(id=id, descriptions=descriptions) - return SummarizedDescriptionResult(id=result.id, description=result.description) diff --git a/graphrag/index/operations/summarize_descriptions/summarize_descriptions.py b/graphrag/index/operations/summarize_descriptions/summarize_descriptions.py index 780c94b329..48aaf37a00 100644 --- a/graphrag/index/operations/summarize_descriptions/summarize_descriptions.py +++ b/graphrag/index/operations/summarize_descriptions/summarize_descriptions.py @@ -5,16 +5,17 @@ import asyncio import logging -from typing import Any import pandas as pd -from graphrag.cache.pipeline_cache import PipelineCache from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks +from graphrag.index.operations.summarize_descriptions.description_summary_extractor import ( + SummarizeExtractor, +) from graphrag.index.operations.summarize_descriptions.typing import ( - SummarizationStrategy, - SummarizeStrategyType, + SummarizedDescriptionResult, ) +from graphrag.language_model.protocol.base import ChatModel from graphrag.logger.progress import ProgressTicker, progress_ticker logger = logging.getLogger(__name__) @@ -24,17 +25,13 @@ async def summarize_descriptions( entities_df: pd.DataFrame, relationships_df: pd.DataFrame, callbacks: WorkflowCallbacks, - cache: PipelineCache, - strategy: dict[str, Any] | None = None, - num_threads: int = 4, + model: ChatModel, + max_summary_length: int, + max_input_tokens: int, + prompt: str, + num_threads: int, ) -> tuple[pd.DataFrame, pd.DataFrame]: """Summarize entity and relationship descriptions from an entity graph, using a language model.""" - logger.debug("summarize_descriptions strategy=%s", strategy) - strategy = strategy or {} - strategy_exec = load_strategy( - strategy.get("type", SummarizeStrategyType.graph_intelligence) - ) - strategy_config = {**strategy} async def get_summarized( nodes: pd.DataFrame, edges: pd.DataFrame, semaphore: asyncio.Semaphore @@ -99,7 +96,14 @@ async def do_summarize_descriptions( semaphore: asyncio.Semaphore, ): async with semaphore: - results = await strategy_exec(id, descriptions, cache, strategy_config) + results = await run_summarize_descriptions( + id, + descriptions, + model, + max_summary_length, + max_input_tokens, + prompt, + ) ticker(1) return results @@ -108,15 +112,26 @@ async def do_summarize_descriptions( return await get_summarized(entities_df, relationships_df, semaphore) -def load_strategy(strategy_type: SummarizeStrategyType) -> SummarizationStrategy: - """Load strategy method definition.""" - match strategy_type: - case SummarizeStrategyType.graph_intelligence: - from graphrag.index.operations.summarize_descriptions.graph_intelligence_strategy import ( - run_graph_intelligence, - ) +async def run_summarize_descriptions( + id: str | tuple[str, str], + descriptions: list[str], + model: ChatModel, + max_summary_length: int, + max_input_tokens: int, + prompt: str, +) -> SummarizedDescriptionResult: + """Run the graph intelligence entity extraction strategy.""" + extractor = SummarizeExtractor( + model=model, + summarization_prompt=prompt, + on_error=lambda e, stack, details: logger.error( + "Entity Extraction Error", + exc_info=e, + extra={"stack": stack, "details": details}, + ), + max_summary_length=max_summary_length, + max_input_tokens=max_input_tokens, + ) - return run_graph_intelligence - case _: - msg = f"Unknown strategy: {strategy_type}" - raise ValueError(msg) + result = await extractor(id=id, descriptions=descriptions) + return SummarizedDescriptionResult(id=result.id, description=result.description) diff --git a/graphrag/index/operations/summarize_descriptions/typing.py b/graphrag/index/operations/summarize_descriptions/typing.py index 55b079090d..5a912caec4 100644 --- a/graphrag/index/operations/summarize_descriptions/typing.py +++ b/graphrag/index/operations/summarize_descriptions/typing.py @@ -3,15 +3,9 @@ """A module containing 'SummarizedDescriptionResult' model.""" -from collections.abc import Awaitable, Callable from dataclasses import dataclass -from enum import Enum from typing import Any, NamedTuple -from graphrag.cache.pipeline_cache import PipelineCache - -StrategyConfig = dict[str, Any] - @dataclass class SummarizedDescriptionResult: @@ -21,28 +15,7 @@ class SummarizedDescriptionResult: description: str -SummarizationStrategy = Callable[ - [ - str | tuple[str, str], - list[str], - PipelineCache, - StrategyConfig, - ], - Awaitable[SummarizedDescriptionResult], -] - - class DescriptionSummarizeRow(NamedTuple): """DescriptionSummarizeRow class definition.""" graph: Any - - -class SummarizeStrategyType(str, Enum): - """SummarizeStrategyType class definition.""" - - graph_intelligence = "graph_intelligence" - - def __repr__(self): - """Get a string representation.""" - return f'"{self.value}"' diff --git a/graphrag/index/workflows/create_base_text_units.py b/graphrag/index/workflows/create_base_text_units.py index feeba7a065..df57713126 100644 --- a/graphrag/index/workflows/create_base_text_units.py +++ b/graphrag/index/workflows/create_base_text_units.py @@ -56,8 +56,8 @@ def create_base_text_units( overlap: int, encoding_model: str, strategy: ChunkStrategyType, - prepend_metadata: bool = False, - chunk_size_includes_metadata: bool = False, + prepend_metadata: bool, + chunk_size_includes_metadata: bool, ) -> pd.DataFrame: """All the steps to transform base text_units.""" documents.sort_values(by=["id"], ascending=[True], inplace=True) diff --git a/graphrag/index/workflows/create_community_reports.py b/graphrag/index/workflows/create_community_reports.py index bf98655c01..e9f19533ef 100644 --- a/graphrag/index/workflows/create_community_reports.py +++ b/graphrag/index/workflows/create_community_reports.py @@ -8,12 +8,9 @@ import pandas as pd import graphrag.data_model.schemas as schemas -from graphrag.cache.pipeline_cache import PipelineCache from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks -from graphrag.config.defaults import graphrag_config_defaults from graphrag.config.enums import AsyncType from graphrag.config.models.graph_rag_config import GraphRagConfig -from graphrag.config.models.language_model_config import LanguageModelConfig from graphrag.index.operations.finalize_community_reports import ( finalize_community_reports, ) @@ -29,7 +26,10 @@ ) from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.workflow import WorkflowFunctionOutput +from graphrag.language_model.manager import ModelManager +from graphrag.language_model.protocol.base import ChatModel from graphrag.tokenizer.get_tokenizer import get_tokenizer +from graphrag.tokenizer.tokenizer import Tokenizer from graphrag.utils.storage import ( load_table_from_storage, storage_has_table, @@ -54,25 +54,32 @@ async def run_workflow( ): claims = await load_table_from_storage("covariates", context.output_storage) - community_reports_llm_settings = config.get_language_model_config( - config.community_reports.model_id - ) - async_mode = community_reports_llm_settings.async_mode - num_threads = community_reports_llm_settings.concurrent_requests - summarization_strategy = config.community_reports.resolved_strategy( - config.root_dir, community_reports_llm_settings + model_config = config.get_language_model_config(config.community_reports.model_id) + prompts = config.community_reports.resolved_prompts(config.root_dir) + + model = ModelManager().get_or_create_chat_model( + name="community_reporting", + model_type=model_config.type, + config=model_config, + callbacks=context.callbacks, + cache=context.cache, ) + tokenizer = get_tokenizer(model_config) + output = await create_community_reports( edges_input=edges, entities=entities, communities=communities, claims_input=claims, callbacks=context.callbacks, - cache=context.cache, - summarization_strategy=summarization_strategy, - async_mode=async_mode, - num_threads=num_threads, + model=model, + tokenizer=tokenizer, + prompt=prompts.graph_prompt, + max_input_length=config.community_reports.max_input_length, + max_report_length=config.community_reports.max_length, + num_threads=model_config.concurrent_requests, + async_type=model_config.async_mode, ) await write_table_to_storage(output, "community_reports", context.output_storage) @@ -87,10 +94,13 @@ async def create_community_reports( communities: pd.DataFrame, claims_input: pd.DataFrame | None, callbacks: WorkflowCallbacks, - cache: PipelineCache, - summarization_strategy: dict, - async_mode: AsyncType = AsyncType.AsyncIO, - num_threads: int = 4, + model: ChatModel, + tokenizer: Tokenizer, + prompt: str, + max_input_length: int, + max_report_length: int, + num_threads: int, + async_type: AsyncType, ) -> pd.DataFrame: """All the steps to transform community reports.""" nodes = explode_communities(communities, entities) @@ -102,15 +112,6 @@ async def create_community_reports( if claims_input is not None: claims = _prep_claims(claims_input) - summarization_strategy["extraction_prompt"] = summarization_strategy["graph_prompt"] - - model_config = LanguageModelConfig(**summarization_strategy["llm"]) - tokenizer = get_tokenizer(model_config) - - max_input_length = summarization_strategy.get( - "max_input_length", graphrag_config_defaults.community_reports.max_input_length - ) - local_contexts = build_local_context( nodes, edges, @@ -126,12 +127,13 @@ async def create_community_reports( local_contexts, build_level_context, callbacks, - cache, - summarization_strategy, + model=model, + prompt=prompt, tokenizer=tokenizer, max_input_length=max_input_length, - async_mode=async_mode, + max_report_length=max_report_length, num_threads=num_threads, + async_type=async_type, ) return finalize_community_reports(community_reports, communities) diff --git a/graphrag/index/workflows/create_community_reports_text.py b/graphrag/index/workflows/create_community_reports_text.py index 20286f7c55..80b236890d 100644 --- a/graphrag/index/workflows/create_community_reports_text.py +++ b/graphrag/index/workflows/create_community_reports_text.py @@ -7,12 +7,9 @@ import pandas as pd -from graphrag.cache.pipeline_cache import PipelineCache from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks -from graphrag.config.defaults import graphrag_config_defaults from graphrag.config.enums import AsyncType from graphrag.config.models.graph_rag_config import GraphRagConfig -from graphrag.config.models.language_model_config import LanguageModelConfig from graphrag.index.operations.finalize_community_reports import ( finalize_community_reports, ) @@ -28,7 +25,10 @@ ) from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.workflow import WorkflowFunctionOutput +from graphrag.language_model.manager import ModelManager +from graphrag.language_model.protocol.base import ChatModel from graphrag.tokenizer.get_tokenizer import get_tokenizer +from graphrag.tokenizer.tokenizer import Tokenizer from graphrag.utils.storage import load_table_from_storage, write_table_to_storage logger = logging.getLogger(__name__) @@ -45,24 +45,31 @@ async def run_workflow( text_units = await load_table_from_storage("text_units", context.output_storage) - community_reports_llm_settings = config.get_language_model_config( - config.community_reports.model_id - ) - async_mode = community_reports_llm_settings.async_mode - num_threads = community_reports_llm_settings.concurrent_requests - summarization_strategy = config.community_reports.resolved_strategy( - config.root_dir, community_reports_llm_settings + model_config = config.get_language_model_config(config.community_reports.model_id) + model = ModelManager().get_or_create_chat_model( + name="community_reporting", + model_type=model_config.type, + config=model_config, + callbacks=context.callbacks, + cache=context.cache, ) + tokenizer = get_tokenizer(model_config) + + prompts = config.community_reports.resolved_prompts(config.root_dir) + output = await create_community_reports_text( entities, communities, text_units, context.callbacks, - context.cache, - summarization_strategy, - async_mode=async_mode, - num_threads=num_threads, + model=model, + tokenizer=tokenizer, + prompt=prompts.text_prompt, + max_input_length=config.community_reports.max_input_length, + max_report_length=config.community_reports.max_length, + num_threads=model_config.concurrent_requests, + async_type=model_config.async_mode, ) await write_table_to_storage(output, "community_reports", context.output_storage) @@ -76,23 +83,17 @@ async def create_community_reports_text( communities: pd.DataFrame, text_units: pd.DataFrame, callbacks: WorkflowCallbacks, - cache: PipelineCache, - summarization_strategy: dict, - async_mode: AsyncType = AsyncType.AsyncIO, - num_threads: int = 4, + model: ChatModel, + tokenizer: Tokenizer, + prompt: str, + max_input_length: int, + max_report_length: int, + num_threads: int, + async_type: AsyncType, ) -> pd.DataFrame: """All the steps to transform community reports.""" nodes = explode_communities(communities, entities) - summarization_strategy["extraction_prompt"] = summarization_strategy["text_prompt"] - - max_input_length = summarization_strategy.get( - "max_input_length", graphrag_config_defaults.community_reports.max_input_length - ) - - model_config = LanguageModelConfig(**summarization_strategy["llm"]) - tokenizer = get_tokenizer(model_config) - local_contexts = build_local_context( communities, text_units, nodes, tokenizer, max_input_length ) @@ -103,12 +104,13 @@ async def create_community_reports_text( local_contexts, build_level_context, callbacks, - cache, - summarization_strategy, + model=model, + prompt=prompt, tokenizer=tokenizer, max_input_length=max_input_length, - async_mode=async_mode, + max_report_length=max_report_length, num_threads=num_threads, + async_type=async_type, ) return finalize_community_reports(community_reports, communities) diff --git a/graphrag/index/workflows/extract_covariates.py b/graphrag/index/workflows/extract_covariates.py index 63adb1d0b9..76889a56db 100644 --- a/graphrag/index/workflows/extract_covariates.py +++ b/graphrag/index/workflows/extract_covariates.py @@ -4,13 +4,12 @@ """A module containing run_workflow method definition.""" import logging -from typing import Any from uuid import uuid4 import pandas as pd -from graphrag.cache.pipeline_cache import PipelineCache from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks +from graphrag.config.defaults import DEFAULT_ENTITY_TYPES from graphrag.config.enums import AsyncType from graphrag.config.models.graph_rag_config import GraphRagConfig from graphrag.data_model.schemas import COVARIATES_FINAL_COLUMNS @@ -19,6 +18,8 @@ ) from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.workflow import WorkflowFunctionOutput +from graphrag.language_model.manager import ModelManager +from graphrag.language_model.protocol.base import ChatModel from graphrag.utils.storage import load_table_from_storage, write_table_to_storage logger = logging.getLogger(__name__) @@ -34,25 +35,29 @@ async def run_workflow( if config.extract_claims.enabled: text_units = await load_table_from_storage("text_units", context.output_storage) - extract_claims_llm_settings = config.get_language_model_config( - config.extract_claims.model_id - ) - extraction_strategy = config.extract_claims.resolved_strategy( - config.root_dir, extract_claims_llm_settings + model_config = config.get_language_model_config(config.extract_claims.model_id) + + model = ModelManager().get_or_create_chat_model( + name="extract_claims", + model_type=model_config.type, + config=model_config, + callbacks=context.callbacks, + cache=context.cache, ) - async_mode = extract_claims_llm_settings.async_mode - num_threads = extract_claims_llm_settings.concurrent_requests + prompts = config.extract_claims.resolved_prompts(config.root_dir) output = await extract_covariates( - text_units, - context.callbacks, - context.cache, - "claim", - extraction_strategy, - async_mode=async_mode, - entity_types=None, - num_threads=num_threads, + text_units=text_units, + callbacks=context.callbacks, + model=model, + covariate_type="claim", + max_gleanings=config.extract_claims.max_gleanings, + claim_description=config.extract_claims.description, + prompt=prompts.extraction_prompt, + entity_types=DEFAULT_ENTITY_TYPES, + num_threads=model_config.concurrent_requests, + async_type=model_config.async_mode, ) await write_table_to_storage(output, "covariates", context.output_storage) @@ -64,27 +69,32 @@ async def run_workflow( async def extract_covariates( text_units: pd.DataFrame, callbacks: WorkflowCallbacks, - cache: PipelineCache, + model: ChatModel, covariate_type: str, - extraction_strategy: dict[str, Any] | None, - async_mode: AsyncType = AsyncType.AsyncIO, - entity_types: list[str] | None = None, - num_threads: int = 4, + max_gleanings: int, + claim_description: str, + prompt: str, + entity_types: list[str], + num_threads: int, + async_type: AsyncType, ) -> pd.DataFrame: """All the steps to extract and format covariates.""" # reassign the id because it will be overwritten in the output by a covariate one # this also results in text_unit_id being copied to the output covariate table text_units["text_unit_id"] = text_units["id"] + covariates = await extractor( input=text_units, callbacks=callbacks, - cache=cache, + model=model, column="text", covariate_type=covariate_type, - strategy=extraction_strategy, - async_mode=async_mode, + max_gleanings=max_gleanings, + claim_description=claim_description, + prompt=prompt, entity_types=entity_types, num_threads=num_threads, + async_type=async_type, ) text_units.drop(columns=["text_unit_id"], inplace=True) # don't pollute the global covariates["id"] = covariates["covariate_type"].apply(lambda _x: str(uuid4())) diff --git a/graphrag/index/workflows/extract_graph.py b/graphrag/index/workflows/extract_graph.py index 592502f6da..ce466379c7 100644 --- a/graphrag/index/workflows/extract_graph.py +++ b/graphrag/index/workflows/extract_graph.py @@ -4,11 +4,9 @@ """A module containing run_workflow method definition.""" import logging -from typing import Any import pandas as pd -from graphrag.cache.pipeline_cache import PipelineCache from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks from graphrag.config.enums import AsyncType from graphrag.config.models.graph_rag_config import GraphRagConfig @@ -20,6 +18,8 @@ ) from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.workflow import WorkflowFunctionOutput +from graphrag.language_model.manager import ModelManager +from graphrag.language_model.protocol.base import ChatModel from graphrag.utils.storage import load_table_from_storage, write_table_to_storage logger = logging.getLogger(__name__) @@ -33,30 +33,44 @@ async def run_workflow( logger.info("Workflow started: extract_graph") text_units = await load_table_from_storage("text_units", context.output_storage) - extract_graph_llm_settings = config.get_language_model_config( + extraction_model_config = config.get_language_model_config( config.extract_graph.model_id ) - extraction_strategy = config.extract_graph.resolved_strategy( - config.root_dir, extract_graph_llm_settings + extraction_prompts = config.extract_graph.resolved_prompts(config.root_dir) + extraction_model = ModelManager().get_or_create_chat_model( + name="extract_graph", + model_type=extraction_model_config.type, + config=extraction_model_config, + cache=context.cache, ) - summarization_llm_settings = config.get_language_model_config( + summarization_model_config = config.get_language_model_config( config.summarize_descriptions.model_id ) - summarization_strategy = config.summarize_descriptions.resolved_strategy( - config.root_dir, summarization_llm_settings + summarization_prompts = config.summarize_descriptions.resolved_prompts( + config.root_dir + ) + summarization_model = ModelManager().get_or_create_chat_model( + name="summarize_descriptions", + model_type=summarization_model_config.type, + config=summarization_model_config, + cache=context.cache, ) entities, relationships, raw_entities, raw_relationships = await extract_graph( text_units=text_units, callbacks=context.callbacks, - cache=context.cache, - extraction_strategy=extraction_strategy, - extraction_num_threads=extract_graph_llm_settings.concurrent_requests, - extraction_async_mode=extract_graph_llm_settings.async_mode, + extraction_model=extraction_model, + extraction_prompt=extraction_prompts.extraction_prompt, entity_types=config.extract_graph.entity_types, - summarization_strategy=summarization_strategy, - summarization_num_threads=summarization_llm_settings.concurrent_requests, + max_gleanings=config.extract_graph.max_gleanings, + extraction_num_threads=extraction_model_config.concurrent_requests, + extraction_async_type=extraction_model_config.async_mode, + summarization_model=summarization_model, + max_summary_length=config.summarize_descriptions.max_length, + max_input_tokens=config.summarize_descriptions.max_input_tokens, + summarization_prompt=summarization_prompts.summarize_prompt, + summarization_num_threads=summarization_model_config.concurrent_requests, ) await write_table_to_storage(entities, "entities", context.output_storage) @@ -82,26 +96,31 @@ async def run_workflow( async def extract_graph( text_units: pd.DataFrame, callbacks: WorkflowCallbacks, - cache: PipelineCache, - extraction_strategy: dict[str, Any] | None = None, - extraction_num_threads: int = 4, - extraction_async_mode: AsyncType = AsyncType.AsyncIO, - entity_types: list[str] | None = None, - summarization_strategy: dict[str, Any] | None = None, - summarization_num_threads: int = 4, + extraction_model: ChatModel, + extraction_prompt: str, + entity_types: list[str], + max_gleanings: int, + extraction_num_threads: int, + extraction_async_type: AsyncType, + summarization_model: ChatModel, + max_summary_length: int, + max_input_tokens: int, + summarization_prompt: str, + summarization_num_threads: int, ) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame]: """All the steps to create the base entity graph.""" # this returns a graph for each text unit, to be merged later extracted_entities, extracted_relationships = await extractor( text_units=text_units, callbacks=callbacks, - cache=cache, text_column="text", id_column="id", - strategy=extraction_strategy, - async_mode=extraction_async_mode, + model=extraction_model, + prompt=extraction_prompt, entity_types=entity_types, + max_gleanings=max_gleanings, num_threads=extraction_num_threads, + async_type=extraction_async_type, ) if not _validate_data(extracted_entities): @@ -124,9 +143,11 @@ async def extract_graph( extracted_entities=extracted_entities, extracted_relationships=extracted_relationships, callbacks=callbacks, - cache=cache, - summarization_strategy=summarization_strategy, - summarization_num_threads=summarization_num_threads, + model=summarization_model, + max_summary_length=max_summary_length, + max_input_tokens=max_input_tokens, + summarization_prompt=summarization_prompt, + num_threads=summarization_num_threads, ) return (entities, relationships, raw_entities, raw_relationships) @@ -136,18 +157,22 @@ async def get_summarized_entities_relationships( extracted_entities: pd.DataFrame, extracted_relationships: pd.DataFrame, callbacks: WorkflowCallbacks, - cache: PipelineCache, - summarization_strategy: dict[str, Any] | None = None, - summarization_num_threads: int = 4, + model: ChatModel, + max_summary_length: int, + max_input_tokens: int, + summarization_prompt: str, + num_threads: int, ) -> tuple[pd.DataFrame, pd.DataFrame]: """Summarize the entities and relationships.""" entity_summaries, relationship_summaries = await summarize_descriptions( entities_df=extracted_entities, relationships_df=extracted_relationships, callbacks=callbacks, - cache=cache, - strategy=summarization_strategy, - num_threads=summarization_num_threads, + model=model, + max_summary_length=max_summary_length, + max_input_tokens=max_input_tokens, + prompt=summarization_prompt, + num_threads=num_threads, ) relationships = extracted_relationships.drop(columns=["description"]).merge( diff --git a/graphrag/index/workflows/extract_graph_nlp.py b/graphrag/index/workflows/extract_graph_nlp.py index 90afedf46c..b49f3e7a1f 100644 --- a/graphrag/index/workflows/extract_graph_nlp.py +++ b/graphrag/index/workflows/extract_graph_nlp.py @@ -8,9 +8,12 @@ import pandas as pd from graphrag.cache.pipeline_cache import PipelineCache -from graphrag.config.models.extract_graph_nlp_config import ExtractGraphNLPConfig +from graphrag.config.enums import AsyncType from graphrag.config.models.graph_rag_config import GraphRagConfig from graphrag.index.operations.build_noun_graph.build_noun_graph import build_noun_graph +from graphrag.index.operations.build_noun_graph.np_extractors.base import ( + BaseNounPhraseExtractor, +) from graphrag.index.operations.build_noun_graph.np_extractors.factory import ( create_noun_phrase_extractor, ) @@ -29,10 +32,16 @@ async def run_workflow( logger.info("Workflow started: extract_graph_nlp") text_units = await load_table_from_storage("text_units", context.output_storage) + text_analyzer_config = config.extract_graph_nlp.text_analyzer + text_analyzer = create_noun_phrase_extractor(text_analyzer_config) + entities, relationships = await extract_graph_nlp( text_units, context.cache, - extraction_config=config.extract_graph_nlp, + text_analyzer=text_analyzer, + normalize_edge_weights=config.extract_graph_nlp.normalize_edge_weights, + num_threads=config.extract_graph_nlp.concurrent_requests, + async_type=config.extract_graph_nlp.async_mode, ) await write_table_to_storage(entities, "entities", context.output_storage) @@ -51,17 +60,18 @@ async def run_workflow( async def extract_graph_nlp( text_units: pd.DataFrame, cache: PipelineCache, - extraction_config: ExtractGraphNLPConfig, + text_analyzer: BaseNounPhraseExtractor, + normalize_edge_weights: bool, + num_threads: int, + async_type: AsyncType, ) -> tuple[pd.DataFrame, pd.DataFrame]: """All the steps to create the base entity graph.""" - text_analyzer_config = extraction_config.text_analyzer - text_analyzer = create_noun_phrase_extractor(text_analyzer_config) extracted_nodes, extracted_edges = await build_noun_graph( text_units, text_analyzer=text_analyzer, - normalize_edge_weights=extraction_config.normalize_edge_weights, - num_threads=extraction_config.concurrent_requests, - async_mode=extraction_config.async_mode, + normalize_edge_weights=normalize_edge_weights, + num_threads=num_threads, + async_mode=async_type, cache=cache, ) diff --git a/graphrag/index/workflows/generate_text_embeddings.py b/graphrag/index/workflows/generate_text_embeddings.py index 9e6dde6d34..3a0ae41259 100644 --- a/graphrag/index/workflows/generate_text_embeddings.py +++ b/graphrag/index/workflows/generate_text_embeddings.py @@ -7,7 +7,6 @@ import pandas as pd -from graphrag.cache.pipeline_cache import PipelineCache from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks from graphrag.config.embeddings import ( community_full_content_embedding, @@ -19,11 +18,15 @@ relationship_description_embedding, text_unit_text_embedding, ) -from graphrag.config.get_embedding_settings import get_embedding_settings +from graphrag.config.get_vector_store_settings import get_vector_store_settings from graphrag.config.models.graph_rag_config import GraphRagConfig from graphrag.index.operations.embed_text.embed_text import embed_text from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.workflow import WorkflowFunctionOutput +from graphrag.language_model.manager import ModelManager +from graphrag.language_model.protocol.base import EmbeddingModel +from graphrag.tokenizer.get_tokenizer import get_tokenizer +from graphrag.tokenizer.tokenizer import Tokenizer from graphrag.utils.storage import ( load_table_from_storage, write_table_to_storage, @@ -67,7 +70,19 @@ async def run_workflow( "community_reports", context.output_storage ) - text_embed = get_embedding_settings(config) + vector_store_config = get_vector_store_settings(config) + + model_config = config.get_language_model_config(config.embed_text.model_id) + + model = ModelManager().get_or_create_embedding_model( + name="text_embedding", + model_type=model_config.type, + config=model_config, + callbacks=context.callbacks, + cache=context.cache, + ) + + tokenizer = get_tokenizer(model_config) output = await generate_text_embeddings( documents=documents, @@ -76,8 +91,12 @@ async def run_workflow( entities=entities, community_reports=community_reports, callbacks=context.callbacks, - cache=context.cache, - text_embed_config=text_embed, + model=model, + tokenizer=tokenizer, + batch_size=config.embed_text.batch_size, + batch_max_tokens=config.embed_text.batch_max_tokens, + num_threads=model_config.concurrent_requests, + vector_store_config=vector_store_config, embedded_fields=embedded_fields, ) @@ -100,8 +119,12 @@ async def generate_text_embeddings( entities: pd.DataFrame | None, community_reports: pd.DataFrame | None, callbacks: WorkflowCallbacks, - cache: PipelineCache, - text_embed_config: dict, + model: EmbeddingModel, + tokenizer: Tokenizer, + batch_size: int, + batch_max_tokens: int, + num_threads: int, + vector_store_config: dict, embedded_fields: list[str], ) -> dict[str, pd.DataFrame]: """All the steps to generate all embeddings.""" @@ -164,8 +187,12 @@ async def generate_text_embeddings( outputs[field] = await _run_embeddings( name=field, callbacks=callbacks, - cache=cache, - text_embed_config=text_embed_config, + model=model, + tokenizer=tokenizer, + vector_store_config=vector_store_config, + batch_size=batch_size, + batch_max_tokens=batch_max_tokens, + num_threads=num_threads, **embedding_param_map[field], ) return outputs @@ -176,17 +203,25 @@ async def _run_embeddings( data: pd.DataFrame, embed_column: str, callbacks: WorkflowCallbacks, - cache: PipelineCache, - text_embed_config: dict, + model: EmbeddingModel, + tokenizer: Tokenizer, + batch_size: int, + batch_max_tokens: int, + num_threads: int, + vector_store_config: dict, ) -> pd.DataFrame: """All the steps to generate single embedding.""" data["embedding"] = await embed_text( input=data, callbacks=callbacks, - cache=cache, + model=model, + tokenizer=tokenizer, embed_column=embed_column, embedding_name=name, - strategy=text_embed_config["strategy"], + batch_size=batch_size, + batch_max_tokens=batch_max_tokens, + num_threads=num_threads, + vector_store_config=vector_store_config, ) return data.loc[:, ["id", "embedding"]] diff --git a/graphrag/index/workflows/update_entities_relationships.py b/graphrag/index/workflows/update_entities_relationships.py index cd8ad82553..69bdefa3f2 100644 --- a/graphrag/index/workflows/update_entities_relationships.py +++ b/graphrag/index/workflows/update_entities_relationships.py @@ -16,6 +16,7 @@ from graphrag.index.update.entities import _group_and_resolve_entities from graphrag.index.update.relationships import _update_and_merge_relationships from graphrag.index.workflows.extract_graph import get_summarized_entities_relationships +from graphrag.language_model.manager import ModelManager from graphrag.storage.pipeline_storage import PipelineStorage from graphrag.utils.storage import load_table_from_storage, write_table_to_storage @@ -77,11 +78,15 @@ async def _update_entities_and_relationships( delta_relationships, ) - summarization_llm_settings = config.get_language_model_config( + summarization_model_config = config.get_language_model_config( config.summarize_descriptions.model_id ) - summarization_strategy = config.summarize_descriptions.resolved_strategy( - config.root_dir, summarization_llm_settings + prompts = config.summarize_descriptions.resolved_prompts(config.root_dir) + model = ModelManager().get_or_create_chat_model( + name="summarize_descriptions", + model_type=summarization_model_config.type, + config=summarization_model_config, + cache=cache, ) ( @@ -91,9 +96,11 @@ async def _update_entities_and_relationships( extracted_entities=merged_entities_df, extracted_relationships=merged_relationships_df, callbacks=callbacks, - cache=cache, - summarization_strategy=summarization_strategy, - summarization_num_threads=summarization_llm_settings.concurrent_requests, + model=model, + max_summary_length=config.summarize_descriptions.max_length, + max_input_tokens=config.summarize_descriptions.max_input_tokens, + summarization_prompt=prompts.summarize_prompt, + num_threads=summarization_model_config.concurrent_requests, ) # Save the updated entities back to storage diff --git a/graphrag/index/workflows/update_text_embeddings.py b/graphrag/index/workflows/update_text_embeddings.py index 11bce16d3e..4b349f5809 100644 --- a/graphrag/index/workflows/update_text_embeddings.py +++ b/graphrag/index/workflows/update_text_embeddings.py @@ -5,12 +5,14 @@ import logging -from graphrag.config.get_embedding_settings import get_embedding_settings +from graphrag.config.get_vector_store_settings import get_vector_store_settings from graphrag.config.models.graph_rag_config import GraphRagConfig from graphrag.index.run.utils import get_update_storages from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.workflow import WorkflowFunctionOutput from graphrag.index.workflows.generate_text_embeddings import generate_text_embeddings +from graphrag.language_model.manager import ModelManager +from graphrag.tokenizer.get_tokenizer import get_tokenizer from graphrag.utils.storage import write_table_to_storage logger = logging.getLogger(__name__) @@ -35,7 +37,20 @@ async def run_workflow( ] embedded_fields = config.embed_text.names - text_embed = get_embedding_settings(config) + vector_store_config = get_vector_store_settings(config) + + model_config = config.get_language_model_config(config.embed_text.model_id) + + model = ModelManager().get_or_create_embedding_model( + name="text_embedding", + model_type=model_config.type, + config=model_config, + callbacks=context.callbacks, + cache=context.cache, + ) + + tokenizer = get_tokenizer(model_config) + result = await generate_text_embeddings( documents=final_documents_df, relationships=merged_relationships_df, @@ -43,8 +58,12 @@ async def run_workflow( entities=merged_entities_df, community_reports=merged_community_reports, callbacks=context.callbacks, - cache=context.cache, - text_embed_config=text_embed, + model=model, + tokenizer=tokenizer, + batch_size=config.embed_text.batch_size, + batch_max_tokens=config.embed_text.batch_max_tokens, + num_threads=model_config.concurrent_requests, + vector_store_config=vector_store_config, embedded_fields=embedded_fields, ) if config.snapshots.embeddings: diff --git a/graphrag/prompt_tune/loader/input.py b/graphrag/prompt_tune/loader/input.py index 235a78fe7a..bdc42e5b5c 100644 --- a/graphrag/prompt_tune/loader/input.py +++ b/graphrag/prompt_tune/loader/input.py @@ -13,16 +13,18 @@ from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks from graphrag.config.models.graph_rag_config import GraphRagConfig from graphrag.index.input.factory import create_input -from graphrag.index.operations.embed_text.strategies.openai import ( - run as run_embed_text, +from graphrag.index.operations.embed_text.run_embed_text import ( + run_embed_text, ) from graphrag.index.workflows.create_base_text_units import create_base_text_units +from graphrag.language_model.manager import ModelManager from graphrag.prompt_tune.defaults import ( LIMIT, N_SUBSET_MAX, K, ) from graphrag.prompt_tune.types import DocSelectionType +from graphrag.tokenizer.get_tokenizer import get_tokenizer from graphrag.utils.api import create_storage_from_config @@ -53,6 +55,14 @@ async def load_docs_in_chunks( embeddings_llm_settings = config.get_language_model_config( config.embed_text.model_id ) + model = ModelManager().get_or_create_embedding_model( + name="text_embedding", + model_type=embeddings_llm_settings.type, + config=embeddings_llm_settings, + callbacks=NoopWorkflowCallbacks(), + cache=NoopPipelineCache(), + ) + tokenizer = get_tokenizer(embeddings_llm_settings) input_storage = create_storage_from_config(config.input.storage) dataset = await create_input(config.input, input_storage) chunk_config = config.chunks @@ -89,13 +99,11 @@ async def load_docs_in_chunks( embedding_results = await run_embed_text( sampled_text_chunks, callbacks=NoopWorkflowCallbacks(), - cache=NoopPipelineCache(), - args={ - "llm": embeddings_llm_settings.model_dump(), - "num_threads": embeddings_llm_settings.concurrent_requests, - "batch_size": config.embed_text.batch_size, - "batch_max_tokens": config.embed_text.batch_max_tokens, - }, + model=model, + tokenizer=tokenizer, + batch_size=config.embed_text.batch_size, + batch_max_tokens=config.embed_text.batch_max_tokens, + num_threads=embeddings_llm_settings.concurrent_requests, ) embeddings = np.array(embedding_results.embeddings) chunks_df = _sample_chunks_from_embeddings(chunks_df, embeddings, k=k) diff --git a/graphrag/prompt_tune/prompt/entity_relationship.py b/graphrag/prompt_tune/prompt/entity_relationship.py index 66eefa9947..ec7dca55a4 100644 --- a/graphrag/prompt_tune/prompt/entity_relationship.py +++ b/graphrag/prompt_tune/prompt/entity_relationship.py @@ -12,7 +12,7 @@ - entity_name: Name of the entity, capitalized - entity_type: One of the following types: [{entity_types}] - entity_description: Comprehensive description of the entity's attributes and activities -Format each entity as ("entity"{{tuple_delimiter}}{{tuple_delimiter}}{{tuple_delimiter}}) +Format each entity as ("entity"<|><|><|>) 2. From the entities identified in step 1, identify all pairs of (source_entity, target_entity) that are *clearly related* to each other. For each pair of related entities, extract the following information: @@ -20,13 +20,13 @@ - target_entity: name of the target entity, as identified in step 1 - relationship_description: explanation as to why you think the source entity and the target entity are related to each other - relationship_strength: an integer score between 1 to 10, indicating strength of the relationship between the source entity and target entity -Format each relationship as ("relationship"{{tuple_delimiter}}{{tuple_delimiter}}{{tuple_delimiter}}{{tuple_delimiter}}) +Format each relationship as ("relationship"<|><|><|><|>) -3. Return output in {language} as a single list of all the entities and relationships identified in steps 1 and 2. Use {{record_delimiter}} as the list delimiter. +3. Return output in {language} as a single list of all the entities and relationships identified in steps 1 and 2. Use ## as the list delimiter. 4. If you have to translate into {language}, just translate the descriptions, nothing else! -5. When finished, output {{completion_delimiter}}. +5. When finished, output <|COMPLETE|>. ###################### -Examples- @@ -37,14 +37,14 @@ The Verdantis's Central Institution is scheduled to meet on Monday and Thursday, with the institution planning to release its latest policy decision on Thursday at 1:30 p.m. PDT, followed by a press conference where Central Institution Chair Martin Smith will take questions. Investors expect the Market Strategy Committee to hold its benchmark interest rate steady in a range of 3.5%-3.75%. ###################### Output: -("entity"{{tuple_delimiter}}CENTRAL INSTITUTION{{tuple_delimiter}}ORGANIZATION{{tuple_delimiter}}The Central Institution is the Federal Reserve of Verdantis, which is setting interest rates on Monday and Thursday) -{{record_delimiter}} -("entity"{{tuple_delimiter}}MARTIN SMITH{{tuple_delimiter}}PERSON{{tuple_delimiter}}Martin Smith is the chair of the Central Institution) -{{record_delimiter}} -("entity"{{tuple_delimiter}}MARKET STRATEGY COMMITTEE{{tuple_delimiter}}ORGANIZATION{{tuple_delimiter}}The Central Institution committee makes key decisions about interest rates and the growth of Verdantis's money supply) -{{record_delimiter}} -("relationship"{{tuple_delimiter}}MARTIN SMITH{{tuple_delimiter}}CENTRAL INSTITUTION{{tuple_delimiter}}Martin Smith is the Chair of the Central Institution and will answer questions at a press conference{{tuple_delimiter}}9) -{{completion_delimiter}} +("entity"<|>CENTRAL INSTITUTION<|>ORGANIZATION<|>The Central Institution is the Federal Reserve of Verdantis, which is setting interest rates on Monday and Thursday) +## +("entity"<|>MARTIN SMITH<|>PERSON<|>Martin Smith is the chair of the Central Institution) +## +("entity"<|>MARKET STRATEGY COMMITTEE<|>ORGANIZATION<|>The Central Institution committee makes key decisions about interest rates and the growth of Verdantis's money supply) +## +("relationship"<|>MARTIN SMITH<|>CENTRAL INSTITUTION<|>Martin Smith is the Chair of the Central Institution and will answer questions at a press conference<|>9) +<|COMPLETE|> ###################### Example 2: @@ -55,12 +55,12 @@ TechGlobal, a formerly public company, was taken private by Vision Holdings in 2014. The well-established chip designer says it powers 85% of premium smartphones. ###################### Output: -("entity"{{tuple_delimiter}}TECHGLOBAL{{tuple_delimiter}}ORGANIZATION{{tuple_delimiter}}TechGlobal is a stock now listed on the Global Exchange which powers 85% of premium smartphones) -{{record_delimiter}} -("entity"{{tuple_delimiter}}VISION HOLDINGS{{tuple_delimiter}}ORGANIZATION{{tuple_delimiter}}Vision Holdings is a firm that previously owned TechGlobal) -{{record_delimiter}} -("relationship"{{tuple_delimiter}}TECHGLOBAL{{tuple_delimiter}}VISION HOLDINGS{{tuple_delimiter}}Vision Holdings formerly owned TechGlobal from 2014 until present{{tuple_delimiter}}5) -{{completion_delimiter}} +("entity"<|>TECHGLOBAL<|>ORGANIZATION<|>TechGlobal is a stock now listed on the Global Exchange which powers 85% of premium smartphones) +## +("entity"<|>VISION HOLDINGS<|>ORGANIZATION<|>Vision Holdings is a firm that previously owned TechGlobal) +## +("relationship"<|>TECHGLOBAL<|>VISION HOLDINGS<|>Vision Holdings formerly owned TechGlobal from 2014 until present<|>5) +<|COMPLETE|> ###################### Example 3: @@ -77,47 +77,47 @@ The Aurelians include 39-year-old businessman Samuel Namara, who has been held in Tiruzia's Alhamia Prison, as well as journalist Durke Bataglani, 59, and environmentalist Meggie Tazbah, 53, who also holds Bratinas nationality. ###################### Output: -("entity"{{tuple_delimiter}}FIRUZABAD{{tuple_delimiter}}GEO{{tuple_delimiter}}Firuzabad held Aurelians as hostages) -{{record_delimiter}} -("entity"{{tuple_delimiter}}AURELIA{{tuple_delimiter}}GEO{{tuple_delimiter}}Country seeking to release hostages) -{{record_delimiter}} -("entity"{{tuple_delimiter}}QUINTARA{{tuple_delimiter}}GEO{{tuple_delimiter}}Country that negotiated a swap of money in exchange for hostages) -{{record_delimiter}} -{{record_delimiter}} -("entity"{{tuple_delimiter}}TIRUZIA{{tuple_delimiter}}GEO{{tuple_delimiter}}Capital of Firuzabad where the Aurelians were being held) -{{record_delimiter}} -("entity"{{tuple_delimiter}}KROHAARA{{tuple_delimiter}}GEO{{tuple_delimiter}}Capital city in Quintara) -{{record_delimiter}} -("entity"{{tuple_delimiter}}CASHION{{tuple_delimiter}}GEO{{tuple_delimiter}}Capital city in Aurelia) -{{record_delimiter}} -("entity"{{tuple_delimiter}}SAMUEL NAMARA{{tuple_delimiter}}PERSON{{tuple_delimiter}}Aurelian who spent time in Tiruzia's Alhamia Prison) -{{record_delimiter}} -("entity"{{tuple_delimiter}}ALHAMIA PRISON{{tuple_delimiter}}GEO{{tuple_delimiter}}Prison in Tiruzia) -{{record_delimiter}} -("entity"{{tuple_delimiter}}DURKE BATAGLANI{{tuple_delimiter}}PERSON{{tuple_delimiter}}Aurelian journalist who was held hostage) -{{record_delimiter}} -("entity"{{tuple_delimiter}}MEGGIE TAZBAH{{tuple_delimiter}}PERSON{{tuple_delimiter}}Bratinas national and environmentalist who was held hostage) -{{record_delimiter}} -("relationship"{{tuple_delimiter}}FIRUZABAD{{tuple_delimiter}}AURELIA{{tuple_delimiter}}Firuzabad negotiated a hostage exchange with Aurelia{{tuple_delimiter}}2) -{{record_delimiter}} -("relationship"{{tuple_delimiter}}QUINTARA{{tuple_delimiter}}AURELIA{{tuple_delimiter}}Quintara brokered the hostage exchange between Firuzabad and Aurelia{{tuple_delimiter}}2) -{{record_delimiter}} -("relationship"{{tuple_delimiter}}QUINTARA{{tuple_delimiter}}FIRUZABAD{{tuple_delimiter}}Quintara brokered the hostage exchange between Firuzabad and Aurelia{{tuple_delimiter}}2) -{{record_delimiter}} -("relationship"{{tuple_delimiter}}SAMUEL NAMARA{{tuple_delimiter}}ALHAMIA PRISON{{tuple_delimiter}}Samuel Namara was a prisoner at Alhamia prison{{tuple_delimiter}}8) -{{record_delimiter}} -("relationship"{{tuple_delimiter}}SAMUEL NAMARA{{tuple_delimiter}}MEGGIE TAZBAH{{tuple_delimiter}}Samuel Namara and Meggie Tazbah were exchanged in the same hostage release{{tuple_delimiter}}2) -{{record_delimiter}} -("relationship"{{tuple_delimiter}}SAMUEL NAMARA{{tuple_delimiter}}DURKE BATAGLANI{{tuple_delimiter}}Samuel Namara and Durke Bataglani were exchanged in the same hostage release{{tuple_delimiter}}2) -{{record_delimiter}} -("relationship"{{tuple_delimiter}}MEGGIE TAZBAH{{tuple_delimiter}}DURKE BATAGLANI{{tuple_delimiter}}Meggie Tazbah and Durke Bataglani were exchanged in the same hostage release{{tuple_delimiter}}2) -{{record_delimiter}} -("relationship"{{tuple_delimiter}}SAMUEL NAMARA{{tuple_delimiter}}FIRUZABAD{{tuple_delimiter}}Samuel Namara was a hostage in Firuzabad{{tuple_delimiter}}2) -{{record_delimiter}} -("relationship"{{tuple_delimiter}}MEGGIE TAZBAH{{tuple_delimiter}}FIRUZABAD{{tuple_delimiter}}Meggie Tazbah was a hostage in Firuzabad{{tuple_delimiter}}2) -{{record_delimiter}} -("relationship"{{tuple_delimiter}}DURKE BATAGLANI{{tuple_delimiter}}FIRUZABAD{{tuple_delimiter}}Durke Bataglani was a hostage in Firuzabad{{tuple_delimiter}}2) -{{completion_delimiter}} +("entity"<|>FIRUZABAD<|>GEO<|>Firuzabad held Aurelians as hostages) +## +("entity"<|>AURELIA<|>GEO<|>Country seeking to release hostages) +## +("entity"<|>QUINTARA<|>GEO<|>Country that negotiated a swap of money in exchange for hostages) +## +## +("entity"<|>TIRUZIA<|>GEO<|>Capital of Firuzabad where the Aurelians were being held) +## +("entity"<|>KROHAARA<|>GEO<|>Capital city in Quintara) +## +("entity"<|>CASHION<|>GEO<|>Capital city in Aurelia) +## +("entity"<|>SAMUEL NAMARA<|>PERSON<|>Aurelian who spent time in Tiruzia's Alhamia Prison) +## +("entity"<|>ALHAMIA PRISON<|>GEO<|>Prison in Tiruzia) +## +("entity"<|>DURKE BATAGLANI<|>PERSON<|>Aurelian journalist who was held hostage) +## +("entity"<|>MEGGIE TAZBAH<|>PERSON<|>Bratinas national and environmentalist who was held hostage) +## +("relationship"<|>FIRUZABAD<|>AURELIA<|>Firuzabad negotiated a hostage exchange with Aurelia<|>2) +## +("relationship"<|>QUINTARA<|>AURELIA<|>Quintara brokered the hostage exchange between Firuzabad and Aurelia<|>2) +## +("relationship"<|>QUINTARA<|>FIRUZABAD<|>Quintara brokered the hostage exchange between Firuzabad and Aurelia<|>2) +## +("relationship"<|>SAMUEL NAMARA<|>ALHAMIA PRISON<|>Samuel Namara was a prisoner at Alhamia prison<|>8) +## +("relationship"<|>SAMUEL NAMARA<|>MEGGIE TAZBAH<|>Samuel Namara and Meggie Tazbah were exchanged in the same hostage release<|>2) +## +("relationship"<|>SAMUEL NAMARA<|>DURKE BATAGLANI<|>Samuel Namara and Durke Bataglani were exchanged in the same hostage release<|>2) +## +("relationship"<|>MEGGIE TAZBAH<|>DURKE BATAGLANI<|>Meggie Tazbah and Durke Bataglani were exchanged in the same hostage release<|>2) +## +("relationship"<|>SAMUEL NAMARA<|>FIRUZABAD<|>Samuel Namara was a hostage in Firuzabad<|>2) +## +("relationship"<|>MEGGIE TAZBAH<|>FIRUZABAD<|>Meggie Tazbah was a hostage in Firuzabad<|>2) +## +("relationship"<|>DURKE BATAGLANI<|>FIRUZABAD<|>Durke Bataglani was a hostage in Firuzabad<|>2) +<|COMPLETE|> -Real Data- ###################### @@ -242,7 +242,7 @@ - entity_name: Name of the entity, capitalized - entity_type: Suggest several labels or categories for the entity. The categories should not be specific, but should be as general as possible. - entity_description: Comprehensive description of the entity's attributes and activities -Format each entity as ("entity"{{tuple_delimiter}}{{tuple_delimiter}}{{tuple_delimiter}}) +Format each entity as ("entity"<|><|><|>) 2. From the entities identified in step 1, identify all pairs of (source_entity, target_entity) that are *clearly related* to each other. For each pair of related entities, extract the following information: @@ -250,13 +250,13 @@ - target_entity: name of the target entity, as identified in step 1 - relationship_description: explanation as to why you think the source entity and the target entity are related to each other - relationship_strength: a numeric score indicating strength of the relationship between the source entity and target entity -Format each relationship as ("relationship"{{tuple_delimiter}}{{tuple_delimiter}}{{tuple_delimiter}}{{tuple_delimiter}}) +Format each relationship as ("relationship"<|><|><|><|>) -3. Return output in {language} as a single list of all the entities and relationships identified in steps 1 and 2. Use **{{record_delimiter}}** as the list delimiter. +3. Return output in {language} as a single list of all the entities and relationships identified in steps 1 and 2. Use **##** as the list delimiter. 4. If you have to translate into {language}, just translate the descriptions, nothing else! -5. When finished, output {{completion_delimiter}}. +5. When finished, output <|COMPLETE|>. ###################### -Examples- @@ -266,14 +266,14 @@ The Verdantis's Central Institution is scheduled to meet on Monday and Thursday, with the institution planning to release its latest policy decision on Thursday at 1:30 p.m. PDT, followed by a press conference where Central Institution Chair Martin Smith will take questions. Investors expect the Market Strategy Committee to hold its benchmark interest rate steady in a range of 3.5%-3.75%. ###################### Output: -("entity"{{tuple_delimiter}}CENTRAL INSTITUTION{{tuple_delimiter}}ORGANIZATION{{tuple_delimiter}}The Central Institution is the Federal Reserve of Verdantis, which is setting interest rates on Monday and Thursday) -{{record_delimiter}} -("entity"{{tuple_delimiter}}MARTIN SMITH{{tuple_delimiter}}PERSON{{tuple_delimiter}}Martin Smith is the chair of the Central Institution) -{{record_delimiter}} -("entity"{{tuple_delimiter}}MARKET STRATEGY COMMITTEE{{tuple_delimiter}}ORGANIZATION{{tuple_delimiter}}The Central Institution committee makes key decisions about interest rates and the growth of Verdantis's money supply) -{{record_delimiter}} -("relationship"{{tuple_delimiter}}MARTIN SMITH{{tuple_delimiter}}CENTRAL INSTITUTION{{tuple_delimiter}}Martin Smith is the Chair of the Central Institution and will answer questions at a press conference{{tuple_delimiter}}9) -{{completion_delimiter}} +("entity"<|>CENTRAL INSTITUTION<|>ORGANIZATION<|>The Central Institution is the Federal Reserve of Verdantis, which is setting interest rates on Monday and Thursday) +## +("entity"<|>MARTIN SMITH<|>PERSON<|>Martin Smith is the chair of the Central Institution) +## +("entity"<|>MARKET STRATEGY COMMITTEE<|>ORGANIZATION<|>The Central Institution committee makes key decisions about interest rates and the growth of Verdantis's money supply) +## +("relationship"<|>MARTIN SMITH<|>CENTRAL INSTITUTION<|>Martin Smith is the Chair of the Central Institution and will answer questions at a press conference<|>9) +<|COMPLETE|> ###################### Example 2: @@ -283,12 +283,12 @@ TechGlobal, a formerly public company, was taken private by Vision Holdings in 2014. The well-established chip designer says it powers 85% of premium smartphones. ###################### Output: -("entity"{{tuple_delimiter}}TECHGLOBAL{{tuple_delimiter}}ORGANIZATION{{tuple_delimiter}}TechGlobal is a stock now listed on the Global Exchange which powers 85% of premium smartphones) -{{record_delimiter}} -("entity"{{tuple_delimiter}}VISION HOLDINGS{{tuple_delimiter}}ORGANIZATION{{tuple_delimiter}}Vision Holdings is a firm that previously owned TechGlobal) -{{record_delimiter}} -("relationship"{{tuple_delimiter}}TECHGLOBAL{{tuple_delimiter}}VISION HOLDINGS{{tuple_delimiter}}Vision Holdings formerly owned TechGlobal from 2014 until present{{tuple_delimiter}}5) -{{completion_delimiter}} +("entity"<|>TECHGLOBAL<|>ORGANIZATION<|>TechGlobal is a stock now listed on the Global Exchange which powers 85% of premium smartphones) +## +("entity"<|>VISION HOLDINGS<|>ORGANIZATION<|>Vision Holdings is a firm that previously owned TechGlobal) +## +("relationship"<|>TECHGLOBAL<|>VISION HOLDINGS<|>Vision Holdings formerly owned TechGlobal from 2014 until present<|>5) +<|COMPLETE|> ###################### Example 3: @@ -304,47 +304,47 @@ The Aurelians include 39-year-old businessman Samuel Namara, who has been held in Tiruzia's Alhamia Prison, as well as journalist Durke Bataglani, 59, and environmentalist Meggie Tazbah, 53, who also holds Bratinas nationality. ###################### Output: -("entity"{{tuple_delimiter}}FIRUZABAD{{tuple_delimiter}}GEO{{tuple_delimiter}}Firuzabad held Aurelians as hostages) -{{record_delimiter}} -("entity"{{tuple_delimiter}}AURELIA{{tuple_delimiter}}GEO{{tuple_delimiter}}Country seeking to release hostages) -{{record_delimiter}} -("entity"{{tuple_delimiter}}QUINTARA{{tuple_delimiter}}GEO{{tuple_delimiter}}Country that negotiated a swap of money in exchange for hostages) -{{record_delimiter}} -{{record_delimiter}} -("entity"{{tuple_delimiter}}TIRUZIA{{tuple_delimiter}}GEO{{tuple_delimiter}}Capital of Firuzabad where the Aurelians were being held) -{{record_delimiter}} -("entity"{{tuple_delimiter}}KROHAARA{{tuple_delimiter}}GEO{{tuple_delimiter}}Capital city in Quintara) -{{record_delimiter}} -("entity"{{tuple_delimiter}}CASHION{{tuple_delimiter}}GEO{{tuple_delimiter}}Capital city in Aurelia) -{{record_delimiter}} -("entity"{{tuple_delimiter}}SAMUEL NAMARA{{tuple_delimiter}}PERSON{{tuple_delimiter}}Aurelian who spent time in Tiruzia's Alhamia Prison) -{{record_delimiter}} -("entity"{{tuple_delimiter}}ALHAMIA PRISON{{tuple_delimiter}}GEO{{tuple_delimiter}}Prison in Tiruzia) -{{record_delimiter}} -("entity"{{tuple_delimiter}}DURKE BATAGLANI{{tuple_delimiter}}PERSON{{tuple_delimiter}}Aurelian journalist who was held hostage) -{{record_delimiter}} -("entity"{{tuple_delimiter}}MEGGIE TAZBAH{{tuple_delimiter}}PERSON{{tuple_delimiter}}Bratinas national and environmentalist who was held hostage) -{{record_delimiter}} -("relationship"{{tuple_delimiter}}FIRUZABAD{{tuple_delimiter}}AURELIA{{tuple_delimiter}}Firuzabad negotiated a hostage exchange with Aurelia{{tuple_delimiter}}2) -{{record_delimiter}} -("relationship"{{tuple_delimiter}}QUINTARA{{tuple_delimiter}}AURELIA{{tuple_delimiter}}Quintara brokered the hostage exchange between Firuzabad and Aurelia{{tuple_delimiter}}2) -{{record_delimiter}} -("relationship"{{tuple_delimiter}}QUINTARA{{tuple_delimiter}}FIRUZABAD{{tuple_delimiter}}Quintara brokered the hostage exchange between Firuzabad and Aurelia{{tuple_delimiter}}2) -{{record_delimiter}} -("relationship"{{tuple_delimiter}}SAMUEL NAMARA{{tuple_delimiter}}ALHAMIA PRISON{{tuple_delimiter}}Samuel Namara was a prisoner at Alhamia prison{{tuple_delimiter}}8) -{{record_delimiter}} -("relationship"{{tuple_delimiter}}SAMUEL NAMARA{{tuple_delimiter}}MEGGIE TAZBAH{{tuple_delimiter}}Samuel Namara and Meggie Tazbah were exchanged in the same hostage release{{tuple_delimiter}}2) -{{record_delimiter}} -("relationship"{{tuple_delimiter}}SAMUEL NAMARA{{tuple_delimiter}}DURKE BATAGLANI{{tuple_delimiter}}Samuel Namara and Durke Bataglani were exchanged in the same hostage release{{tuple_delimiter}}2) -{{record_delimiter}} -("relationship"{{tuple_delimiter}}MEGGIE TAZBAH{{tuple_delimiter}}DURKE BATAGLANI{{tuple_delimiter}}Meggie Tazbah and Durke Bataglani were exchanged in the same hostage release{{tuple_delimiter}}2) -{{record_delimiter}} -("relationship"{{tuple_delimiter}}SAMUEL NAMARA{{tuple_delimiter}}FIRUZABAD{{tuple_delimiter}}Samuel Namara was a hostage in Firuzabad{{tuple_delimiter}}2) -{{record_delimiter}} -("relationship"{{tuple_delimiter}}MEGGIE TAZBAH{{tuple_delimiter}}FIRUZABAD{{tuple_delimiter}}Meggie Tazbah was a hostage in Firuzabad{{tuple_delimiter}}2) -{{record_delimiter}} -("relationship"{{tuple_delimiter}}DURKE BATAGLANI{{tuple_delimiter}}FIRUZABAD{{tuple_delimiter}}Durke Bataglani was a hostage in Firuzabad{{tuple_delimiter}}2) -{{completion_delimiter}} +("entity"<|>FIRUZABAD<|>GEO<|>Firuzabad held Aurelians as hostages) +## +("entity"<|>AURELIA<|>GEO<|>Country seeking to release hostages) +## +("entity"<|>QUINTARA<|>GEO<|>Country that negotiated a swap of money in exchange for hostages) +## +## +("entity"<|>TIRUZIA<|>GEO<|>Capital of Firuzabad where the Aurelians were being held) +## +("entity"<|>KROHAARA<|>GEO<|>Capital city in Quintara) +## +("entity"<|>CASHION<|>GEO<|>Capital city in Aurelia) +## +("entity"<|>SAMUEL NAMARA<|>PERSON<|>Aurelian who spent time in Tiruzia's Alhamia Prison) +## +("entity"<|>ALHAMIA PRISON<|>GEO<|>Prison in Tiruzia) +## +("entity"<|>DURKE BATAGLANI<|>PERSON<|>Aurelian journalist who was held hostage) +## +("entity"<|>MEGGIE TAZBAH<|>PERSON<|>Bratinas national and environmentalist who was held hostage) +## +("relationship"<|>FIRUZABAD<|>AURELIA<|>Firuzabad negotiated a hostage exchange with Aurelia<|>2) +## +("relationship"<|>QUINTARA<|>AURELIA<|>Quintara brokered the hostage exchange between Firuzabad and Aurelia<|>2) +## +("relationship"<|>QUINTARA<|>FIRUZABAD<|>Quintara brokered the hostage exchange between Firuzabad and Aurelia<|>2) +## +("relationship"<|>SAMUEL NAMARA<|>ALHAMIA PRISON<|>Samuel Namara was a prisoner at Alhamia prison<|>8) +## +("relationship"<|>SAMUEL NAMARA<|>MEGGIE TAZBAH<|>Samuel Namara and Meggie Tazbah were exchanged in the same hostage release<|>2) +## +("relationship"<|>SAMUEL NAMARA<|>DURKE BATAGLANI<|>Samuel Namara and Durke Bataglani were exchanged in the same hostage release<|>2) +## +("relationship"<|>MEGGIE TAZBAH<|>DURKE BATAGLANI<|>Meggie Tazbah and Durke Bataglani were exchanged in the same hostage release<|>2) +## +("relationship"<|>SAMUEL NAMARA<|>FIRUZABAD<|>Samuel Namara was a hostage in Firuzabad<|>2) +## +("relationship"<|>MEGGIE TAZBAH<|>FIRUZABAD<|>Meggie Tazbah was a hostage in Firuzabad<|>2) +## +("relationship"<|>DURKE BATAGLANI<|>FIRUZABAD<|>Durke Bataglani was a hostage in Firuzabad<|>2) +<|COMPLETE|> ###################### -Real Data- diff --git a/graphrag/prompt_tune/template/extract_graph.py b/graphrag/prompt_tune/template/extract_graph.py index 32d8756ec2..58a095cd0d 100644 --- a/graphrag/prompt_tune/template/extract_graph.py +++ b/graphrag/prompt_tune/template/extract_graph.py @@ -12,7 +12,7 @@ - entity_name: Name of the entity, capitalized - entity_type: One of the following types: [{entity_types}] - entity_description: Comprehensive description of the entity's attributes and activities -Format each entity as ("entity"{{tuple_delimiter}}{{tuple_delimiter}}{{tuple_delimiter}}) +Format each entity as ("entity"<|><|><|>) 2. From the entities identified in step 1, identify all pairs of (source_entity, target_entity) that are *clearly related* to each other. For each pair of related entities, extract the following information: @@ -20,13 +20,13 @@ - target_entity: name of the target entity, as identified in step 1 - relationship_description: explanation as to why you think the source entity and the target entity are related to each other - relationship_strength: an integer score between 1 to 10, indicating strength of the relationship between the source entity and target entity -Format each relationship as ("relationship"{{tuple_delimiter}}{{tuple_delimiter}}{{tuple_delimiter}}{{tuple_delimiter}}) +Format each relationship as ("relationship"<|><|><|><|>) -3. Return output in {language} as a single list of all the entities and relationships identified in steps 1 and 2. Use **{{record_delimiter}}** as the list delimiter. +3. Return output in {language} as a single list of all the entities and relationships identified in steps 1 and 2. Use **##** as the list delimiter. 4. If you have to translate into {language}, just translate the descriptions, nothing else! -5. When finished, output {{completion_delimiter}}. +5. When finished, output <|COMPLETE|>. -Examples- ###################### @@ -113,7 +113,7 @@ - entity_name: Name of the entity, capitalized - entity_type: Suggest several labels or categories for the entity. The categories should not be specific, but should be as general as possible. - entity_description: Comprehensive description of the entity's attributes and activities -Format each entity as ("entity"{{tuple_delimiter}}{{tuple_delimiter}}{{tuple_delimiter}}) +Format each entity as ("entity"<|><|><|>) 2. From the entities identified in step 1, identify all pairs of (source_entity, target_entity) that are *clearly related* to each other. For each pair of related entities, extract the following information: @@ -121,13 +121,13 @@ - target_entity: name of the target entity, as identified in step 1 - relationship_description: explanation as to why you think the source entity and the target entity are related to each other - relationship_strength: a numeric score indicating strength of the relationship between the source entity and target entity -Format each relationship as ("relationship"{{tuple_delimiter}}{{tuple_delimiter}}{{tuple_delimiter}}{{tuple_delimiter}}) +Format each relationship as ("relationship"<|><|><|><|>) -3. Return output in {language} as a single list of all the entities and relationships identified in steps 1 and 2. Use **{{record_delimiter}}** as the list delimiter. +3. Return output in {language} as a single list of all the entities and relationships identified in steps 1 and 2. Use **##** as the list delimiter. 4. If you have to translate into {language}, just translate the descriptions, nothing else! -5. When finished, output {{completion_delimiter}}. +5. When finished, output <|COMPLETE|>. -Examples- ###################### diff --git a/graphrag/prompts/index/extract_claims.py b/graphrag/prompts/index/extract_claims.py index 5e0e5570c6..59b19c9061 100644 --- a/graphrag/prompts/index/extract_claims.py +++ b/graphrag/prompts/index/extract_claims.py @@ -22,11 +22,11 @@ - Claim Date: Period (start_date, end_date) when the claim was made. Both start_date and end_date should be in ISO-8601 format. If the claim was made on a single date rather than a date range, set the same date for both start_date and end_date. If date is unknown, return **NONE**. - Claim Source Text: List of **all** quotes from the original text that are relevant to the claim. -Format each claim as ({tuple_delimiter}{tuple_delimiter}{tuple_delimiter}{tuple_delimiter}{tuple_delimiter}{tuple_delimiter}{tuple_delimiter}) +Format each claim as (<|><|><|><|><|><|><|>) -3. Return output in English as a single list of all the claims identified in steps 1 and 2. Use **{record_delimiter}** as the list delimiter. +3. Return output in English as a single list of all the claims identified in steps 1 and 2. Use **##** as the list delimiter. -4. When finished, output {completion_delimiter} +4. When finished, output <|COMPLETE|> -Examples- Example 1: @@ -35,8 +35,8 @@ Text: According to an article on 2022/01/10, Company A was fined for bid rigging while participating in multiple public tenders published by Government Agency B. The company is owned by Person C who was suspected of engaging in corruption activities in 2015. Output: -(COMPANY A{tuple_delimiter}GOVERNMENT AGENCY B{tuple_delimiter}ANTI-COMPETITIVE PRACTICES{tuple_delimiter}TRUE{tuple_delimiter}2022-01-10T00:00:00{tuple_delimiter}2022-01-10T00:00:00{tuple_delimiter}Company A was found to engage in anti-competitive practices because it was fined for bid rigging in multiple public tenders published by Government Agency B according to an article published on 2022/01/10{tuple_delimiter}According to an article published on 2022/01/10, Company A was fined for bid rigging while participating in multiple public tenders published by Government Agency B.) -{completion_delimiter} +(COMPANY A<|>GOVERNMENT AGENCY B<|>ANTI-COMPETITIVE PRACTICES<|>TRUE<|>2022-01-10T00:00:00<|>2022-01-10T00:00:00<|>Company A was found to engage in anti-competitive practices because it was fined for bid rigging in multiple public tenders published by Government Agency B according to an article published on 2022/01/10<|>According to an article published on 2022/01/10, Company A was fined for bid rigging while participating in multiple public tenders published by Government Agency B.) +<|COMPLETE|> Example 2: Entity specification: Company A, Person C @@ -44,10 +44,10 @@ Text: According to an article on 2022/01/10, Company A was fined for bid rigging while participating in multiple public tenders published by Government Agency B. The company is owned by Person C who was suspected of engaging in corruption activities in 2015. Output: -(COMPANY A{tuple_delimiter}GOVERNMENT AGENCY B{tuple_delimiter}ANTI-COMPETITIVE PRACTICES{tuple_delimiter}TRUE{tuple_delimiter}2022-01-10T00:00:00{tuple_delimiter}2022-01-10T00:00:00{tuple_delimiter}Company A was found to engage in anti-competitive practices because it was fined for bid rigging in multiple public tenders published by Government Agency B according to an article published on 2022/01/10{tuple_delimiter}According to an article published on 2022/01/10, Company A was fined for bid rigging while participating in multiple public tenders published by Government Agency B.) -{record_delimiter} -(PERSON C{tuple_delimiter}NONE{tuple_delimiter}CORRUPTION{tuple_delimiter}SUSPECTED{tuple_delimiter}2015-01-01T00:00:00{tuple_delimiter}2015-12-30T00:00:00{tuple_delimiter}Person C was suspected of engaging in corruption activities in 2015{tuple_delimiter}The company is owned by Person C who was suspected of engaging in corruption activities in 2015) -{completion_delimiter} +(COMPANY A<|>GOVERNMENT AGENCY B<|>ANTI-COMPETITIVE PRACTICES<|>TRUE<|>2022-01-10T00:00:00<|>2022-01-10T00:00:00<|>Company A was found to engage in anti-competitive practices because it was fined for bid rigging in multiple public tenders published by Government Agency B according to an article published on 2022/01/10<|>According to an article published on 2022/01/10, Company A was fined for bid rigging while participating in multiple public tenders published by Government Agency B.) +## +(PERSON C<|>NONE<|>CORRUPTION<|>SUSPECTED<|>2015-01-01T00:00:00<|>2015-12-30T00:00:00<|>Person C was suspected of engaging in corruption activities in 2015<|>The company is owned by Person C who was suspected of engaging in corruption activities in 2015) +<|COMPLETE|> -Real Data- Use the following input for your answer. diff --git a/graphrag/prompts/index/extract_graph.py b/graphrag/prompts/index/extract_graph.py index a94b36142e..91157937d4 100644 --- a/graphrag/prompts/index/extract_graph.py +++ b/graphrag/prompts/index/extract_graph.py @@ -12,7 +12,7 @@ - entity_name: Name of the entity, capitalized - entity_type: One of the following types: [{entity_types}] - entity_description: Comprehensive description of the entity's attributes and activities -Format each entity as ("entity"{tuple_delimiter}{tuple_delimiter}{tuple_delimiter}) +Format each entity as ("entity"<|><|><|>) 2. From the entities identified in step 1, identify all pairs of (source_entity, target_entity) that are *clearly related* to each other. For each pair of related entities, extract the following information: @@ -20,11 +20,11 @@ - target_entity: name of the target entity, as identified in step 1 - relationship_description: explanation as to why you think the source entity and the target entity are related to each other - relationship_strength: a numeric score indicating strength of the relationship between the source entity and target entity - Format each relationship as ("relationship"{tuple_delimiter}{tuple_delimiter}{tuple_delimiter}{tuple_delimiter}) + Format each relationship as ("relationship"<|><|><|><|>) -3. Return output in English as a single list of all the entities and relationships identified in steps 1 and 2. Use **{record_delimiter}** as the list delimiter. +3. Return output in English as a single list of all the entities and relationships identified in steps 1 and 2. Use **##** as the list delimiter. -4. When finished, output {completion_delimiter} +4. When finished, output <|COMPLETE|> ###################### -Examples- @@ -35,14 +35,14 @@ The Verdantis's Central Institution is scheduled to meet on Monday and Thursday, with the institution planning to release its latest policy decision on Thursday at 1:30 p.m. PDT, followed by a press conference where Central Institution Chair Martin Smith will take questions. Investors expect the Market Strategy Committee to hold its benchmark interest rate steady in a range of 3.5%-3.75%. ###################### Output: -("entity"{tuple_delimiter}CENTRAL INSTITUTION{tuple_delimiter}ORGANIZATION{tuple_delimiter}The Central Institution is the Federal Reserve of Verdantis, which is setting interest rates on Monday and Thursday) -{record_delimiter} -("entity"{tuple_delimiter}MARTIN SMITH{tuple_delimiter}PERSON{tuple_delimiter}Martin Smith is the chair of the Central Institution) -{record_delimiter} -("entity"{tuple_delimiter}MARKET STRATEGY COMMITTEE{tuple_delimiter}ORGANIZATION{tuple_delimiter}The Central Institution committee makes key decisions about interest rates and the growth of Verdantis's money supply) -{record_delimiter} -("relationship"{tuple_delimiter}MARTIN SMITH{tuple_delimiter}CENTRAL INSTITUTION{tuple_delimiter}Martin Smith is the Chair of the Central Institution and will answer questions at a press conference{tuple_delimiter}9) -{completion_delimiter} +("entity"<|>CENTRAL INSTITUTION<|>ORGANIZATION<|>The Central Institution is the Federal Reserve of Verdantis, which is setting interest rates on Monday and Thursday) +## +("entity"<|>MARTIN SMITH<|>PERSON<|>Martin Smith is the chair of the Central Institution) +## +("entity"<|>MARKET STRATEGY COMMITTEE<|>ORGANIZATION<|>The Central Institution committee makes key decisions about interest rates and the growth of Verdantis's money supply) +## +("relationship"<|>MARTIN SMITH<|>CENTRAL INSTITUTION<|>Martin Smith is the Chair of the Central Institution and will answer questions at a press conference<|>9) +<|COMPLETE|> ###################### Example 2: @@ -53,12 +53,12 @@ TechGlobal, a formerly public company, was taken private by Vision Holdings in 2014. The well-established chip designer says it powers 85% of premium smartphones. ###################### Output: -("entity"{tuple_delimiter}TECHGLOBAL{tuple_delimiter}ORGANIZATION{tuple_delimiter}TechGlobal is a stock now listed on the Global Exchange which powers 85% of premium smartphones) -{record_delimiter} -("entity"{tuple_delimiter}VISION HOLDINGS{tuple_delimiter}ORGANIZATION{tuple_delimiter}Vision Holdings is a firm that previously owned TechGlobal) -{record_delimiter} -("relationship"{tuple_delimiter}TECHGLOBAL{tuple_delimiter}VISION HOLDINGS{tuple_delimiter}Vision Holdings formerly owned TechGlobal from 2014 until present{tuple_delimiter}5) -{completion_delimiter} +("entity"<|>TECHGLOBAL<|>ORGANIZATION<|>TechGlobal is a stock now listed on the Global Exchange which powers 85% of premium smartphones) +## +("entity"<|>VISION HOLDINGS<|>ORGANIZATION<|>Vision Holdings is a firm that previously owned TechGlobal) +## +("relationship"<|>TECHGLOBAL<|>VISION HOLDINGS<|>Vision Holdings formerly owned TechGlobal from 2014 until present<|>5) +<|COMPLETE|> ###################### Example 3: @@ -75,47 +75,47 @@ The Aurelians include 39-year-old businessman Samuel Namara, who has been held in Tiruzia's Alhamia Prison, as well as journalist Durke Bataglani, 59, and environmentalist Meggie Tazbah, 53, who also holds Bratinas nationality. ###################### Output: -("entity"{tuple_delimiter}FIRUZABAD{tuple_delimiter}GEO{tuple_delimiter}Firuzabad held Aurelians as hostages) -{record_delimiter} -("entity"{tuple_delimiter}AURELIA{tuple_delimiter}GEO{tuple_delimiter}Country seeking to release hostages) -{record_delimiter} -("entity"{tuple_delimiter}QUINTARA{tuple_delimiter}GEO{tuple_delimiter}Country that negotiated a swap of money in exchange for hostages) -{record_delimiter} -{record_delimiter} -("entity"{tuple_delimiter}TIRUZIA{tuple_delimiter}GEO{tuple_delimiter}Capital of Firuzabad where the Aurelians were being held) -{record_delimiter} -("entity"{tuple_delimiter}KROHAARA{tuple_delimiter}GEO{tuple_delimiter}Capital city in Quintara) -{record_delimiter} -("entity"{tuple_delimiter}CASHION{tuple_delimiter}GEO{tuple_delimiter}Capital city in Aurelia) -{record_delimiter} -("entity"{tuple_delimiter}SAMUEL NAMARA{tuple_delimiter}PERSON{tuple_delimiter}Aurelian who spent time in Tiruzia's Alhamia Prison) -{record_delimiter} -("entity"{tuple_delimiter}ALHAMIA PRISON{tuple_delimiter}GEO{tuple_delimiter}Prison in Tiruzia) -{record_delimiter} -("entity"{tuple_delimiter}DURKE BATAGLANI{tuple_delimiter}PERSON{tuple_delimiter}Aurelian journalist who was held hostage) -{record_delimiter} -("entity"{tuple_delimiter}MEGGIE TAZBAH{tuple_delimiter}PERSON{tuple_delimiter}Bratinas national and environmentalist who was held hostage) -{record_delimiter} -("relationship"{tuple_delimiter}FIRUZABAD{tuple_delimiter}AURELIA{tuple_delimiter}Firuzabad negotiated a hostage exchange with Aurelia{tuple_delimiter}2) -{record_delimiter} -("relationship"{tuple_delimiter}QUINTARA{tuple_delimiter}AURELIA{tuple_delimiter}Quintara brokered the hostage exchange between Firuzabad and Aurelia{tuple_delimiter}2) -{record_delimiter} -("relationship"{tuple_delimiter}QUINTARA{tuple_delimiter}FIRUZABAD{tuple_delimiter}Quintara brokered the hostage exchange between Firuzabad and Aurelia{tuple_delimiter}2) -{record_delimiter} -("relationship"{tuple_delimiter}SAMUEL NAMARA{tuple_delimiter}ALHAMIA PRISON{tuple_delimiter}Samuel Namara was a prisoner at Alhamia prison{tuple_delimiter}8) -{record_delimiter} -("relationship"{tuple_delimiter}SAMUEL NAMARA{tuple_delimiter}MEGGIE TAZBAH{tuple_delimiter}Samuel Namara and Meggie Tazbah were exchanged in the same hostage release{tuple_delimiter}2) -{record_delimiter} -("relationship"{tuple_delimiter}SAMUEL NAMARA{tuple_delimiter}DURKE BATAGLANI{tuple_delimiter}Samuel Namara and Durke Bataglani were exchanged in the same hostage release{tuple_delimiter}2) -{record_delimiter} -("relationship"{tuple_delimiter}MEGGIE TAZBAH{tuple_delimiter}DURKE BATAGLANI{tuple_delimiter}Meggie Tazbah and Durke Bataglani were exchanged in the same hostage release{tuple_delimiter}2) -{record_delimiter} -("relationship"{tuple_delimiter}SAMUEL NAMARA{tuple_delimiter}FIRUZABAD{tuple_delimiter}Samuel Namara was a hostage in Firuzabad{tuple_delimiter}2) -{record_delimiter} -("relationship"{tuple_delimiter}MEGGIE TAZBAH{tuple_delimiter}FIRUZABAD{tuple_delimiter}Meggie Tazbah was a hostage in Firuzabad{tuple_delimiter}2) -{record_delimiter} -("relationship"{tuple_delimiter}DURKE BATAGLANI{tuple_delimiter}FIRUZABAD{tuple_delimiter}Durke Bataglani was a hostage in Firuzabad{tuple_delimiter}2) -{completion_delimiter} +("entity"<|>FIRUZABAD<|>GEO<|>Firuzabad held Aurelians as hostages) +## +("entity"<|>AURELIA<|>GEO<|>Country seeking to release hostages) +## +("entity"<|>QUINTARA<|>GEO<|>Country that negotiated a swap of money in exchange for hostages) +## +## +("entity"<|>TIRUZIA<|>GEO<|>Capital of Firuzabad where the Aurelians were being held) +## +("entity"<|>KROHAARA<|>GEO<|>Capital city in Quintara) +## +("entity"<|>CASHION<|>GEO<|>Capital city in Aurelia) +## +("entity"<|>SAMUEL NAMARA<|>PERSON<|>Aurelian who spent time in Tiruzia's Alhamia Prison) +## +("entity"<|>ALHAMIA PRISON<|>GEO<|>Prison in Tiruzia) +## +("entity"<|>DURKE BATAGLANI<|>PERSON<|>Aurelian journalist who was held hostage) +## +("entity"<|>MEGGIE TAZBAH<|>PERSON<|>Bratinas national and environmentalist who was held hostage) +## +("relationship"<|>FIRUZABAD<|>AURELIA<|>Firuzabad negotiated a hostage exchange with Aurelia<|>2) +## +("relationship"<|>QUINTARA<|>AURELIA<|>Quintara brokered the hostage exchange between Firuzabad and Aurelia<|>2) +## +("relationship"<|>QUINTARA<|>FIRUZABAD<|>Quintara brokered the hostage exchange between Firuzabad and Aurelia<|>2) +## +("relationship"<|>SAMUEL NAMARA<|>ALHAMIA PRISON<|>Samuel Namara was a prisoner at Alhamia prison<|>8) +## +("relationship"<|>SAMUEL NAMARA<|>MEGGIE TAZBAH<|>Samuel Namara and Meggie Tazbah were exchanged in the same hostage release<|>2) +## +("relationship"<|>SAMUEL NAMARA<|>DURKE BATAGLANI<|>Samuel Namara and Durke Bataglani were exchanged in the same hostage release<|>2) +## +("relationship"<|>MEGGIE TAZBAH<|>DURKE BATAGLANI<|>Meggie Tazbah and Durke Bataglani were exchanged in the same hostage release<|>2) +## +("relationship"<|>SAMUEL NAMARA<|>FIRUZABAD<|>Samuel Namara was a hostage in Firuzabad<|>2) +## +("relationship"<|>MEGGIE TAZBAH<|>FIRUZABAD<|>Meggie Tazbah was a hostage in Firuzabad<|>2) +## +("relationship"<|>DURKE BATAGLANI<|>FIRUZABAD<|>Durke Bataglani was a hostage in Firuzabad<|>2) +<|COMPLETE|> ###################### -Real Data- diff --git a/graphrag/query/llm/text_utils.py b/graphrag/query/llm/text_utils.py index 5ff1983aa1..ddd6abe1ef 100644 --- a/graphrag/query/llm/text_utils.py +++ b/graphrag/query/llm/text_utils.py @@ -11,7 +11,6 @@ from json_repair import repair_json -import graphrag.config.defaults as defs from graphrag.tokenizer.get_tokenizer import get_tokenizer from graphrag.tokenizer.tokenizer import Tokenizer @@ -36,7 +35,7 @@ def batched(iterable: Iterator, n: int): def chunk_text(text: str, max_tokens: int, tokenizer: Tokenizer | None = None): """Chunk text by token length.""" if tokenizer is None: - tokenizer = get_tokenizer(encoding_model=defs.ENCODING_MODEL) + tokenizer = get_tokenizer() tokens = tokenizer.encode(text) # type: ignore chunk_iterator = batched(iter(tokens), max_tokens) yield from (tokenizer.decode(list(chunk)) for chunk in chunk_iterator) diff --git a/tests/fixtures/min-csv/settings.yml b/tests/fixtures/min-csv/settings.yml index ee3ddd03f9..26143840e4 100644 --- a/tests/fixtures/min-csv/settings.yml +++ b/tests/fixtures/min-csv/settings.yml @@ -6,14 +6,12 @@ models: api_key: ${GRAPHRAG_API_KEY} api_base: ${GRAPHRAG_API_BASE} api_version: "2025-04-01-preview" - deployment_name: gpt-4.1 model: gpt-4.1 retry_strategy: exponential_backoff tokens_per_minute: null requests_per_minute: null model_supports_json: true concurrent_requests: 25 - async_mode: threaded default_embedding_model: azure_auth_type: api_key type: embedding @@ -21,13 +19,11 @@ models: api_key: ${GRAPHRAG_API_KEY} api_base: ${GRAPHRAG_API_BASE} api_version: "2025-04-01-preview" - deployment_name: text-embedding-ada-002 - model: text-embedding-ada-002 + model: text-embedding-3-large retry_strategy: exponential_backoff tokens_per_minute: null requests_per_minute: null concurrent_requests: 25 - async_mode: threaded vector_store: default_vector_store: diff --git a/tests/fixtures/text/settings.yml b/tests/fixtures/text/settings.yml index 163f92ccc4..633c12a03f 100644 --- a/tests/fixtures/text/settings.yml +++ b/tests/fixtures/text/settings.yml @@ -6,14 +6,12 @@ models: api_key: ${GRAPHRAG_API_KEY} api_base: ${GRAPHRAG_API_BASE} api_version: "2025-04-01-preview" - deployment_name: gpt-4.1 model: gpt-4.1 retry_strategy: exponential_backoff tokens_per_minute: null requests_per_minute: null model_supports_json: true concurrent_requests: 25 - async_mode: threaded default_embedding_model: azure_auth_type: api_key type: embedding @@ -21,13 +19,11 @@ models: api_key: ${GRAPHRAG_API_KEY} api_base: ${GRAPHRAG_API_BASE} api_version: "2025-04-01-preview" - deployment_name: text-embedding-ada-002 - model: text-embedding-ada-002 + model: text-embedding-3-large retry_strategy: exponential_backoff tokens_per_minute: null requests_per_minute: null concurrent_requests: 25 - async_mode: threaded vector_store: default_vector_store: diff --git a/tests/unit/config/utils.py b/tests/unit/config/utils.py index e322a2179d..f230692560 100644 --- a/tests/unit/config/utils.py +++ b/tests/unit/config/utils.py @@ -12,7 +12,8 @@ from graphrag.config.models.cluster_graph_config import ClusterGraphConfig from graphrag.config.models.community_reports_config import CommunityReportsConfig from graphrag.config.models.drift_search_config import DRIFTSearchConfig -from graphrag.config.models.extract_claims_config import ClaimExtractionConfig +from graphrag.config.models.embed_text_config import EmbedTextConfig +from graphrag.config.models.extract_claims_config import ExtractClaimsConfig from graphrag.config.models.extract_graph_config import ExtractGraphConfig from graphrag.config.models.extract_graph_nlp_config import ( ExtractGraphNLPConfig, @@ -30,7 +31,6 @@ from graphrag.config.models.summarize_descriptions_config import ( SummarizeDescriptionsConfig, ) -from graphrag.config.models.text_embedding_config import TextEmbeddingConfig from graphrag.config.models.vector_store_config import VectorStoreConfig FAKE_API_KEY = "NOT_AN_API_KEY" @@ -181,12 +181,11 @@ def assert_input_configs(actual: InputConfig, expected: InputConfig) -> None: def assert_text_embedding_configs( - actual: TextEmbeddingConfig, expected: TextEmbeddingConfig + actual: EmbedTextConfig, expected: EmbedTextConfig ) -> None: assert actual.batch_size == expected.batch_size assert actual.batch_max_tokens == expected.batch_max_tokens assert actual.names == expected.names - assert actual.strategy == expected.strategy assert actual.model_id == expected.model_id assert actual.vector_store_id == expected.vector_store_id @@ -213,7 +212,6 @@ def assert_extract_graph_configs( assert actual.prompt == expected.prompt assert actual.entity_types == expected.entity_types assert actual.max_gleanings == expected.max_gleanings - assert actual.strategy == expected.strategy assert actual.model_id == expected.model_id @@ -257,7 +255,6 @@ def assert_summarize_descriptions_configs( ) -> None: assert actual.prompt == expected.prompt assert actual.max_length == expected.max_length - assert actual.strategy == expected.strategy assert actual.model_id == expected.model_id @@ -268,18 +265,16 @@ def assert_community_reports_configs( assert actual.text_prompt == expected.text_prompt assert actual.max_length == expected.max_length assert actual.max_input_length == expected.max_input_length - assert actual.strategy == expected.strategy assert actual.model_id == expected.model_id def assert_extract_claims_configs( - actual: ClaimExtractionConfig, expected: ClaimExtractionConfig + actual: ExtractClaimsConfig, expected: ExtractClaimsConfig ) -> None: assert actual.enabled == expected.enabled assert actual.prompt == expected.prompt assert actual.description == expected.description assert actual.max_gleanings == expected.max_gleanings - assert actual.strategy == expected.strategy assert actual.model_id == expected.model_id diff --git a/tests/unit/indexing/text_splitting/test_text_splitting.py b/tests/unit/indexing/text_splitting/test_text_splitting.py index 8de8718a99..4ea8b25e5d 100644 --- a/tests/unit/indexing/text_splitting/test_text_splitting.py +++ b/tests/unit/indexing/text_splitting/test_text_splitting.py @@ -80,17 +80,6 @@ def test_token_text_splitter(mock_tokenizer, mock_split_text): mock_split_text.assert_called_once_with(text=text, tokenizer=mocked_tokenizer) -@mock.patch("tiktoken.encoding_for_model", side_effect=KeyError) -@mock.patch("tiktoken.get_encoding") -def test_model_name_exception(mock_get_encoding, mock_encoding_for_model): - mock_get_encoding.return_value = mock.MagicMock() - - TokenTextSplitter(model_name="mock_model", encoding_name="mock_encoding") - - mock_get_encoding.assert_called_once_with("mock_encoding") - mock_encoding_for_model.assert_called_once_with("mock_model") - - def test_split_single_text_on_tokens(): text = "This is a test text, meaning to be taken seriously by this test only." mocked_tokenizer = MockTokenizer() diff --git a/tests/unit/indexing/verbs/entities/extraction/strategies/graph_intelligence/test_gi_entity_extraction.py b/tests/unit/indexing/verbs/entities/extraction/strategies/graph_intelligence/test_gi_entity_extraction.py index 13e676e9fb..39f041b10c 100644 --- a/tests/unit/indexing/verbs/entities/extraction/strategies/graph_intelligence/test_gi_entity_extraction.py +++ b/tests/unit/indexing/verbs/entities/extraction/strategies/graph_intelligence/test_gi_entity_extraction.py @@ -2,24 +2,23 @@ # Licensed under the MIT License import unittest -from graphrag.index.operations.extract_graph.graph_intelligence_strategy import ( - run_extract_graph, -) +from graphrag.cache.factory import CacheFactory +from graphrag.index.operations.extract_graph.extract_graph import run_extract_graph from graphrag.index.operations.extract_graph.typing import ( Document, ) +from graphrag.prompts.index.extract_graph import GRAPH_EXTRACTION_PROMPT from tests.unit.indexing.verbs.helpers.mock_llm import create_mock_llm +_cache = CacheFactory.create_cache("none", kwargs={}) + class TestRunChain(unittest.IsolatedAsyncioTestCase): async def test_run_extract_graph_single_document_correct_entities_returned(self): results = await run_extract_graph( docs=[Document("test_text", "1")], entity_types=["person"], - args={ - "max_gleanings": 0, - "summarize_descriptions": False, - }, + max_gleanings=0, model=create_mock_llm( responses=[ """ @@ -36,6 +35,7 @@ async def test_run_extract_graph_single_document_correct_entities_returned(self) ], name="test_run_extract_graph_single_document_correct_entities_returned", ), + prompt=GRAPH_EXTRACTION_PROMPT, ) # self.assertItemsEqual isn't available yet, or I am just silly @@ -50,10 +50,7 @@ async def test_run_extract_graph_multiple_documents_correct_entities_returned( results = await run_extract_graph( docs=[Document("text_1", "1"), Document("text_2", "2")], entity_types=["person"], - args={ - "max_gleanings": 0, - "summarize_descriptions": False, - }, + max_gleanings=0, model=create_mock_llm( responses=[ """ @@ -74,6 +71,7 @@ async def test_run_extract_graph_multiple_documents_correct_entities_returned( ], name="test_run_extract_graph_multiple_documents_correct_entities_returned", ), + prompt=GRAPH_EXTRACTION_PROMPT, ) # self.assertItemsEqual isn't available yet, or I am just silly @@ -86,10 +84,7 @@ async def test_run_extract_graph_multiple_documents_correct_edges_returned(self) results = await run_extract_graph( docs=[Document("text_1", "1"), Document("text_2", "2")], entity_types=["person"], - args={ - "max_gleanings": 0, - "summarize_descriptions": False, - }, + max_gleanings=0, model=create_mock_llm( responses=[ """ @@ -110,6 +105,7 @@ async def test_run_extract_graph_multiple_documents_correct_edges_returned(self) ], name="test_run_extract_graph_multiple_documents_correct_edges_returned", ), + prompt=GRAPH_EXTRACTION_PROMPT, ) # self.assertItemsEqual isn't available yet, or I am just silly @@ -130,10 +126,7 @@ async def test_run_extract_graph_multiple_documents_correct_entity_source_ids_ma results = await run_extract_graph( docs=[Document("text_1", "1"), Document("text_2", "2")], entity_types=["person"], - args={ - "max_gleanings": 0, - "summarize_descriptions": False, - }, + max_gleanings=0, model=create_mock_llm( responses=[ """ @@ -154,6 +147,7 @@ async def test_run_extract_graph_multiple_documents_correct_entity_source_ids_ma ], name="test_run_extract_graph_multiple_documents_correct_entity_source_ids_mapped", ), + prompt=GRAPH_EXTRACTION_PROMPT, ) graph = results.graph @@ -179,10 +173,7 @@ async def test_run_extract_graph_multiple_documents_correct_edge_source_ids_mapp results = await run_extract_graph( docs=[Document("text_1", "1"), Document("text_2", "2")], entity_types=["person"], - args={ - "max_gleanings": 0, - "summarize_descriptions": False, - }, + max_gleanings=0, model=create_mock_llm( responses=[ """ @@ -203,6 +194,7 @@ async def test_run_extract_graph_multiple_documents_correct_edge_source_ids_mapp ], name="test_run_extract_graph_multiple_documents_correct_edge_source_ids_mapped", ), + prompt=GRAPH_EXTRACTION_PROMPT, ) graph = results.graph diff --git a/tests/unit/utils/test_encoding.py b/tests/unit/utils/test_encoding.py index aca5575b2e..7ad83b5d7f 100644 --- a/tests/unit/utils/test_encoding.py +++ b/tests/unit/utils/test_encoding.py @@ -8,7 +8,9 @@ def test_encode_basic(): tokenizer = get_tokenizer() result = tokenizer.encode("abc def") - assert result == [13997, 711], "Encoding failed to return expected tokens" + assert result == [26682, 1056], ( + f"Encoding failed to return expected tokens, sent {result}" + ) def test_num_tokens_empty_input(): diff --git a/tests/verbs/test_create_community_reports.py b/tests/verbs/test_create_community_reports.py index 56fe4a6221..3947d4dc58 100644 --- a/tests/verbs/test_create_community_reports.py +++ b/tests/verbs/test_create_community_reports.py @@ -3,7 +3,6 @@ from graphrag.config.create_graphrag_config import create_graphrag_config -from graphrag.config.enums import ModelType from graphrag.data_model.schemas import COMMUNITY_REPORTS_FINAL_COLUMNS from graphrag.index.operations.summarize_communities.community_reports_extractor import ( CommunityReportResponse, @@ -52,17 +51,8 @@ async def test_create_community_reports(): ) config = create_graphrag_config({"models": DEFAULT_MODEL_CONFIG}) - llm_settings = config.get_language_model_config( - config.community_reports.model_id - ).model_dump() - llm_settings["type"] = ModelType.MockChat - llm_settings["responses"] = MOCK_RESPONSES - llm_settings["parse_json"] = True - config.community_reports.strategy = { - "type": "graph_intelligence", - "llm": llm_settings, - "graph_prompt": "", - } + config.models["default_chat_model"].type = "mock_chat" + config.models["default_chat_model"].responses = MOCK_RESPONSES # type: ignore await run_workflow(config, context) diff --git a/tests/verbs/test_extract_covariates.py b/tests/verbs/test_extract_covariates.py index b873186541..c9c74f0fa9 100644 --- a/tests/verbs/test_extract_covariates.py +++ b/tests/verbs/test_extract_covariates.py @@ -32,17 +32,11 @@ async def test_extract_covariates(): ) config = create_graphrag_config({"models": DEFAULT_MODEL_CONFIG}) - llm_settings = config.get_language_model_config( - config.extract_claims.model_id - ).model_dump() - llm_settings["type"] = ModelType.MockChat - llm_settings["responses"] = MOCK_LLM_RESPONSES config.extract_claims.enabled = True - config.extract_claims.strategy = { - "type": "graph_intelligence", - "llm": llm_settings, - "claim_description": "description", - } + config.extract_claims.description = "description" + llm_settings = config.get_language_model_config(config.extract_claims.model_id) + llm_settings.type = ModelType.MockChat + llm_settings.responses = MOCK_LLM_RESPONSES # type: ignore await run_workflow(config, context) diff --git a/tests/verbs/test_extract_graph.py b/tests/verbs/test_extract_graph.py index 145d161d51..28dbfcf3c0 100644 --- a/tests/verbs/test_extract_graph.py +++ b/tests/verbs/test_extract_graph.py @@ -9,7 +9,8 @@ from graphrag.utils.storage import load_table_from_storage from .util import ( - DEFAULT_MODEL_CONFIG, + DEFAULT_CHAT_MODEL_CONFIG, + DEFAULT_EMBEDDING_MODEL_CONFIG, create_test_context, ) @@ -39,27 +40,23 @@ async def test_extract_graph(): storage=["text_units"], ) - config = create_graphrag_config({"models": DEFAULT_MODEL_CONFIG}) - extract_claims_llm_settings = config.get_language_model_config( - config.extract_graph.model_id - ).model_dump() - extract_claims_llm_settings["type"] = ModelType.MockChat - extract_claims_llm_settings["responses"] = MOCK_LLM_ENTITY_RESPONSES - config.extract_graph.strategy = { - "type": "graph_intelligence", - "llm": extract_claims_llm_settings, - } + extraction_model = DEFAULT_CHAT_MODEL_CONFIG.copy() + extraction_model["type"] = ModelType.MockChat + extraction_model["responses"] = MOCK_LLM_ENTITY_RESPONSES # type: ignore + config = create_graphrag_config({ + "models": { + "default_chat_model": extraction_model, + "default_embedding_model": DEFAULT_EMBEDDING_MODEL_CONFIG, + } + }) + summarize_llm_settings = config.get_language_model_config( config.summarize_descriptions.model_id ).model_dump() summarize_llm_settings["type"] = ModelType.MockChat summarize_llm_settings["responses"] = MOCK_LLM_SUMMARIZATION_RESPONSES - config.summarize_descriptions.strategy = { - "type": "graph_intelligence", - "llm": summarize_llm_settings, - "max_input_tokens": 1000, - "max_summary_length": 100, - } + config.summarize_descriptions.max_input_tokens = 1000 + config.summarize_descriptions.max_length = 100 await run_workflow(config, context) diff --git a/tests/verbs/test_generate_text_embeddings.py b/tests/verbs/test_generate_text_embeddings.py index b0e47d1638..33254874e7 100644 --- a/tests/verbs/test_generate_text_embeddings.py +++ b/tests/verbs/test_generate_text_embeddings.py @@ -6,7 +6,6 @@ all_embeddings, ) from graphrag.config.enums import ModelType -from graphrag.index.operations.embed_text.embed_text import TextEmbedStrategyType from graphrag.index.workflows.generate_text_embeddings import ( run_workflow, ) @@ -30,15 +29,9 @@ async def test_generate_text_embeddings(): ) config = create_graphrag_config({"models": DEFAULT_MODEL_CONFIG}) - llm_settings = config.get_language_model_config( - config.embed_text.model_id - ).model_dump() - llm_settings["type"] = ModelType.MockEmbedding + llm_settings = config.get_language_model_config(config.embed_text.model_id) + llm_settings.type = ModelType.MockEmbedding - config.embed_text.strategy = { - "type": TextEmbedStrategyType.openai, - "llm": llm_settings, - } config.embed_text.names = list(all_embeddings) config.snapshots.embeddings = True