Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/config/models.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,20 +95,20 @@ Many users have used platforms such as [ollama](https://ollama.com/) and [LiteLL

### Model Protocol

As of GraphRAG 2.0.0, we support model injection through the use of a standard chat and embedding Protocol and an accompanying ModelFactory that you can use to register your model implementation. This is not supported with the CLI, so you'll need to use GraphRAG as a library.
As of GraphRAG 2.0.0, we support model injection through the use of a standard chat and embedding Protocol and an accompanying factories that you can use to register your model implementation. This is not supported with the CLI, so you'll need to use GraphRAG as a library.

- Our Protocol is [defined here](https://github.com/microsoft/graphrag/blob/main/graphrag/language_model/protocol/base.py)
- We have a simple mock implementation in our tests that you can [reference here](https://github.com/microsoft/graphrag/blob/main/tests/mock_provider.py)

Once you have a model implementation, you need to register it with our ModelFactory:
Once you have a model implementation, you need to register it with our ChatModelFactory or EmbeddingModelFactory:

```python
class MyCustomModel:
...
# implementation

# elsewhere...
ModelFactory.register_chat("my-custom-chat-model", lambda **kwargs: MyCustomModel(**kwargs))
ChatModelFactory.register("my-custom-chat-model", lambda **kwargs: MyCustomModel(**kwargs))
```

Then in your config you can reference the type name you used:
Expand Down
58 changes: 23 additions & 35 deletions docs/examples_notebooks/custom_vector_store.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -155,18 +155,19 @@
" self.connected = True\n",
" print(f\"✅ Connected to in-memory vector store: {self.index_name}\")\n",
"\n",
" def load_documents(\n",
" self, documents: list[VectorStoreDocument], overwrite: bool = True\n",
" ) -> None:\n",
" def create_index(self) -> None:\n",
" \"\"\"Create index in the vector store (no-op for in-memory store).\"\"\"\n",
" self.documents.clear()\n",
" self.vectors.clear()\n",
"\n",
" print(f\"✅ Index '{self.index_name}' is ready in in-memory vector store\")\n",
"\n",
" def load_documents(self, documents: list[VectorStoreDocument]) -> None:\n",
" \"\"\"Load documents into the vector store.\"\"\"\n",
" if not self.connected:\n",
" msg = \"Vector store not connected. Call connect() first.\"\n",
" raise RuntimeError(msg)\n",
"\n",
" if overwrite:\n",
" self.documents.clear()\n",
" self.vectors.clear()\n",
"\n",
" loaded_count = 0\n",
" for doc in documents:\n",
" if doc.vector is not None:\n",
Expand Down Expand Up @@ -230,13 +231,6 @@
" # Use vector search with the embedding\n",
" return self.similarity_search_by_vector(query_embedding, k, **kwargs)\n",
"\n",
" def filter_by_id(self, include_ids: list[str] | list[int]) -> Any:\n",
" \"\"\"Build a query filter to filter documents by id.\n",
"\n",
" For this simple implementation, we return the list of IDs as the filter.\n",
" \"\"\"\n",
" return [str(id_) for id_ in include_ids]\n",
"\n",
" def search_by_id(self, id: str) -> VectorStoreDocument:\n",
" \"\"\"Search for a document by id.\"\"\"\n",
" doc_id = str(id)\n",
Expand Down Expand Up @@ -281,15 +275,15 @@
"CUSTOM_VECTOR_STORE_TYPE = \"simple_memory\"\n",
"\n",
"# Register the vector store class\n",
"VectorStoreFactory.register(CUSTOM_VECTOR_STORE_TYPE, SimpleInMemoryVectorStore)\n",
"VectorStoreFactory().register(CUSTOM_VECTOR_STORE_TYPE, SimpleInMemoryVectorStore)\n",
"\n",
"print(f\"✅ Registered custom vector store with type: '{CUSTOM_VECTOR_STORE_TYPE}'\")\n",
"\n",
"# Verify registration\n",
"available_types = VectorStoreFactory.get_vector_store_types()\n",
"available_types = VectorStoreFactory().keys()\n",
"print(f\"\\n📋 Available vector store types: {available_types}\")\n",
"print(\n",
" f\"🔍 Is our custom type supported? {VectorStoreFactory.is_supported_type(CUSTOM_VECTOR_STORE_TYPE)}\"\n",
" f\"🔍 Is our custom type supported? {CUSTOM_VECTOR_STORE_TYPE in VectorStoreFactory()}\"\n",
")"
]
},
Expand Down Expand Up @@ -347,8 +341,8 @@
"schema = VectorStoreSchemaConfig(index_name=\"test_collection\")\n",
"\n",
"# Create vector store instance using factory\n",
"vector_store = VectorStoreFactory.create_vector_store(\n",
" CUSTOM_VECTOR_STORE_TYPE, vector_store_schema_config=schema\n",
"vector_store = VectorStoreFactory().create(\n",
" CUSTOM_VECTOR_STORE_TYPE, {\"vector_store_schema_config\": schema}\n",
")\n",
"\n",
"print(f\"✅ Created vector store instance: {type(vector_store).__name__}\")\n",
Expand All @@ -363,6 +357,7 @@
"source": [
"# Connect and load documents\n",
"vector_store.connect()\n",
"vector_store.create_index()\n",
"vector_store.load_documents(sample_documents)\n",
"\n",
"print(f\"📊 Updated stats: {vector_store.get_stats()}\")"
Expand Down Expand Up @@ -472,13 +467,12 @@
" # 1. GraphRAG creates vector store using factory\n",
" schema = VectorStoreSchemaConfig(index_name=\"graphrag_entities\")\n",
"\n",
" store = VectorStoreFactory.create_vector_store(\n",
" store = VectorStoreFactory().create(\n",
" CUSTOM_VECTOR_STORE_TYPE,\n",
" vector_store_schema_config=schema,\n",
" similarity_threshold=0.3,\n",
" {\"vector_store_schema_config\": schema, \"similarity_threshold\": 0.3},\n",
" )\n",
" store.connect()\n",
"\n",
" store.create_index()\n",
" print(\"✅ Step 1: Vector store created and connected\")\n",
"\n",
" # 2. During indexing, GraphRAG loads extracted entities\n",
Expand Down Expand Up @@ -534,12 +528,12 @@
"\n",
" # Test 1: Basic functionality\n",
" print(\"Test 1: Basic functionality\")\n",
" store = VectorStoreFactory.create_vector_store(\n",
" store = VectorStoreFactory().create(\n",
" CUSTOM_VECTOR_STORE_TYPE,\n",
" vector_store_schema_config=VectorStoreSchemaConfig(index_name=\"test\"),\n",
" {\"vector_store_schema_config\": VectorStoreSchemaConfig(index_name=\"test\")},\n",
" )\n",
" store.connect()\n",
"\n",
" store.create_index()\n",
" # Load test documents\n",
" test_docs = sample_documents[:2]\n",
" store.load_documents(test_docs)\n",
Expand Down Expand Up @@ -575,17 +569,11 @@
"\n",
" print(\"✅ Search by ID test passed\")\n",
"\n",
" # Test 4: Filter functionality\n",
" print(\"\\nTest 4: Filter functionality\")\n",
" filter_result = store.filter_by_id([\"doc_1\", \"doc_2\"])\n",
" assert filter_result == [\"doc_1\", \"doc_2\"], \"Should return filtered IDs\"\n",
" print(\"✅ Filter functionality test passed\")\n",
"\n",
" # Test 5: Error handling\n",
" # Test 4: Error handling\n",
" print(\"\\nTest 5: Error handling\")\n",
" disconnected_store = VectorStoreFactory.create_vector_store(\n",
" disconnected_store = VectorStoreFactory().create(\n",
" CUSTOM_VECTOR_STORE_TYPE,\n",
" vector_store_schema_config=VectorStoreSchemaConfig(index_name=\"test2\"),\n",
" {\"vector_store_schema_config\": VectorStoreSchemaConfig(index_name=\"test2\")},\n",
" )\n",
"\n",
" try:\n",
Expand Down
1 change: 1 addition & 0 deletions docs/index/architecture.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ Several subsystems within GraphRAG use a factory pattern to register and retriev
The following subsystems use a factory pattern that allows you to register your own implementations:

- [language model](https://github.com/microsoft/graphrag/blob/main/graphrag/language_model/factory.py) - implement your own `chat` and `embed` methods to use a model provider of choice beyond the built-in OpenAI/Azure support
- [input reader](https://github.com/microsoft/graphrag/blob/main/graphrag/index/input/factory.py) - implement your own input document reader to support file types other than text, CSV, and JSON
- [cache](https://github.com/microsoft/graphrag/blob/main/graphrag/cache/factory.py) - create your own cache storage location in addition to the file, blob, and CosmosDB ones we provide
- [logger](https://github.com/microsoft/graphrag/blob/main/graphrag/logger/factory.py) - create your own log writing location in addition to the built-in file and blob storage
- [storage](https://github.com/microsoft/graphrag/blob/main/graphrag/storage/factory.py) - create your own storage provider (database, etc.) beyond the file, blob, and CosmosDB ones built in
Expand Down
4 changes: 4 additions & 0 deletions docs/index/inputs.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ Also see the [outputs](outputs.md) documentation for the final documents table s

As of version 2.6.0, GraphRAG's [indexing API method](https://github.com/microsoft/graphrag/blob/main/graphrag/api/index.py) allows you to pass in your own pandas DataFrame and bypass all of the input loading/parsing described in the next section. This is convenient if you have content in a format or storage location we don't support out-of-the-box. __You must ensure that your input DataFrame conforms to the schema described above.__ All of the chunking behavior described later will proceed exactly the same.

## Custom File Handling

As of version 3.0.0, we have migrated to using an injectable InputReader provider class. This means you can implement any input file handling you want in a class that extends InputReader and register it with the InputReaderFactory. See the [architecture page](https://microsoft.github.io/graphrag/index/architecture/) for more info on our standard provider pattern.

## Formats

We support three file formats out-of-the-box. This covers the overwhelming majority of use cases we have encountered. If you have a different format, we recommend writing a script to convert to one of these, which are widely used and supported by many tools and libraries.
Expand Down
67 changes: 9 additions & 58 deletions graphrag/cache/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,18 @@

from __future__ import annotations

from typing import TYPE_CHECKING, ClassVar

from graphrag.cache.json_pipeline_cache import JsonPipelineCache
from graphrag.cache.memory_pipeline_cache import InMemoryCache
from graphrag.cache.noop_pipeline_cache import NoopPipelineCache
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.config.enums import CacheType
from graphrag.factory.factory import Factory
from graphrag.storage.blob_pipeline_storage import BlobPipelineStorage
from graphrag.storage.cosmosdb_pipeline_storage import CosmosDBPipelineStorage
from graphrag.storage.file_pipeline_storage import FilePipelineStorage

if TYPE_CHECKING:
from collections.abc import Callable

from graphrag.cache.pipeline_cache import PipelineCache


class CacheFactory:
class CacheFactory(Factory[PipelineCache]):
"""A factory class for cache implementations.

Includes a method for users to register a custom cache implementation.
Expand All @@ -30,51 +25,6 @@ class CacheFactory:
for individual enforcement of required/optional arguments.
"""

_registry: ClassVar[dict[str, Callable[..., PipelineCache]]] = {}

@classmethod
def register(cls, cache_type: str, creator: Callable[..., PipelineCache]) -> None:
"""Register a custom cache implementation.

Args:
cache_type: The type identifier for the cache.
creator: A class or callable that creates an instance of PipelineCache.
"""
cls._registry[cache_type] = creator

@classmethod
def create_cache(cls, cache_type: str, kwargs: dict) -> PipelineCache:
"""Create a cache object from the provided type.

Args:
cache_type: The type of cache to create.
root_dir: The root directory for file-based caches.
kwargs: Additional keyword arguments for the cache constructor.

Returns
-------
A PipelineCache instance.

Raises
------
ValueError: If the cache type is not registered.
"""
if cache_type not in cls._registry:
msg = f"Unknown cache type: {cache_type}"
raise ValueError(msg)

return cls._registry[cache_type](**kwargs)

@classmethod
def get_cache_types(cls) -> list[str]:
"""Get the registered cache implementations."""
return list(cls._registry.keys())

@classmethod
def is_supported_type(cls, cache_type: str) -> bool:
"""Check if the given cache type is supported."""
return cache_type in cls._registry


# --- register built-in cache implementations ---
def create_file_cache(root_dir: str, base_dir: str, **kwargs) -> PipelineCache:
Expand Down Expand Up @@ -108,8 +58,9 @@ def create_memory_cache(**kwargs) -> PipelineCache:


# --- register built-in cache implementations ---
CacheFactory.register(CacheType.none.value, create_noop_cache)
CacheFactory.register(CacheType.memory.value, create_memory_cache)
CacheFactory.register(CacheType.file.value, create_file_cache)
CacheFactory.register(CacheType.blob.value, create_blob_cache)
CacheFactory.register(CacheType.cosmosdb.value, create_cosmosdb_cache)
cache_factory = CacheFactory()
cache_factory.register(CacheType.none.value, create_noop_cache)
cache_factory.register(CacheType.memory.value, create_memory_cache)
cache_factory.register(CacheType.file.value, create_file_cache)
cache_factory.register(CacheType.blob.value, create_blob_cache)
cache_factory.register(CacheType.cosmosdb.value, create_cosmosdb_cache)
31 changes: 0 additions & 31 deletions graphrag/config/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

"""Common default configuration values."""

from collections.abc import Callable
from dataclasses import dataclass, field
from pathlib import Path
from typing import ClassVar
Expand All @@ -24,25 +23,6 @@
from graphrag.index.operations.build_noun_graph.np_extractors.stop_words import (
EN_STOP_WORDS,
)
from graphrag.language_model.providers.litellm.services.rate_limiter.rate_limiter import (
RateLimiter,
)
from graphrag.language_model.providers.litellm.services.rate_limiter.static_rate_limiter import (
StaticRateLimiter,
)
from graphrag.language_model.providers.litellm.services.retry.exponential_retry import (
ExponentialRetry,
)
from graphrag.language_model.providers.litellm.services.retry.incremental_wait_retry import (
IncrementalWaitRetry,
)
from graphrag.language_model.providers.litellm.services.retry.native_wait_retry import (
NativeRetry,
)
from graphrag.language_model.providers.litellm.services.retry.random_wait_retry import (
RandomWaitRetry,
)
from graphrag.language_model.providers.litellm.services.retry.retry import Retry

DEFAULT_OUTPUT_BASE_DIR = "output"
DEFAULT_CHAT_MODEL_ID = "default_chat_model"
Expand All @@ -60,17 +40,6 @@

DEFAULT_ENTITY_TYPES = ["organization", "person", "geo", "event"]

DEFAULT_RETRY_SERVICES: dict[str, Callable[..., Retry]] = {
"native": NativeRetry,
"exponential_backoff": ExponentialRetry,
"random_wait": RandomWaitRetry,
"incremental_wait": IncrementalWaitRetry,
}

DEFAULT_RATE_LIMITER_SERVICES: dict[str, Callable[..., RateLimiter]] = {
"static": StaticRateLimiter,
}


@dataclass
class BasicSearchDefaults:
Expand Down
9 changes: 6 additions & 3 deletions graphrag/config/models/graph_rag_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,10 @@ def _validate_retry_services(self) -> None:

_ = retry_factory.create(
strategy=model.retry_strategy,
max_retries=model.max_retries,
max_retry_wait=model.max_retry_wait,
init_args={
"max_retries": model.max_retries,
"max_retry_wait": model.max_retry_wait,
},
)

def _validate_rate_limiter_services(self) -> None:
Expand All @@ -130,7 +132,8 @@ def _validate_rate_limiter_services(self) -> None:
)
if rpm is not None or tpm is not None:
_ = rate_limiter_factory.create(
strategy=model.rate_limit_strategy, rpm=rpm, tpm=tpm
strategy=model.rate_limit_strategy,
init_args={"rpm": rpm, "tpm": tpm},
)

input: InputConfig = Field(
Expand Down
9 changes: 6 additions & 3 deletions graphrag/config/models/language_model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
AzureApiVersionMissingError,
ConflictingSettingsError,
)
from graphrag.language_model.factory import ModelFactory
from graphrag.language_model.factory import ChatModelFactory, EmbeddingModelFactory

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -91,8 +91,11 @@ def _validate_type(self) -> None:
If the model name is not recognized.
"""
# Type should be contained by the registered models
if not ModelFactory.is_supported_model(self.type):
msg = f"Model type {self.type} is not recognized, must be one of {ModelFactory.get_chat_models() + ModelFactory.get_embedding_models()}."
if (
self.type not in ChatModelFactory()
and self.type not in EmbeddingModelFactory()
):
msg = f"Model type {self.type} is not recognized, must be one of {ChatModelFactory().keys() + EmbeddingModelFactory().keys()}."
raise KeyError(msg)

model_provider: str | None = Field(
Expand Down
Loading