diff --git a/docs/config/models.md b/docs/config/models.md index 676cb4b2ff..ca3f044e3d 100644 --- a/docs/config/models.md +++ b/docs/config/models.md @@ -95,12 +95,12 @@ Many users have used platforms such as [ollama](https://ollama.com/) and [LiteLL ### Model Protocol -As of GraphRAG 2.0.0, we support model injection through the use of a standard chat and embedding Protocol and an accompanying ModelFactory that you can use to register your model implementation. This is not supported with the CLI, so you'll need to use GraphRAG as a library. +As of GraphRAG 2.0.0, we support model injection through the use of a standard chat and embedding Protocol and an accompanying factories that you can use to register your model implementation. This is not supported with the CLI, so you'll need to use GraphRAG as a library. - Our Protocol is [defined here](https://github.com/microsoft/graphrag/blob/main/graphrag/language_model/protocol/base.py) - We have a simple mock implementation in our tests that you can [reference here](https://github.com/microsoft/graphrag/blob/main/tests/mock_provider.py) -Once you have a model implementation, you need to register it with our ModelFactory: +Once you have a model implementation, you need to register it with our ChatModelFactory or EmbeddingModelFactory: ```python class MyCustomModel: @@ -108,7 +108,7 @@ class MyCustomModel: # implementation # elsewhere... -ModelFactory.register_chat("my-custom-chat-model", lambda **kwargs: MyCustomModel(**kwargs)) +ChatModelFactory.register("my-custom-chat-model", lambda **kwargs: MyCustomModel(**kwargs)) ``` Then in your config you can reference the type name you used: diff --git a/docs/examples_notebooks/custom_vector_store.ipynb b/docs/examples_notebooks/custom_vector_store.ipynb index bce47f8bac..fe8495bbae 100644 --- a/docs/examples_notebooks/custom_vector_store.ipynb +++ b/docs/examples_notebooks/custom_vector_store.ipynb @@ -155,18 +155,19 @@ " self.connected = True\n", " print(f\"āœ… Connected to in-memory vector store: {self.index_name}\")\n", "\n", - " def load_documents(\n", - " self, documents: list[VectorStoreDocument], overwrite: bool = True\n", - " ) -> None:\n", + " def create_index(self) -> None:\n", + " \"\"\"Create index in the vector store (no-op for in-memory store).\"\"\"\n", + " self.documents.clear()\n", + " self.vectors.clear()\n", + "\n", + " print(f\"āœ… Index '{self.index_name}' is ready in in-memory vector store\")\n", + "\n", + " def load_documents(self, documents: list[VectorStoreDocument]) -> None:\n", " \"\"\"Load documents into the vector store.\"\"\"\n", " if not self.connected:\n", " msg = \"Vector store not connected. Call connect() first.\"\n", " raise RuntimeError(msg)\n", "\n", - " if overwrite:\n", - " self.documents.clear()\n", - " self.vectors.clear()\n", - "\n", " loaded_count = 0\n", " for doc in documents:\n", " if doc.vector is not None:\n", @@ -230,13 +231,6 @@ " # Use vector search with the embedding\n", " return self.similarity_search_by_vector(query_embedding, k, **kwargs)\n", "\n", - " def filter_by_id(self, include_ids: list[str] | list[int]) -> Any:\n", - " \"\"\"Build a query filter to filter documents by id.\n", - "\n", - " For this simple implementation, we return the list of IDs as the filter.\n", - " \"\"\"\n", - " return [str(id_) for id_ in include_ids]\n", - "\n", " def search_by_id(self, id: str) -> VectorStoreDocument:\n", " \"\"\"Search for a document by id.\"\"\"\n", " doc_id = str(id)\n", @@ -281,15 +275,15 @@ "CUSTOM_VECTOR_STORE_TYPE = \"simple_memory\"\n", "\n", "# Register the vector store class\n", - "VectorStoreFactory.register(CUSTOM_VECTOR_STORE_TYPE, SimpleInMemoryVectorStore)\n", + "VectorStoreFactory().register(CUSTOM_VECTOR_STORE_TYPE, SimpleInMemoryVectorStore)\n", "\n", "print(f\"āœ… Registered custom vector store with type: '{CUSTOM_VECTOR_STORE_TYPE}'\")\n", "\n", "# Verify registration\n", - "available_types = VectorStoreFactory.get_vector_store_types()\n", + "available_types = VectorStoreFactory().keys()\n", "print(f\"\\nšŸ“‹ Available vector store types: {available_types}\")\n", "print(\n", - " f\"šŸ” Is our custom type supported? {VectorStoreFactory.is_supported_type(CUSTOM_VECTOR_STORE_TYPE)}\"\n", + " f\"šŸ” Is our custom type supported? {CUSTOM_VECTOR_STORE_TYPE in VectorStoreFactory()}\"\n", ")" ] }, @@ -347,8 +341,8 @@ "schema = VectorStoreSchemaConfig(index_name=\"test_collection\")\n", "\n", "# Create vector store instance using factory\n", - "vector_store = VectorStoreFactory.create_vector_store(\n", - " CUSTOM_VECTOR_STORE_TYPE, vector_store_schema_config=schema\n", + "vector_store = VectorStoreFactory().create(\n", + " CUSTOM_VECTOR_STORE_TYPE, {\"vector_store_schema_config\": schema}\n", ")\n", "\n", "print(f\"āœ… Created vector store instance: {type(vector_store).__name__}\")\n", @@ -363,6 +357,7 @@ "source": [ "# Connect and load documents\n", "vector_store.connect()\n", + "vector_store.create_index()\n", "vector_store.load_documents(sample_documents)\n", "\n", "print(f\"šŸ“Š Updated stats: {vector_store.get_stats()}\")" @@ -472,13 +467,12 @@ " # 1. GraphRAG creates vector store using factory\n", " schema = VectorStoreSchemaConfig(index_name=\"graphrag_entities\")\n", "\n", - " store = VectorStoreFactory.create_vector_store(\n", + " store = VectorStoreFactory().create(\n", " CUSTOM_VECTOR_STORE_TYPE,\n", - " vector_store_schema_config=schema,\n", - " similarity_threshold=0.3,\n", + " {\"vector_store_schema_config\": schema, \"similarity_threshold\": 0.3},\n", " )\n", " store.connect()\n", - "\n", + " store.create_index()\n", " print(\"āœ… Step 1: Vector store created and connected\")\n", "\n", " # 2. During indexing, GraphRAG loads extracted entities\n", @@ -534,12 +528,12 @@ "\n", " # Test 1: Basic functionality\n", " print(\"Test 1: Basic functionality\")\n", - " store = VectorStoreFactory.create_vector_store(\n", + " store = VectorStoreFactory().create(\n", " CUSTOM_VECTOR_STORE_TYPE,\n", - " vector_store_schema_config=VectorStoreSchemaConfig(index_name=\"test\"),\n", + " {\"vector_store_schema_config\": VectorStoreSchemaConfig(index_name=\"test\")},\n", " )\n", " store.connect()\n", - "\n", + " store.create_index()\n", " # Load test documents\n", " test_docs = sample_documents[:2]\n", " store.load_documents(test_docs)\n", @@ -575,17 +569,11 @@ "\n", " print(\"āœ… Search by ID test passed\")\n", "\n", - " # Test 4: Filter functionality\n", - " print(\"\\nTest 4: Filter functionality\")\n", - " filter_result = store.filter_by_id([\"doc_1\", \"doc_2\"])\n", - " assert filter_result == [\"doc_1\", \"doc_2\"], \"Should return filtered IDs\"\n", - " print(\"āœ… Filter functionality test passed\")\n", - "\n", - " # Test 5: Error handling\n", + " # Test 4: Error handling\n", " print(\"\\nTest 5: Error handling\")\n", - " disconnected_store = VectorStoreFactory.create_vector_store(\n", + " disconnected_store = VectorStoreFactory().create(\n", " CUSTOM_VECTOR_STORE_TYPE,\n", - " vector_store_schema_config=VectorStoreSchemaConfig(index_name=\"test2\"),\n", + " {\"vector_store_schema_config\": VectorStoreSchemaConfig(index_name=\"test2\")},\n", " )\n", "\n", " try:\n", diff --git a/docs/index/architecture.md b/docs/index/architecture.md index 199538a0c9..be73bae7ed 100644 --- a/docs/index/architecture.md +++ b/docs/index/architecture.md @@ -39,6 +39,7 @@ Several subsystems within GraphRAG use a factory pattern to register and retriev The following subsystems use a factory pattern that allows you to register your own implementations: - [language model](https://github.com/microsoft/graphrag/blob/main/graphrag/language_model/factory.py) - implement your own `chat` and `embed` methods to use a model provider of choice beyond the built-in OpenAI/Azure support +- [input reader](https://github.com/microsoft/graphrag/blob/main/graphrag/index/input/factory.py) - implement your own input document reader to support file types other than text, CSV, and JSON - [cache](https://github.com/microsoft/graphrag/blob/main/graphrag/cache/factory.py) - create your own cache storage location in addition to the file, blob, and CosmosDB ones we provide - [logger](https://github.com/microsoft/graphrag/blob/main/graphrag/logger/factory.py) - create your own log writing location in addition to the built-in file and blob storage - [storage](https://github.com/microsoft/graphrag/blob/main/graphrag/storage/factory.py) - create your own storage provider (database, etc.) beyond the file, blob, and CosmosDB ones built in diff --git a/docs/index/inputs.md b/docs/index/inputs.md index af8a310825..4bc5ad620a 100644 --- a/docs/index/inputs.md +++ b/docs/index/inputs.md @@ -20,6 +20,10 @@ Also see the [outputs](outputs.md) documentation for the final documents table s As of version 2.6.0, GraphRAG's [indexing API method](https://github.com/microsoft/graphrag/blob/main/graphrag/api/index.py) allows you to pass in your own pandas DataFrame and bypass all of the input loading/parsing described in the next section. This is convenient if you have content in a format or storage location we don't support out-of-the-box. __You must ensure that your input DataFrame conforms to the schema described above.__ All of the chunking behavior described later will proceed exactly the same. +## Custom File Handling + +As of version 3.0.0, we have migrated to using an injectable InputReader provider class. This means you can implement any input file handling you want in a class that extends InputReader and register it with the InputReaderFactory. See the [architecture page](https://microsoft.github.io/graphrag/index/architecture/) for more info on our standard provider pattern. + ## Formats We support three file formats out-of-the-box. This covers the overwhelming majority of use cases we have encountered. If you have a different format, we recommend writing a script to convert to one of these, which are widely used and supported by many tools and libraries. diff --git a/graphrag/cache/factory.py b/graphrag/cache/factory.py index 331540260d..f406fc856d 100644 --- a/graphrag/cache/factory.py +++ b/graphrag/cache/factory.py @@ -5,23 +5,18 @@ from __future__ import annotations -from typing import TYPE_CHECKING, ClassVar - from graphrag.cache.json_pipeline_cache import JsonPipelineCache from graphrag.cache.memory_pipeline_cache import InMemoryCache from graphrag.cache.noop_pipeline_cache import NoopPipelineCache +from graphrag.cache.pipeline_cache import PipelineCache from graphrag.config.enums import CacheType +from graphrag.factory.factory import Factory from graphrag.storage.blob_pipeline_storage import BlobPipelineStorage from graphrag.storage.cosmosdb_pipeline_storage import CosmosDBPipelineStorage from graphrag.storage.file_pipeline_storage import FilePipelineStorage -if TYPE_CHECKING: - from collections.abc import Callable - - from graphrag.cache.pipeline_cache import PipelineCache - -class CacheFactory: +class CacheFactory(Factory[PipelineCache]): """A factory class for cache implementations. Includes a method for users to register a custom cache implementation. @@ -30,51 +25,6 @@ class CacheFactory: for individual enforcement of required/optional arguments. """ - _registry: ClassVar[dict[str, Callable[..., PipelineCache]]] = {} - - @classmethod - def register(cls, cache_type: str, creator: Callable[..., PipelineCache]) -> None: - """Register a custom cache implementation. - - Args: - cache_type: The type identifier for the cache. - creator: A class or callable that creates an instance of PipelineCache. - """ - cls._registry[cache_type] = creator - - @classmethod - def create_cache(cls, cache_type: str, kwargs: dict) -> PipelineCache: - """Create a cache object from the provided type. - - Args: - cache_type: The type of cache to create. - root_dir: The root directory for file-based caches. - kwargs: Additional keyword arguments for the cache constructor. - - Returns - ------- - A PipelineCache instance. - - Raises - ------ - ValueError: If the cache type is not registered. - """ - if cache_type not in cls._registry: - msg = f"Unknown cache type: {cache_type}" - raise ValueError(msg) - - return cls._registry[cache_type](**kwargs) - - @classmethod - def get_cache_types(cls) -> list[str]: - """Get the registered cache implementations.""" - return list(cls._registry.keys()) - - @classmethod - def is_supported_type(cls, cache_type: str) -> bool: - """Check if the given cache type is supported.""" - return cache_type in cls._registry - # --- register built-in cache implementations --- def create_file_cache(root_dir: str, base_dir: str, **kwargs) -> PipelineCache: @@ -108,8 +58,9 @@ def create_memory_cache(**kwargs) -> PipelineCache: # --- register built-in cache implementations --- -CacheFactory.register(CacheType.none.value, create_noop_cache) -CacheFactory.register(CacheType.memory.value, create_memory_cache) -CacheFactory.register(CacheType.file.value, create_file_cache) -CacheFactory.register(CacheType.blob.value, create_blob_cache) -CacheFactory.register(CacheType.cosmosdb.value, create_cosmosdb_cache) +cache_factory = CacheFactory() +cache_factory.register(CacheType.none.value, create_noop_cache) +cache_factory.register(CacheType.memory.value, create_memory_cache) +cache_factory.register(CacheType.file.value, create_file_cache) +cache_factory.register(CacheType.blob.value, create_blob_cache) +cache_factory.register(CacheType.cosmosdb.value, create_cosmosdb_cache) diff --git a/graphrag/config/defaults.py b/graphrag/config/defaults.py index 6f36169afb..9c6a65008c 100644 --- a/graphrag/config/defaults.py +++ b/graphrag/config/defaults.py @@ -3,7 +3,6 @@ """Common default configuration values.""" -from collections.abc import Callable from dataclasses import dataclass, field from pathlib import Path from typing import ClassVar @@ -24,25 +23,6 @@ from graphrag.index.operations.build_noun_graph.np_extractors.stop_words import ( EN_STOP_WORDS, ) -from graphrag.language_model.providers.litellm.services.rate_limiter.rate_limiter import ( - RateLimiter, -) -from graphrag.language_model.providers.litellm.services.rate_limiter.static_rate_limiter import ( - StaticRateLimiter, -) -from graphrag.language_model.providers.litellm.services.retry.exponential_retry import ( - ExponentialRetry, -) -from graphrag.language_model.providers.litellm.services.retry.incremental_wait_retry import ( - IncrementalWaitRetry, -) -from graphrag.language_model.providers.litellm.services.retry.native_wait_retry import ( - NativeRetry, -) -from graphrag.language_model.providers.litellm.services.retry.random_wait_retry import ( - RandomWaitRetry, -) -from graphrag.language_model.providers.litellm.services.retry.retry import Retry DEFAULT_OUTPUT_BASE_DIR = "output" DEFAULT_CHAT_MODEL_ID = "default_chat_model" @@ -60,17 +40,6 @@ DEFAULT_ENTITY_TYPES = ["organization", "person", "geo", "event"] -DEFAULT_RETRY_SERVICES: dict[str, Callable[..., Retry]] = { - "native": NativeRetry, - "exponential_backoff": ExponentialRetry, - "random_wait": RandomWaitRetry, - "incremental_wait": IncrementalWaitRetry, -} - -DEFAULT_RATE_LIMITER_SERVICES: dict[str, Callable[..., RateLimiter]] = { - "static": StaticRateLimiter, -} - @dataclass class BasicSearchDefaults: diff --git a/graphrag/config/models/graph_rag_config.py b/graphrag/config/models/graph_rag_config.py index c8cdca819c..6a6a98e973 100644 --- a/graphrag/config/models/graph_rag_config.py +++ b/graphrag/config/models/graph_rag_config.py @@ -104,8 +104,10 @@ def _validate_retry_services(self) -> None: _ = retry_factory.create( strategy=model.retry_strategy, - max_retries=model.max_retries, - max_retry_wait=model.max_retry_wait, + init_args={ + "max_retries": model.max_retries, + "max_retry_wait": model.max_retry_wait, + }, ) def _validate_rate_limiter_services(self) -> None: @@ -130,7 +132,8 @@ def _validate_rate_limiter_services(self) -> None: ) if rpm is not None or tpm is not None: _ = rate_limiter_factory.create( - strategy=model.rate_limit_strategy, rpm=rpm, tpm=tpm + strategy=model.rate_limit_strategy, + init_args={"rpm": rpm, "tpm": tpm}, ) input: InputConfig = Field( diff --git a/graphrag/config/models/language_model_config.py b/graphrag/config/models/language_model_config.py index 33aee99893..11c46d76c7 100644 --- a/graphrag/config/models/language_model_config.py +++ b/graphrag/config/models/language_model_config.py @@ -15,7 +15,7 @@ AzureApiVersionMissingError, ConflictingSettingsError, ) -from graphrag.language_model.factory import ModelFactory +from graphrag.language_model.factory import ChatModelFactory, EmbeddingModelFactory logger = logging.getLogger(__name__) @@ -91,8 +91,11 @@ def _validate_type(self) -> None: If the model name is not recognized. """ # Type should be contained by the registered models - if not ModelFactory.is_supported_model(self.type): - msg = f"Model type {self.type} is not recognized, must be one of {ModelFactory.get_chat_models() + ModelFactory.get_embedding_models()}." + if ( + self.type not in ChatModelFactory() + and self.type not in EmbeddingModelFactory() + ): + msg = f"Model type {self.type} is not recognized, must be one of {ChatModelFactory().keys() + EmbeddingModelFactory().keys()}." raise KeyError(msg) model_provider: str | None = Field( diff --git a/graphrag/factory/factory.py b/graphrag/factory/factory.py index 0624e9bfa1..3116c538ef 100644 --- a/graphrag/factory/factory.py +++ b/graphrag/factory/factory.py @@ -34,25 +34,25 @@ def keys(self) -> list[str]: """Get a list of registered strategy names.""" return list(self._services.keys()) - def register(self, *, strategy: str, service_initializer: Callable[..., T]) -> None: + def register(self, strategy: str, initializer: Callable[..., T]) -> None: """ Register a new service. Args ---- strategy: The name of the strategy. - service_initializer: A callable that creates an instance of T. + initializer: A callable that creates an instance of T. """ - self._services[strategy] = service_initializer + self._services[strategy] = initializer - def create(self, *, strategy: str, **kwargs: Any) -> T: + def create(self, strategy: str, init_args: dict[str, Any] | None = None) -> T: """ Create a service instance based on the strategy. Args ---- strategy: The name of the strategy. - **kwargs: Additional arguments to pass to the service initializer. + init_args: Dict of keyword arguments to pass to the service initializer. Returns ------- @@ -65,4 +65,4 @@ def create(self, *, strategy: str, **kwargs: Any) -> T: if strategy not in self._services: msg = f"Strategy '{strategy}' is not registered." raise ValueError(msg) - return self._services[strategy](**kwargs) + return self._services[strategy](**(init_args or {})) diff --git a/graphrag/index/input/csv.py b/graphrag/index/input/csv.py index 7db033debb..a70863e7ae 100644 --- a/graphrag/index/input/csv.py +++ b/graphrag/index/input/csv.py @@ -8,27 +8,28 @@ import pandas as pd -from graphrag.config.models.input_config import InputConfig -from graphrag.index.input.util import load_files, process_data_columns -from graphrag.storage.pipeline_storage import PipelineStorage +from graphrag.index.input.input_reader import InputReader +from graphrag.index.input.util import process_data_columns logger = logging.getLogger(__name__) -async def load_csv( - config: InputConfig, - storage: PipelineStorage, -) -> pd.DataFrame: - """Load csv inputs from a directory.""" - logger.info("Loading csv files from %s", config.storage.base_dir) +class CSVFileReader(InputReader): + """Reader implementation for csv files.""" - async def load_file(path: str) -> pd.DataFrame: - buffer = BytesIO(await storage.get(path, as_bytes=True)) - data = pd.read_csv(buffer, encoding=config.encoding) - data = process_data_columns(data, config, path) - creation_date = await storage.get_creation_date(path) - data["creation_date"] = data.apply(lambda _: creation_date, axis=1) + async def read_file(self, path: str) -> pd.DataFrame: + """Read a csv file into a DataFrame of documents. - return data + Args: + - path - The path to read the file from. - return await load_files(load_file, config, storage) + Returns + ------- + - output - DataFrame with a row for each document in the file. + """ + buffer = BytesIO(await self._storage.get(path, as_bytes=True)) + data = pd.read_csv(buffer, encoding=self._config.encoding) + data = process_data_columns(data, self._config, path) + creation_date = await self._storage.get_creation_date(path) + data["creation_date"] = data.apply(lambda _: creation_date, axis=1) + return data diff --git a/graphrag/index/input/factory.py b/graphrag/index/input/factory.py index bc4da8c7a1..bdabd0d7c7 100644 --- a/graphrag/index/input/factory.py +++ b/graphrag/index/input/factory.py @@ -4,53 +4,22 @@ """A module containing create_input method definition.""" import logging -from collections.abc import Awaitable, Callable -from typing import cast - -import pandas as pd from graphrag.config.enums import InputFileType -from graphrag.config.models.input_config import InputConfig -from graphrag.index.input.csv import load_csv -from graphrag.index.input.json import load_json -from graphrag.index.input.text import load_text -from graphrag.storage.pipeline_storage import PipelineStorage +from graphrag.factory.factory import Factory +from graphrag.index.input.csv import CSVFileReader +from graphrag.index.input.input_reader import InputReader +from graphrag.index.input.json import JSONFileReader +from graphrag.index.input.text import TextFileReader logger = logging.getLogger(__name__) -loaders: dict[str, Callable[..., Awaitable[pd.DataFrame]]] = { - InputFileType.text: load_text, - InputFileType.csv: load_csv, - InputFileType.json: load_json, -} - - -async def create_input( - config: InputConfig, - storage: PipelineStorage, -) -> pd.DataFrame: - """Instantiate input data for a pipeline.""" - logger.info("loading input from root_dir=%s", config.storage.base_dir) - - if config.file_type in loaders: - logger.info("Loading Input %s", config.file_type) - loader = loaders[config.file_type] - result = await loader(config, storage) - # Convert metadata columns to strings and collapse them into a JSON object - if config.metadata: - if all(col in result.columns for col in config.metadata): - # Collapse the metadata columns into a single JSON object column - result["metadata"] = result[config.metadata].apply( - lambda row: row.to_dict(), axis=1 - ) - else: - value_error_msg = ( - "One or more metadata columns not found in the DataFrame." - ) - raise ValueError(value_error_msg) - - result[config.metadata] = result[config.metadata].astype(str) - - return cast("pd.DataFrame", result) - - msg = f"Unknown input type {config.file_type}" - raise ValueError(msg) + + +class InputReaderFactory(Factory[InputReader]): + """Factory for creating Input Reader instances.""" + + +input_reader_factory = InputReaderFactory() +input_reader_factory.register(InputFileType.text, TextFileReader) +input_reader_factory.register(InputFileType.csv, CSVFileReader) +input_reader_factory.register(InputFileType.json, JSONFileReader) diff --git a/graphrag/index/input/input_reader.py b/graphrag/index/input/input_reader.py new file mode 100644 index 0000000000..ed0add9f97 --- /dev/null +++ b/graphrag/index/input/input_reader.py @@ -0,0 +1,84 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing 'PipelineCache' model.""" + +from __future__ import annotations + +import logging +import re +from abc import ABCMeta, abstractmethod +from typing import TYPE_CHECKING + +import pandas as pd + +if TYPE_CHECKING: + from graphrag.config.models.input_config import InputConfig + from graphrag.storage.pipeline_storage import PipelineStorage + +logger = logging.getLogger(__name__) + + +class InputReader(metaclass=ABCMeta): + """Provide a cache interface for the pipeline.""" + + def __init__(self, storage: PipelineStorage, config: InputConfig, **kwargs): + self._storage = storage + self._config = config + + async def read_files(self) -> pd.DataFrame: + """Load files from storage and apply a loader function based on file type. Process metadata on the results if needed.""" + files = list(self._storage.find(re.compile(self._config.file_pattern))) + + if len(files) == 0: + msg = f"No {self._config.file_type} files found in {self._config.storage.base_dir}" + raise ValueError(msg) + + files_loaded = [] + + for file in files: + try: + files_loaded.append(await self.read_file(file)) + except Exception as e: # noqa: BLE001 (catching Exception is fine here) + logger.warning("Warning! Error loading file %s. Skipping...", file) + logger.warning("Error: %s", e) + + logger.info( + "Found %d %s files, loading %d", + len(files), + self._config.file_type, + len(files_loaded), + ) + result = pd.concat(files_loaded) + total_files_log = ( + f"Total number of unfiltered {self._config.file_type} rows: {len(result)}" + ) + logger.info(total_files_log) + # Convert metadata columns to strings and collapse them into a JSON object + if self._config.metadata: + if all(col in result.columns for col in self._config.metadata): + # Collapse the metadata columns into a single JSON object column + result["metadata"] = result[self._config.metadata].apply( + lambda row: row.to_dict(), axis=1 + ) + else: + value_error_msg = ( + "One or more metadata columns not found in the DataFrame." + ) + raise ValueError(value_error_msg) + + result[self._config.metadata] = result[self._config.metadata].astype(str) + + return result + + @abstractmethod + async def read_file(self, path: str) -> pd.DataFrame: + """Read a file into a DataFrame of documents. + + Args: + - path - The path to read the file from. + + Returns + ------- + - output - DataFrame with a row for each document in the file. + """ diff --git a/graphrag/index/input/json.py b/graphrag/index/input/json.py index 41343fd2b2..cae7db5dbc 100644 --- a/graphrag/index/input/json.py +++ b/graphrag/index/input/json.py @@ -8,30 +8,32 @@ import pandas as pd -from graphrag.config.models.input_config import InputConfig -from graphrag.index.input.util import load_files, process_data_columns -from graphrag.storage.pipeline_storage import PipelineStorage +from graphrag.index.input.input_reader import InputReader +from graphrag.index.input.util import process_data_columns logger = logging.getLogger(__name__) -async def load_json( - config: InputConfig, - storage: PipelineStorage, -) -> pd.DataFrame: - """Load json inputs from a directory.""" - logger.info("Loading json files from %s", config.storage.base_dir) +class JSONFileReader(InputReader): + """Reader implementation for json files.""" - async def load_file(path: str) -> pd.DataFrame: - text = await storage.get(path, encoding=config.encoding) + async def read_file(self, path: str) -> pd.DataFrame: + """Read a JSON file into a DataFrame of documents. + + Args: + - path - The path to read the file from. + + Returns + ------- + - output - DataFrame with a row for each document in the file. + """ + text = await self._storage.get(path, encoding=self._config.encoding) as_json = json.loads(text) # json file could just be a single object, or an array of objects rows = as_json if isinstance(as_json, list) else [as_json] data = pd.DataFrame(rows) - data = process_data_columns(data, config, path) - creation_date = await storage.get_creation_date(path) + data = process_data_columns(data, self._config, path) + creation_date = await self._storage.get_creation_date(path) data["creation_date"] = data.apply(lambda _: creation_date, axis=1) return data - - return await load_files(load_file, config, storage) diff --git a/graphrag/index/input/text.py b/graphrag/index/input/text.py index f1fc74352f..dac2c4c701 100644 --- a/graphrag/index/input/text.py +++ b/graphrag/index/input/text.py @@ -8,26 +8,28 @@ import pandas as pd -from graphrag.config.models.input_config import InputConfig -from graphrag.index.input.util import load_files +from graphrag.index.input.input_reader import InputReader from graphrag.index.utils.hashing import gen_sha512_hash -from graphrag.storage.pipeline_storage import PipelineStorage logger = logging.getLogger(__name__) -async def load_text( - config: InputConfig, - storage: PipelineStorage, -) -> pd.DataFrame: - """Load text inputs from a directory.""" +class TextFileReader(InputReader): + """Reader implementation for text files.""" - async def load_file(path: str) -> pd.DataFrame: - text = await storage.get(path, encoding=config.encoding) + async def read_file(self, path: str) -> pd.DataFrame: + """Read a text file into a DataFrame of documents. + + Args: + - path - The path to read the file from. + + Returns + ------- + - output - DataFrame with a row for each document in the file. + """ + text = await self._storage.get(path, encoding=self._config.encoding) new_item = {"text": text} new_item["id"] = gen_sha512_hash(new_item, new_item.keys()) new_item["title"] = str(Path(path).name) - new_item["creation_date"] = await storage.get_creation_date(path) + new_item["creation_date"] = await self._storage.get_creation_date(path) return pd.DataFrame([new_item]) - - return await load_files(load_file, config, storage) diff --git a/graphrag/index/input/util.py b/graphrag/index/input/util.py index 457c1864a9..2780909167 100644 --- a/graphrag/index/input/util.py +++ b/graphrag/index/input/util.py @@ -4,50 +4,15 @@ """Shared column processing for structured input files.""" import logging -import re -from typing import Any import pandas as pd from graphrag.config.models.input_config import InputConfig from graphrag.index.utils.hashing import gen_sha512_hash -from graphrag.storage.pipeline_storage import PipelineStorage logger = logging.getLogger(__name__) -async def load_files( - loader: Any, - config: InputConfig, - storage: PipelineStorage, -) -> pd.DataFrame: - """Load files from storage and apply a loader function.""" - files = list(storage.find(re.compile(config.file_pattern))) - - if len(files) == 0: - msg = f"No {config.file_type} files found in {config.storage.base_dir}" - raise ValueError(msg) - - files_loaded = [] - - for file in files: - try: - files_loaded.append(await loader(file)) - except Exception as e: # noqa: BLE001 (catching Exception is fine here) - logger.warning("Warning! Error loading file %s. Skipping...", file) - logger.warning("Error: %s", e) - - logger.info( - "Found %d %s files, loading %d", len(files), config.file_type, len(files_loaded) - ) - result = pd.concat(files_loaded) - total_files_log = ( - f"Total number of unfiltered {config.file_type} rows: {len(result)}" - ) - logger.info(total_files_log) - return result - - def process_data_columns( documents: pd.DataFrame, config: InputConfig, path: str ) -> pd.DataFrame: diff --git a/graphrag/index/workflows/finalize_graph.py b/graphrag/index/workflows/finalize_graph.py index 5681173c59..49529aea3a 100644 --- a/graphrag/index/workflows/finalize_graph.py +++ b/graphrag/index/workflows/finalize_graph.py @@ -41,7 +41,6 @@ async def run_workflow( ) if config.snapshots.graphml: - # todo: extract graphs at each level, and add in meta like descriptions graph = create_graph(final_relationships, edge_attr=["weight"]) await snapshot_graphml( diff --git a/graphrag/index/workflows/generate_text_embeddings.py b/graphrag/index/workflows/generate_text_embeddings.py index c15ff3e07f..63859faa49 100644 --- a/graphrag/index/workflows/generate_text_embeddings.py +++ b/graphrag/index/workflows/generate_text_embeddings.py @@ -236,8 +236,6 @@ def _create_vector_store( index_name: str, embedding_name: str | None = None, ) -> BaseVectorStore: - vector_store_type: str = str(vector_store_config.type) - embeddings_schema: dict[str, VectorStoreSchemaConfig] = ( vector_store_config.embeddings_schema ) @@ -259,10 +257,10 @@ def _create_vector_store( single_embedding_config.index_name = index_name args = vector_store_config.model_dump() - vector_store = VectorStoreFactory().create_vector_store( - vector_store_schema_config=single_embedding_config, - vector_store_type=vector_store_type, - **args, + args["vector_store_schema_config"] = single_embedding_config + vector_store = VectorStoreFactory().create( + vector_store_config.type, + args, ) vector_store.connect(**args) diff --git a/graphrag/index/workflows/load_input_documents.py b/graphrag/index/workflows/load_input_documents.py index 33e14d0cb2..228ca4d0c4 100644 --- a/graphrag/index/workflows/load_input_documents.py +++ b/graphrag/index/workflows/load_input_documents.py @@ -8,11 +8,10 @@ import pandas as pd from graphrag.config.models.graph_rag_config import GraphRagConfig -from graphrag.config.models.input_config import InputConfig -from graphrag.index.input.factory import create_input +from graphrag.index.input.factory import InputReaderFactory +from graphrag.index.input.input_reader import InputReader from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.workflow import WorkflowFunctionOutput -from graphrag.storage.pipeline_storage import PipelineStorage from graphrag.utils.storage import write_table_to_storage logger = logging.getLogger(__name__) @@ -23,11 +22,13 @@ async def run_workflow( context: PipelineRunContext, ) -> WorkflowFunctionOutput: """Load and parse input documents into a standard format.""" - output = await load_input_documents( - config.input, - context.input_storage, + input_reader = InputReaderFactory().create( + config.input.file_type, + {"storage": context.input_storage, "config": config.input}, ) + output = await load_input_documents(input_reader) + logger.info("Final # of rows loaded: %s", len(output)) context.stats.num_documents = len(output) @@ -36,8 +37,6 @@ async def run_workflow( return WorkflowFunctionOutput(result=output) -async def load_input_documents( - config: InputConfig, storage: PipelineStorage -) -> pd.DataFrame: +async def load_input_documents(input_reader: InputReader) -> pd.DataFrame: """Load and parse input documents into a standard format.""" - return await create_input(config, storage) + return await input_reader.read_files() diff --git a/graphrag/index/workflows/load_update_documents.py b/graphrag/index/workflows/load_update_documents.py index fbe48b6419..7755091017 100644 --- a/graphrag/index/workflows/load_update_documents.py +++ b/graphrag/index/workflows/load_update_documents.py @@ -8,8 +8,8 @@ import pandas as pd from graphrag.config.models.graph_rag_config import GraphRagConfig -from graphrag.config.models.input_config import InputConfig -from graphrag.index.input.factory import create_input +from graphrag.index.input.factory import InputReaderFactory +from graphrag.index.input.input_reader import InputReader from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.workflow import WorkflowFunctionOutput from graphrag.index.update.incremental_index import get_delta_docs @@ -24,9 +24,12 @@ async def run_workflow( context: PipelineRunContext, ) -> WorkflowFunctionOutput: """Load and parse update-only input documents into a standard format.""" + input_reader = InputReaderFactory().create( + config.input.file_type, + {"storage": context.input_storage, "config": config.input}, + ) output = await load_update_documents( - config.input, - context.input_storage, + input_reader, context.previous_storage, ) @@ -43,12 +46,11 @@ async def run_workflow( async def load_update_documents( - config: InputConfig, - input_storage: PipelineStorage, + input_reader: InputReader, previous_storage: PipelineStorage, ) -> pd.DataFrame: """Load and parse update-only input documents into a standard format.""" - input_documents = await create_input(config, input_storage) + input_documents = await input_reader.read_files() # previous storage is the output of the previous run # we'll use this to diff the input from the prior delta_documents = await get_delta_docs(input_documents, previous_storage) diff --git a/graphrag/language_model/factory.py b/graphrag/language_model/factory.py index 840dada3dc..6ff32b530e 100644 --- a/graphrag/language_model/factory.py +++ b/graphrag/language_model/factory.py @@ -3,10 +3,8 @@ """A package containing a factory for supported llm types.""" -from collections.abc import Callable -from typing import Any, ClassVar - from graphrag.config.enums import ModelType +from graphrag.factory.factory import Factory from graphrag.language_model.protocol.base import ChatModel, EmbeddingModel from graphrag.language_model.providers.litellm.chat_model import LitellmChatModel from graphrag.language_model.providers.litellm.embedding_model import ( @@ -14,90 +12,14 @@ ) -class ModelFactory: - """A factory for creating Model instances.""" - - _chat_registry: ClassVar[dict[str, Callable[..., ChatModel]]] = {} - _embedding_registry: ClassVar[dict[str, Callable[..., EmbeddingModel]]] = {} - - @classmethod - def register_chat(cls, model_type: str, creator: Callable[..., ChatModel]) -> None: - """Register a ChatModel implementation.""" - cls._chat_registry[model_type] = creator - - @classmethod - def register_embedding( - cls, model_type: str, creator: Callable[..., EmbeddingModel] - ) -> None: - """Register an EmbeddingModel implementation.""" - cls._embedding_registry[model_type] = creator - - @classmethod - def create_chat_model(cls, model_type: str, **kwargs: Any) -> ChatModel: - """ - Create a ChatModel instance. - - Args: - model_type: The type of ChatModel to create. - **kwargs: Additional keyword arguments for the ChatModel constructor. - - Returns - ------- - A ChatModel instance. - """ - if model_type not in cls._chat_registry: - msg = f"ChatMOdel implementation '{model_type}' is not registered." - raise ValueError(msg) - return cls._chat_registry[model_type](**kwargs) +class ChatModelFactory(Factory[ChatModel]): + """Singleton factory for creating ChatModel instances.""" - @classmethod - def create_embedding_model(cls, model_type: str, **kwargs: Any) -> EmbeddingModel: - """ - Create an EmbeddingModel instance. - Args: - model_type: The type of EmbeddingModel to create. - **kwargs: Additional keyword arguments for the EmbeddingLLM constructor. - - Returns - ------- - An EmbeddingLLM instance. - """ - if model_type not in cls._embedding_registry: - msg = f"EmbeddingModel implementation '{model_type}' is not registered." - raise ValueError(msg) - return cls._embedding_registry[model_type](**kwargs) - - @classmethod - def get_chat_models(cls) -> list[str]: - """Get the registered ChatModel implementations.""" - return list(cls._chat_registry.keys()) - - @classmethod - def get_embedding_models(cls) -> list[str]: - """Get the registered EmbeddingModel implementations.""" - return list(cls._embedding_registry.keys()) - - @classmethod - def is_supported_chat_model(cls, model_type: str) -> bool: - """Check if the given model type is supported.""" - return model_type in cls._chat_registry - - @classmethod - def is_supported_embedding_model(cls, model_type: str) -> bool: - """Check if the given model type is supported.""" - return model_type in cls._embedding_registry - - @classmethod - def is_supported_model(cls, model_type: str) -> bool: - """Check if the given model type is supported.""" - return cls.is_supported_chat_model( - model_type - ) or cls.is_supported_embedding_model(model_type) +class EmbeddingModelFactory(Factory[EmbeddingModel]): + """Singleton factory for creating EmbeddingModel instances.""" # --- Register default implementations --- -ModelFactory.register_chat(ModelType.Chat, lambda **kwargs: LitellmChatModel(**kwargs)) -ModelFactory.register_embedding( - ModelType.Embedding, lambda **kwargs: LitellmEmbeddingModel(**kwargs) -) +ChatModelFactory().register(ModelType.Chat, LitellmChatModel) +EmbeddingModelFactory().register(ModelType.Embedding, LitellmEmbeddingModel) diff --git a/graphrag/language_model/manager.py b/graphrag/language_model/manager.py index bc41235dda..29349cb278 100644 --- a/graphrag/language_model/manager.py +++ b/graphrag/language_model/manager.py @@ -13,7 +13,7 @@ from typing_extensions import Self -from graphrag.language_model.factory import ModelFactory +from graphrag.language_model.factory import ChatModelFactory, EmbeddingModelFactory if TYPE_CHECKING: from graphrag.language_model.protocol.base import ChatModel, EmbeddingModel @@ -54,9 +54,7 @@ def register_chat( **chat_kwargs: Additional parameters for instantiation. """ chat_kwargs["name"] = name - self.chat_models[name] = ModelFactory.create_chat_model( - model_type, **chat_kwargs - ) + self.chat_models[name] = ChatModelFactory().create(model_type, chat_kwargs) return self.chat_models[name] def register_embedding( @@ -71,8 +69,8 @@ def register_embedding( **embedding_kwargs: Additional parameters for instantiation. """ embedding_kwargs["name"] = name - self.embedding_models[name] = ModelFactory.create_embedding_model( - model_type, **embedding_kwargs + self.embedding_models[name] = EmbeddingModelFactory().create( + model_type, embedding_kwargs ) return self.embedding_models[name] diff --git a/graphrag/language_model/providers/litellm/request_wrappers/with_rate_limiter.py b/graphrag/language_model/providers/litellm/request_wrappers/with_rate_limiter.py index c0e0728f2e..108369444f 100644 --- a/graphrag/language_model/providers/litellm/request_wrappers/with_rate_limiter.py +++ b/graphrag/language_model/providers/litellm/request_wrappers/with_rate_limiter.py @@ -55,7 +55,7 @@ def with_rate_limiter( raise ValueError(msg) rate_limiter_service = rate_limiter_factory.create( - strategy=model_config.rate_limit_strategy, rpm=rpm, tpm=tpm + strategy=model_config.rate_limit_strategy, init_args={"rpm": rpm, "tpm": tpm} ) max_tokens = model_config.max_completion_tokens or model_config.max_tokens or 0 diff --git a/graphrag/language_model/providers/litellm/request_wrappers/with_retries.py b/graphrag/language_model/providers/litellm/request_wrappers/with_retries.py index 1279f9e820..53e13f3fe9 100644 --- a/graphrag/language_model/providers/litellm/request_wrappers/with_retries.py +++ b/graphrag/language_model/providers/litellm/request_wrappers/with_retries.py @@ -39,8 +39,10 @@ def with_retries( retry_factory = RetryFactory() retry_service = retry_factory.create( strategy=model_config.retry_strategy, - max_retries=model_config.max_retries, - max_retry_wait=model_config.max_retry_wait, + init_args={ + "max_retries": model_config.max_retries, + "max_retry_wait": model_config.max_retry_wait, + }, ) def _wrapped_with_retries(**kwargs: Any) -> Any: diff --git a/graphrag/language_model/providers/litellm/services/rate_limiter/rate_limiter_factory.py b/graphrag/language_model/providers/litellm/services/rate_limiter/rate_limiter_factory.py index a6ef6880ef..5904be0b56 100644 --- a/graphrag/language_model/providers/litellm/services/rate_limiter/rate_limiter_factory.py +++ b/graphrag/language_model/providers/litellm/services/rate_limiter/rate_limiter_factory.py @@ -3,11 +3,13 @@ """LiteLLM Rate Limiter Factory.""" -from graphrag.config.defaults import DEFAULT_RATE_LIMITER_SERVICES from graphrag.factory.factory import Factory from graphrag.language_model.providers.litellm.services.rate_limiter.rate_limiter import ( RateLimiter, ) +from graphrag.language_model.providers.litellm.services.rate_limiter.static_rate_limiter import ( + StaticRateLimiter, +) class RateLimiterFactory(Factory[RateLimiter]): @@ -15,8 +17,4 @@ class RateLimiterFactory(Factory[RateLimiter]): rate_limiter_factory = RateLimiterFactory() - -for service_name, service_cls in DEFAULT_RATE_LIMITER_SERVICES.items(): - rate_limiter_factory.register( - strategy=service_name, service_initializer=service_cls - ) +rate_limiter_factory.register("static", StaticRateLimiter) diff --git a/graphrag/language_model/providers/litellm/services/retry/retry_factory.py b/graphrag/language_model/providers/litellm/services/retry/retry_factory.py index 15b3318630..9acdf21425 100644 --- a/graphrag/language_model/providers/litellm/services/retry/retry_factory.py +++ b/graphrag/language_model/providers/litellm/services/retry/retry_factory.py @@ -3,8 +3,19 @@ """LiteLLM Retry Factory.""" -from graphrag.config.defaults import DEFAULT_RETRY_SERVICES from graphrag.factory.factory import Factory +from graphrag.language_model.providers.litellm.services.retry.exponential_retry import ( + ExponentialRetry, +) +from graphrag.language_model.providers.litellm.services.retry.incremental_wait_retry import ( + IncrementalWaitRetry, +) +from graphrag.language_model.providers.litellm.services.retry.native_wait_retry import ( + NativeRetry, +) +from graphrag.language_model.providers.litellm.services.retry.random_wait_retry import ( + RandomWaitRetry, +) from graphrag.language_model.providers.litellm.services.retry.retry import Retry @@ -14,5 +25,7 @@ class RetryFactory(Factory[Retry]): retry_factory = RetryFactory() -for service_name, service_cls in DEFAULT_RETRY_SERVICES.items(): - retry_factory.register(strategy=service_name, service_initializer=service_cls) +retry_factory.register("native", NativeRetry) +retry_factory.register("exponential_backoff", ExponentialRetry) +retry_factory.register("random_wait", RandomWaitRetry) +retry_factory.register("incremental_wait", IncrementalWaitRetry) diff --git a/graphrag/logger/factory.py b/graphrag/logger/factory.py index 03245d1bb0..56800733f5 100644 --- a/graphrag/logger/factory.py +++ b/graphrag/logger/factory.py @@ -7,18 +7,15 @@ import logging from pathlib import Path -from typing import TYPE_CHECKING, ClassVar from graphrag.config.enums import ReportingType - -if TYPE_CHECKING: - from collections.abc import Callable +from graphrag.factory.factory import Factory LOG_FORMAT = "%(asctime)s.%(msecs)04d - %(levelname)s - %(name)s - %(message)s" DATE_FORMAT = "%Y-%m-%d %H:%M:%S" -class LoggerFactory: +class LoggerFactory(Factory[logging.Handler]): """A factory class for logger implementations. Includes a method for users to register a custom logger implementation. @@ -30,53 +27,6 @@ class LoggerFactory: it merely configures the logger to your specified storage location. """ - _registry: ClassVar[dict[str, Callable[..., logging.Handler]]] = {} - - @classmethod - def register( - cls, reporting_type: str, creator: Callable[..., logging.Handler] - ) -> None: - """Register a custom logger implementation. - - Args: - reporting_type: The type identifier for the logger. - creator: A class or callable that initializes logging. - """ - cls._registry[reporting_type] = creator - - @classmethod - def create_logger(cls, reporting_type: str, kwargs: dict) -> logging.Handler: - """Create a logger from the provided type. - - Args: - reporting_type: The type of logger to create. - logger: The logger instance for the application. - kwargs: Additional keyword arguments for the constructor. - - Returns - ------- - A logger instance. - - Raises - ------ - ValueError: If the logger type is not registered. - """ - if reporting_type not in cls._registry: - msg = f"Unknown reporting type: {reporting_type}" - raise ValueError(msg) - - return cls._registry[reporting_type](**kwargs) - - @classmethod - def get_logger_types(cls) -> list[str]: - """Get the registered logger implementations.""" - return list(cls._registry.keys()) - - @classmethod - def is_supported_type(cls, reporting_type: str) -> bool: - """Check if the given logger type is supported.""" - return reporting_type in cls._registry - # --- register built-in logger implementations --- def create_file_logger(**kwargs) -> logging.Handler: @@ -109,5 +59,6 @@ def create_blob_logger(**kwargs) -> logging.Handler: # --- register built-in implementations --- -LoggerFactory.register(ReportingType.file.value, create_file_logger) -LoggerFactory.register(ReportingType.blob.value, create_blob_logger) +logger_factory = LoggerFactory() +logger_factory.register(ReportingType.file.value, create_file_logger) +logger_factory.register(ReportingType.blob.value, create_blob_logger) diff --git a/graphrag/logger/standard_logging.py b/graphrag/logger/standard_logging.py index fea504c02f..31296e12d9 100644 --- a/graphrag/logger/standard_logging.py +++ b/graphrag/logger/standard_logging.py @@ -79,5 +79,5 @@ def init_loggers( config_dict = reporting_config.model_dump() args = {**config_dict, "root_dir": config.root_dir, "filename": filename} - handler = LoggerFactory.create_logger(reporting_config.type, args) + handler = LoggerFactory().create(reporting_config.type, args) logger.addHandler(handler) diff --git a/graphrag/prompt_tune/loader/input.py b/graphrag/prompt_tune/loader/input.py index bdc42e5b5c..5e9fccb440 100644 --- a/graphrag/prompt_tune/loader/input.py +++ b/graphrag/prompt_tune/loader/input.py @@ -12,7 +12,7 @@ from graphrag.cache.noop_pipeline_cache import NoopPipelineCache from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks from graphrag.config.models.graph_rag_config import GraphRagConfig -from graphrag.index.input.factory import create_input +from graphrag.index.input.factory import InputReaderFactory from graphrag.index.operations.embed_text.run_embed_text import ( run_embed_text, ) @@ -64,7 +64,11 @@ async def load_docs_in_chunks( ) tokenizer = get_tokenizer(embeddings_llm_settings) input_storage = create_storage_from_config(config.input.storage) - dataset = await create_input(config.input, input_storage) + input_reader = InputReaderFactory().create( + config.input.file_type, + {"storage": input_storage, "config": config.input}, + ) + dataset = await input_reader.read_files() chunk_config = config.chunks chunks_df = create_base_text_units( documents=dataset, diff --git a/graphrag/storage/factory.py b/graphrag/storage/factory.py index 89c4eeee0d..d2dfa6645c 100644 --- a/graphrag/storage/factory.py +++ b/graphrag/storage/factory.py @@ -5,21 +5,16 @@ from __future__ import annotations -from typing import TYPE_CHECKING, ClassVar - from graphrag.config.enums import StorageType +from graphrag.factory.factory import Factory from graphrag.storage.blob_pipeline_storage import BlobPipelineStorage from graphrag.storage.cosmosdb_pipeline_storage import CosmosDBPipelineStorage from graphrag.storage.file_pipeline_storage import FilePipelineStorage from graphrag.storage.memory_pipeline_storage import MemoryPipelineStorage +from graphrag.storage.pipeline_storage import PipelineStorage -if TYPE_CHECKING: - from collections.abc import Callable - - from graphrag.storage.pipeline_storage import PipelineStorage - -class StorageFactory: +class StorageFactory(Factory[PipelineStorage]): """A factory class for storage implementations. Includes a method for users to register a custom storage implementation. @@ -28,56 +23,10 @@ class StorageFactory: for individual enforcement of required/optional arguments. """ - _registry: ClassVar[dict[str, Callable[..., PipelineStorage]]] = {} - - @classmethod - def register( - cls, storage_type: str, creator: Callable[..., PipelineStorage] - ) -> None: - """Register a custom storage implementation. - - Args: - storage_type: The type identifier for the storage. - creator: A class or callable that creates an instance of PipelineStorage. - - """ - cls._registry[storage_type] = creator - - @classmethod - def create_storage(cls, storage_type: str, kwargs: dict) -> PipelineStorage: - """Create a storage object from the provided type. - - Args: - storage_type: The type of storage to create. - kwargs: Additional keyword arguments for the storage constructor. - - Returns - ------- - A PipelineStorage instance. - - Raises - ------ - ValueError: If the storage type is not registered. - """ - if storage_type not in cls._registry: - msg = f"Unknown storage type: {storage_type}" - raise ValueError(msg) - - return cls._registry[storage_type](**kwargs) - - @classmethod - def get_storage_types(cls) -> list[str]: - """Get the registered storage implementations.""" - return list(cls._registry.keys()) - - @classmethod - def is_supported_type(cls, storage_type: str) -> bool: - """Check if the given storage type is supported.""" - return storage_type in cls._registry - # --- register built-in storage implementations --- -StorageFactory.register(StorageType.blob.value, BlobPipelineStorage) -StorageFactory.register(StorageType.cosmosdb.value, CosmosDBPipelineStorage) -StorageFactory.register(StorageType.file.value, FilePipelineStorage) -StorageFactory.register(StorageType.memory.value, MemoryPipelineStorage) +storage_factory = StorageFactory() +storage_factory.register(StorageType.blob.value, BlobPipelineStorage) +storage_factory.register(StorageType.cosmosdb.value, CosmosDBPipelineStorage) +storage_factory.register(StorageType.file.value, FilePipelineStorage) +storage_factory.register(StorageType.memory.value, MemoryPipelineStorage) diff --git a/graphrag/utils/api.py b/graphrag/utils/api.py index 16a5f9ed52..db4a90a3a5 100644 --- a/graphrag/utils/api.py +++ b/graphrag/utils/api.py @@ -49,10 +49,9 @@ def get_embedding_store( if embedding_config.index_name is None: embedding_config.index_name = index_name - embedding_store = VectorStoreFactory().create_vector_store( - vector_store_type=vector_store_type, - vector_store_schema_config=embedding_config, - **store, + embedding_store = VectorStoreFactory().create( + vector_store_type, + {**store, "vector_store_schema_config": embedding_config}, ) embedding_store.connect(**store) @@ -107,19 +106,19 @@ def load_search_prompt(root_dir: str, prompt_config: str | None) -> str | None: def create_storage_from_config(output: StorageConfig) -> PipelineStorage: """Create a storage object from the config.""" storage_config = output.model_dump() - return StorageFactory().create_storage( - storage_type=storage_config["type"], - kwargs=storage_config, + return StorageFactory().create( + storage_config["type"], + storage_config, ) def create_cache_from_config(cache: CacheConfig, root_dir: str) -> PipelineCache: """Create a cache object from the config.""" cache_config = cache.model_dump() - kwargs = {**cache_config, "root_dir": root_dir} - return CacheFactory().create_cache( - cache_type=cache_config["type"], - kwargs=kwargs, + args = {**cache_config, "root_dir": root_dir} + return CacheFactory().create( + strategy=cache_config["type"], + init_args=args, ) diff --git a/graphrag/vector_stores/azure_ai_search.py b/graphrag/vector_stores/azure_ai_search.py index 3d08f2fc11..e6193e09d9 100644 --- a/graphrag/vector_stores/azure_ai_search.py +++ b/graphrag/vector_stores/azure_ai_search.py @@ -22,7 +22,6 @@ ) from azure.search.documents.models import VectorizedQuery -from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig from graphrag.data_model.types import TextEmbedder from graphrag.vector_stores.base import ( BaseVectorStore, @@ -36,13 +35,6 @@ class AzureAISearchVectorStore(BaseVectorStore): index_client: SearchIndexClient - def __init__( - self, vector_store_schema_config: VectorStoreSchemaConfig, **kwargs: Any - ) -> None: - super().__init__( - vector_store_schema_config=vector_store_schema_config, **kwargs - ) - def connect(self, **kwargs: Any) -> Any: """Connect to AI search vector storage.""" url = kwargs["url"] diff --git a/graphrag/vector_stores/cosmosdb.py b/graphrag/vector_stores/cosmosdb.py index 3010e517ec..23b2c8f821 100644 --- a/graphrag/vector_stores/cosmosdb.py +++ b/graphrag/vector_stores/cosmosdb.py @@ -10,7 +10,6 @@ from azure.cosmos.partition_key import PartitionKey from azure.identity import DefaultAzureCredential -from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig from graphrag.data_model.types import TextEmbedder from graphrag.vector_stores.base import ( BaseVectorStore, @@ -26,13 +25,6 @@ class CosmosDBVectorStore(BaseVectorStore): _database_client: DatabaseProxy _container_client: ContainerProxy - def __init__( - self, vector_store_schema_config: VectorStoreSchemaConfig, **kwargs: Any - ) -> None: - super().__init__( - vector_store_schema_config=vector_store_schema_config, **kwargs - ) - def connect(self, **kwargs: Any) -> Any: """Connect to CosmosDB vector storage.""" connection_string = kwargs.get("connection_string") diff --git a/graphrag/vector_stores/factory.py b/graphrag/vector_stores/factory.py index 8e4f7baa30..90b004cb52 100644 --- a/graphrag/vector_stores/factory.py +++ b/graphrag/vector_stores/factory.py @@ -5,23 +5,15 @@ from __future__ import annotations -from typing import TYPE_CHECKING, ClassVar - from graphrag.config.enums import VectorStoreType +from graphrag.factory.factory import Factory from graphrag.vector_stores.azure_ai_search import AzureAISearchVectorStore +from graphrag.vector_stores.base import BaseVectorStore from graphrag.vector_stores.cosmosdb import CosmosDBVectorStore from graphrag.vector_stores.lancedb import LanceDBVectorStore -if TYPE_CHECKING: - from collections.abc import Callable - - from graphrag.config.models.vector_store_schema_config import ( - VectorStoreSchemaConfig, - ) - from graphrag.vector_stores.base import BaseVectorStore - -class VectorStoreFactory: +class VectorStoreFactory(Factory[BaseVectorStore]): """A factory for vector stores. Includes a method for users to register a custom vector store implementation. @@ -30,67 +22,11 @@ class VectorStoreFactory: for individual enforcement of required/optional arguments. """ - _registry: ClassVar[dict[str, Callable[..., BaseVectorStore]]] = {} - - @classmethod - def register( - cls, vector_store_type: str, creator: Callable[..., BaseVectorStore] - ) -> None: - """Register a custom vector store implementation. - - Args: - vector_store_type: The type identifier for the vector store. - creator: A class or callable that creates an instance of BaseVectorStore. - - Raises - ------ - TypeError: If creator is a class type instead of a factory function. - """ - cls._registry[vector_store_type] = creator - - @classmethod - def create_vector_store( - cls, - vector_store_type: str, - vector_store_schema_config: VectorStoreSchemaConfig, - **kwargs: dict, - ) -> BaseVectorStore: - """Create a vector store object from the provided type. - - Args: - vector_store_type: The type of vector store to create. - kwargs: Additional keyword arguments for the vector store constructor. - - Returns - ------- - A BaseVectorStore instance. - - Raises - ------ - ValueError: If the vector store type is not registered. - """ - if vector_store_type not in cls._registry: - msg = f"Unknown vector store type: {vector_store_type}" - raise ValueError(msg) - - return cls._registry[vector_store_type]( - vector_store_schema_config=vector_store_schema_config, **kwargs - ) - - @classmethod - def get_vector_store_types(cls) -> list[str]: - """Get the registered vector store implementations.""" - return list(cls._registry.keys()) - - @classmethod - def is_supported_type(cls, vector_store_type: str) -> bool: - """Check if the given vector store type is supported.""" - return vector_store_type in cls._registry - # --- register built-in vector store implementations --- -VectorStoreFactory.register(VectorStoreType.LanceDB.value, LanceDBVectorStore) -VectorStoreFactory.register( +vector_store_factory = VectorStoreFactory() +vector_store_factory.register(VectorStoreType.LanceDB.value, LanceDBVectorStore) +vector_store_factory.register( VectorStoreType.AzureAISearch.value, AzureAISearchVectorStore ) -VectorStoreFactory.register(VectorStoreType.CosmosDB.value, CosmosDBVectorStore) +vector_store_factory.register(VectorStoreType.CosmosDB.value, CosmosDBVectorStore) diff --git a/graphrag/vector_stores/lancedb.py b/graphrag/vector_stores/lancedb.py index 629df0d8f9..2b589d5bb0 100644 --- a/graphrag/vector_stores/lancedb.py +++ b/graphrag/vector_stores/lancedb.py @@ -9,7 +9,6 @@ import numpy as np import pyarrow as pa -from graphrag.config.models.vector_store_schema_config import VectorStoreSchemaConfig from graphrag.data_model.types import TextEmbedder from graphrag.vector_stores.base import ( BaseVectorStore, @@ -21,13 +20,6 @@ class LanceDBVectorStore(BaseVectorStore): """LanceDB vector storage implementation.""" - def __init__( - self, vector_store_schema_config: VectorStoreSchemaConfig, **kwargs: Any - ) -> None: - super().__init__( - vector_store_schema_config=vector_store_schema_config, **kwargs - ) - def connect(self, **kwargs: Any) -> Any: """Connect to the vector storage.""" self.db_connection = lancedb.connect(kwargs["db_uri"]) diff --git a/tests/__init__.py b/tests/__init__.py index cbb9376b53..e5c93e3503 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -6,10 +6,10 @@ # Register MOCK providers from graphrag.config.enums import ModelType -from graphrag.language_model.factory import ModelFactory +from graphrag.language_model.factory import ChatModelFactory, EmbeddingModelFactory from tests.mock_provider import MockChatLLM, MockEmbeddingLLM -ModelFactory.register_chat(ModelType.MockChat, lambda **kwargs: MockChatLLM(**kwargs)) -ModelFactory.register_embedding( +ChatModelFactory().register(ModelType.MockChat, lambda **kwargs: MockChatLLM(**kwargs)) +EmbeddingModelFactory().register( ModelType.MockEmbedding, lambda **kwargs: MockEmbeddingLLM(**kwargs) ) diff --git a/tests/fixtures/azure/config.json b/tests/fixtures/azure/config.json index 8adced2c08..91d9fc51d0 100644 --- a/tests/fixtures/azure/config.json +++ b/tests/fixtures/azure/config.json @@ -1,6 +1,7 @@ { "input_path": "./tests/fixtures/azure", "input_file_type": "text", + "index_method": "standard", "workflow_config": { "skip_assert": true, "azure": { diff --git a/tests/fixtures/min-csv/config.json b/tests/fixtures/min-csv/config.json index 64f90a3f07..ca2bd7d823 100644 --- a/tests/fixtures/min-csv/config.json +++ b/tests/fixtures/min-csv/config.json @@ -1,6 +1,7 @@ { "input_path": "./tests/fixtures/min-csv", "input_file_type": "text", + "index_method": "standard", "workflow_config": { "load_input_documents": { "max_runtime": 30 @@ -20,10 +21,6 @@ 100, 750 ], - "nan_allowed_columns": [ - "x", - "y" - ], "max_runtime": 30, "expected_artifacts": [ "entities.parquet", @@ -54,7 +51,7 @@ "period", "size" ], - "max_runtime": 360, + "max_runtime": 1200, "expected_artifacts": ["community_reports.parquet"] }, "create_final_text_units": { diff --git a/tests/fixtures/min-csv/settings.yml b/tests/fixtures/min-csv/settings.yml index 26143840e4..88f23ad061 100644 --- a/tests/fixtures/min-csv/settings.yml +++ b/tests/fixtures/min-csv/settings.yml @@ -8,8 +8,8 @@ models: api_version: "2025-04-01-preview" model: gpt-4.1 retry_strategy: exponential_backoff - tokens_per_minute: null - requests_per_minute: null + tokens_per_minute: 250_000 + requests_per_minute: 250 model_supports_json: true concurrent_requests: 25 default_embedding_model: @@ -21,8 +21,8 @@ models: api_version: "2025-04-01-preview" model: text-embedding-3-large retry_strategy: exponential_backoff - tokens_per_minute: null - requests_per_minute: null + tokens_per_minute: 250_000 + requests_per_minute: 250 concurrent_requests: 25 vector_store: diff --git a/tests/fixtures/text/config.json b/tests/fixtures/text/config.json index f7278a2308..eef957da2f 100644 --- a/tests/fixtures/text/config.json +++ b/tests/fixtures/text/config.json @@ -1,6 +1,7 @@ { "input_path": "./tests/fixtures/text", "input_file_type": "text", + "index_method": "fast", "workflow_config": { "load_input_documents": { "max_runtime": 30 @@ -8,17 +9,16 @@ "create_base_text_units": { "max_runtime": 30 }, - "extract_graph": { - "max_runtime": 500 + "extract_graph_nlp": { + "max_runtime": 30 + }, + "prune_graph": { + "max_runtime": 30 }, "finalize_graph": { "row_range": [ 10, - 200 - ], - "nan_allowed_columns": [ - "x", - "y" + 300 ], "max_runtime": 30, "expected_artifacts": [ @@ -26,23 +26,6 @@ "relationships.parquet" ] }, - "extract_covariates": { - "row_range": [ - 10, - 100 - ], - "nan_allowed_columns": [ - "type", - "description", - "object_id", - "status", - "start_date", - "end_date", - "source_text" - ], - "max_runtime": 360, - "expected_artifacts": ["covariates.parquet"] - }, "create_communities": { "row_range": [ 1, @@ -51,7 +34,7 @@ "max_runtime": 30, "expected_artifacts": ["communities.parquet"] }, - "create_community_reports": { + "create_community_reports_text": { "row_range": [ 1, 30 @@ -67,7 +50,7 @@ "period", "size" ], - "max_runtime": 360, + "max_runtime": 1200, "expected_artifacts": ["community_reports.parquet"] }, "create_final_text_units": { diff --git a/tests/fixtures/text/settings.yml b/tests/fixtures/text/settings.yml index 633c12a03f..f7602ccb73 100644 --- a/tests/fixtures/text/settings.yml +++ b/tests/fixtures/text/settings.yml @@ -8,8 +8,8 @@ models: api_version: "2025-04-01-preview" model: gpt-4.1 retry_strategy: exponential_backoff - tokens_per_minute: null - requests_per_minute: null + tokens_per_minute: 250_000 + requests_per_minute: 250 model_supports_json: true concurrent_requests: 25 default_embedding_model: @@ -21,8 +21,8 @@ models: api_version: "2025-04-01-preview" model: text-embedding-3-large retry_strategy: exponential_backoff - tokens_per_minute: null - requests_per_minute: null + tokens_per_minute: 250_000 + requests_per_minute: 250 concurrent_requests: 25 vector_store: @@ -32,9 +32,6 @@ vector_store: api_key: ${AZURE_AI_SEARCH_API_KEY} container_name: "simple_text_ci" -extract_claims: - enabled: true - community_reports: prompt: "prompts/community_report.txt" max_length: 2000 diff --git a/tests/integration/cache/test_factory.py b/tests/integration/cache/test_factory.py index 5dde4e2635..c7a7964471 100644 --- a/tests/integration/cache/test_factory.py +++ b/tests/integration/cache/test_factory.py @@ -2,7 +2,7 @@ # Licensed under the MIT License """CacheFactory Tests. -These tests will test the CacheFactory class and the creation of each cache type that is natively supported. +These tests will test the CacheFactory() class and the creation of each cache type that is natively supported. """ import sys @@ -23,30 +23,30 @@ def test_create_noop_cache(): - kwargs = {} - cache = CacheFactory.create_cache(CacheType.none.value, kwargs) + cache = CacheFactory().create(strategy=CacheType.none.value) assert isinstance(cache, NoopPipelineCache) def test_create_memory_cache(): - kwargs = {} - cache = CacheFactory.create_cache(CacheType.memory.value, kwargs) + cache = CacheFactory().create(strategy=CacheType.memory.value) assert isinstance(cache, InMemoryCache) def test_create_file_cache(): - kwargs = {"root_dir": "/tmp", "base_dir": "testcache"} - cache = CacheFactory.create_cache(CacheType.file.value, kwargs) + cache = CacheFactory().create( + strategy=CacheType.file.value, + init_args={"root_dir": "/tmp", "base_dir": "testcache"}, + ) assert isinstance(cache, JsonPipelineCache) def test_create_blob_cache(): - kwargs = { + init_args = { "connection_string": WELL_KNOWN_BLOB_STORAGE_KEY, "container_name": "testcontainer", "base_dir": "testcache", } - cache = CacheFactory.create_cache(CacheType.blob.value, kwargs) + cache = CacheFactory().create(strategy=CacheType.blob.value, init_args=init_args) assert isinstance(cache, JsonPipelineCache) @@ -55,12 +55,14 @@ def test_create_blob_cache(): reason="cosmosdb emulator is only available on windows runners at this time", ) def test_create_cosmosdb_cache(): - kwargs = { + init_args = { "connection_string": WELL_KNOWN_COSMOS_CONNECTION_STRING, "base_dir": "testdatabase", "container_name": "testcontainer", } - cache = CacheFactory.create_cache(CacheType.cosmosdb.value, kwargs) + cache = CacheFactory().create( + strategy=CacheType.cosmosdb.value, init_args=init_args + ) assert isinstance(cache, JsonPipelineCache) @@ -75,8 +77,11 @@ def test_register_and_create_custom_cache(): instance.initialized = True custom_cache_class.return_value = instance - CacheFactory.register("custom", lambda **kwargs: custom_cache_class(**kwargs)) - cache = CacheFactory.create_cache("custom", {}) + CacheFactory().register( + strategy="custom", + initializer=lambda **kwargs: custom_cache_class(**kwargs), + ) + cache = CacheFactory().create(strategy="custom") assert custom_cache_class.called assert cache is instance @@ -84,47 +89,34 @@ def test_register_and_create_custom_cache(): assert cache.initialized is True # type: ignore # Attribute only exists on our mock # Check if it's in the list of registered cache types - assert "custom" in CacheFactory.get_cache_types() - assert CacheFactory.is_supported_type("custom") - - -def test_get_cache_types(): - cache_types = CacheFactory.get_cache_types() - # Check that built-in types are registered - assert CacheType.none.value in cache_types - assert CacheType.memory.value in cache_types - assert CacheType.file.value in cache_types - assert CacheType.blob.value in cache_types - assert CacheType.cosmosdb.value in cache_types + assert "custom" in CacheFactory() def test_create_unknown_cache(): - with pytest.raises(ValueError, match="Unknown cache type: unknown"): - CacheFactory.create_cache("unknown", {}) + with pytest.raises(ValueError, match="Strategy 'unknown' is not registered\\."): + CacheFactory().create(strategy="unknown") def test_is_supported_type(): # Test built-in types - assert CacheFactory.is_supported_type(CacheType.none.value) - assert CacheFactory.is_supported_type(CacheType.memory.value) - assert CacheFactory.is_supported_type(CacheType.file.value) - assert CacheFactory.is_supported_type(CacheType.blob.value) - assert CacheFactory.is_supported_type(CacheType.cosmosdb.value) + assert CacheType.none.value in CacheFactory() + assert CacheType.memory.value in CacheFactory() + assert CacheType.file.value in CacheFactory() + assert CacheType.blob.value in CacheFactory() + assert CacheType.cosmosdb.value in CacheFactory() # Test unknown type - assert not CacheFactory.is_supported_type("unknown") + assert "unknown" not in CacheFactory() def test_enum_and_string_compatibility(): """Test that both enum and string types work for cache creation.""" - kwargs = {} - # Test with enum - cache_enum = CacheFactory.create_cache(CacheType.memory, kwargs) + cache_enum = CacheFactory().create(strategy=CacheType.memory) assert isinstance(cache_enum, InMemoryCache) # Test with string - cache_str = CacheFactory.create_cache("memory", kwargs) + cache_str = CacheFactory().create(strategy="memory") assert isinstance(cache_str, InMemoryCache) # Both should create the same type @@ -132,7 +124,7 @@ def test_enum_and_string_compatibility(): def test_register_class_directly_works(): - """Test that registering a class directly works (CacheFactory allows this).""" + """Test that registering a class directly works (CacheFactory() allows this).""" from graphrag.cache.pipeline_cache import PipelineCache class CustomCache(PipelineCache): @@ -157,13 +149,12 @@ async def clear(self): def child(self, name: str): return self - # CacheFactory allows registering classes directly (no TypeError) - CacheFactory.register("custom_class", CustomCache) + # CacheFactory() allows registering classes directly (no TypeError) + CacheFactory().register("custom_class", CustomCache) # Verify it was registered - assert "custom_class" in CacheFactory.get_cache_types() - assert CacheFactory.is_supported_type("custom_class") + assert "custom_class" in CacheFactory() # Test creating an instance - cache = CacheFactory.create_cache("custom_class", {}) + cache = CacheFactory().create(strategy="custom_class") assert isinstance(cache, CustomCache) diff --git a/tests/integration/language_model/test_factory.py b/tests/integration/language_model/test_factory.py index af503265d6..9757c68d99 100644 --- a/tests/integration/language_model/test_factory.py +++ b/tests/integration/language_model/test_factory.py @@ -9,7 +9,7 @@ from collections.abc import AsyncGenerator, Generator from typing import Any -from graphrag.language_model.factory import ModelFactory +from graphrag.language_model.factory import ChatModelFactory, EmbeddingModelFactory from graphrag.language_model.manager import ModelManager from graphrag.language_model.response.base import ( BaseModelOutput, @@ -48,7 +48,7 @@ def chat_stream( self, prompt: str, history: list | None = None, **kwargs: Any ) -> Generator[str, None]: ... - ModelFactory.register_chat("custom_chat", CustomChatModel) + ChatModelFactory().register("custom_chat", CustomChatModel) model = ModelManager().get_or_create_chat_model("custom", "custom_chat") assert isinstance(model, CustomChatModel) response = await model.achat("prompt") @@ -81,7 +81,7 @@ async def aembed_batch( def embed_batch(self, text_list: list[str], **kwargs) -> list[list[float]]: return [[1.0]] - ModelFactory.register_embedding("custom_embedding", CustomEmbeddingModel) + EmbeddingModelFactory().register("custom_embedding", CustomEmbeddingModel) llm = ModelManager().get_or_create_embedding_model("custom", "custom_embedding") assert isinstance(llm, CustomEmbeddingModel) response = await llm.aembed("text") diff --git a/tests/unit/litellm_services/test_rate_limiter.py b/tests/integration/language_model/test_rate_limiter.py similarity index 89% rename from tests/unit/litellm_services/test_rate_limiter.py rename to tests/integration/language_model/test_rate_limiter.py index ffe144212d..ec482b7240 100644 --- a/tests/unit/litellm_services/test_rate_limiter.py +++ b/tests/integration/language_model/test_rate_limiter.py @@ -16,7 +16,7 @@ from graphrag.language_model.providers.litellm.services.rate_limiter.rate_limiter_factory import ( RateLimiterFactory, ) -from tests.unit.litellm_services.utils import ( +from tests.integration.language_model.utils import ( assert_max_num_values_per_period, assert_stagger, bin_time_intervals, @@ -51,7 +51,7 @@ def test_rate_limiter_validation(): # Valid parameters rate_limiter = rate_limiter_factory.create( - strategy="static", rpm=60, tpm=10000, period_in_seconds=60 + strategy="static", init_args={"rpm": 60, "tpm": 10000, "period_in_seconds": 60} ) assert rate_limiter is not None @@ -60,7 +60,9 @@ def test_rate_limiter_validation(): ValueError, match=r"Strategy 'invalid_strategy' is not registered.", ): - rate_limiter_factory.create(strategy="invalid_strategy", rpm=60, tpm=10000) + rate_limiter_factory.create( + strategy="invalid_strategy", init_args={"rpm": 60, "tpm": 10000} + ) # Both rpm and tpm are None with pytest.raises( @@ -74,26 +76,29 @@ def test_rate_limiter_validation(): ValueError, match=r"RPM and TPM must be either None \(disabled\) or positive integers.", ): - rate_limiter_factory.create(strategy="static", rpm=-10) + rate_limiter_factory.create(strategy="static", init_args={"rpm": -10}) # Invalid tpm with pytest.raises( ValueError, match=r"RPM and TPM must be either None \(disabled\) or positive integers.", ): - rate_limiter_factory.create(strategy="static", tpm=-10) + rate_limiter_factory.create(strategy="static", init_args={"tpm": -10}) # Invalid period_in_seconds with pytest.raises( ValueError, match=r"Period in seconds must be a positive integer." ): - rate_limiter_factory.create(strategy="static", rpm=10, period_in_seconds=-10) + rate_limiter_factory.create( + strategy="static", init_args={"rpm": 10, "period_in_seconds": -10} + ) def test_rpm(): """Test that the rate limiter enforces RPM limits.""" rate_limiter = rate_limiter_factory.create( - strategy="static", rpm=_rpm, period_in_seconds=_period_in_seconds + strategy="static", + init_args={"rpm": _rpm, "period_in_seconds": _period_in_seconds}, ) time_values: list[float] = [] @@ -121,7 +126,8 @@ def test_rpm(): def test_tpm(): """Test that the rate limiter enforces TPM limits.""" rate_limiter = rate_limiter_factory.create( - strategy="static", tpm=_tpm, period_in_seconds=_period_in_seconds + strategy="static", + init_args={"tpm": _tpm, "period_in_seconds": _period_in_seconds}, ) time_values: list[float] = [] @@ -154,7 +160,8 @@ def test_token_in_request_exceeds_tpm(): In this case, the request should still be allowed to proceed but may take up its own rate limit bin. """ rate_limiter = rate_limiter_factory.create( - strategy="static", tpm=_tpm, period_in_seconds=_period_in_seconds + strategy="static", + init_args={"tpm": _tpm, "period_in_seconds": _period_in_seconds}, ) time_values: list[float] = [] @@ -178,7 +185,8 @@ def test_token_in_request_exceeds_tpm(): def test_rpm_and_tpm_with_rpm_as_limiting_factor(): """Test that the rate limiter enforces RPM and TPM limits.""" rate_limiter = rate_limiter_factory.create( - strategy="static", rpm=_rpm, tpm=_tpm, period_in_seconds=_period_in_seconds + strategy="static", + init_args={"rpm": _rpm, "tpm": _tpm, "period_in_seconds": _period_in_seconds}, ) time_values: list[float] = [] @@ -207,7 +215,8 @@ def test_rpm_and_tpm_with_rpm_as_limiting_factor(): def test_rpm_and_tpm_with_tpm_as_limiting_factor(): """Test that the rate limiter enforces TPM limits.""" rate_limiter = rate_limiter_factory.create( - strategy="static", rpm=_rpm, tpm=_tpm, period_in_seconds=_period_in_seconds + strategy="static", + init_args={"rpm": _rpm, "tpm": _tpm, "period_in_seconds": _period_in_seconds}, ) time_values: list[float] = [] @@ -251,7 +260,8 @@ def _run_rate_limiter( def test_rpm_threaded(): """Test that the rate limiter enforces RPM limits in a threaded environment.""" rate_limiter = rate_limiter_factory.create( - strategy="static", rpm=_rpm, tpm=_tpm, period_in_seconds=_period_in_seconds + strategy="static", + init_args={"rpm": _rpm, "tpm": _tpm, "period_in_seconds": _period_in_seconds}, ) input_queue: Queue[int | None] = Queue() @@ -311,7 +321,8 @@ def test_rpm_threaded(): def test_tpm_threaded(): """Test that the rate limiter enforces TPM limits in a threaded environment.""" rate_limiter = rate_limiter_factory.create( - strategy="static", rpm=_rpm, tpm=_tpm, period_in_seconds=_period_in_seconds + strategy="static", + init_args={"rpm": _rpm, "tpm": _tpm, "period_in_seconds": _period_in_seconds}, ) input_queue: Queue[int | None] = Queue() diff --git a/tests/unit/litellm_services/test_retries.py b/tests/integration/language_model/test_retries.py similarity index 95% rename from tests/unit/litellm_services/test_retries.py rename to tests/integration/language_model/test_retries.py index 3db2a39027..617e11a8cb 100644 --- a/tests/unit/litellm_services/test_retries.py +++ b/tests/integration/language_model/test_retries.py @@ -57,8 +57,10 @@ def test_retries( """ retry_service = retry_factory.create( strategy=strategy, - max_retries=max_retries, - max_retry_wait=max_retry_wait, + init_args={ + "max_retries": max_retries, + "max_retry_wait": max_retry_wait, + }, ) retries = 0 @@ -124,8 +126,10 @@ async def test_retries_async( """ retry_service = retry_factory.create( strategy=strategy, - max_retries=max_retries, - max_retry_wait=max_retry_wait, + init_args={ + "max_retries": max_retries, + "max_retry_wait": max_retry_wait, + }, ) retries = 0 diff --git a/tests/unit/litellm_services/utils.py b/tests/integration/language_model/utils.py similarity index 100% rename from tests/unit/litellm_services/utils.py rename to tests/integration/language_model/utils.py diff --git a/tests/integration/logging/test_factory.py b/tests/integration/logging/test_factory.py index bacb4f0712..cb22ca8674 100644 --- a/tests/integration/logging/test_factory.py +++ b/tests/integration/logging/test_factory.py @@ -27,7 +27,7 @@ def test_create_blob_logger(): "base_dir": "testbasedir", "container_name": "testcontainer", } - logger = LoggerFactory.create_logger(ReportingType.blob.value, kwargs) + logger = LoggerFactory().create(ReportingType.blob.value, kwargs) assert isinstance(logger, BlobWorkflowLogger) @@ -40,8 +40,8 @@ def test_register_and_create_custom_logger(): instance.initialized = True custom_logger_class.return_value = instance - LoggerFactory.register("custom", lambda **kwargs: custom_logger_class(**kwargs)) - logger = LoggerFactory.create_logger("custom", {}) + LoggerFactory().register("custom", lambda **kwargs: custom_logger_class(**kwargs)) + logger = LoggerFactory().create("custom") assert custom_logger_class.called assert logger is instance @@ -49,17 +49,15 @@ def test_register_and_create_custom_logger(): assert logger.initialized is True # type: ignore # Attribute only exists on our mock # Check if it's in the list of registered logger types - assert "custom" in LoggerFactory.get_logger_types() - assert LoggerFactory.is_supported_type("custom") + assert "custom" in LoggerFactory() def test_get_logger_types(): - logger_types = LoggerFactory.get_logger_types() # Check that built-in types are registered - assert ReportingType.file.value in logger_types - assert ReportingType.blob.value in logger_types + assert ReportingType.file.value in LoggerFactory() + assert ReportingType.blob.value in LoggerFactory() def test_create_unknown_logger(): - with pytest.raises(ValueError, match="Unknown reporting type: unknown"): - LoggerFactory.create_logger("unknown", {}) + with pytest.raises(ValueError, match="Strategy 'unknown' is not registered\\."): + LoggerFactory().create("unknown") diff --git a/tests/integration/storage/test_factory.py b/tests/integration/storage/test_factory.py index 0b59366ef9..8817b25036 100644 --- a/tests/integration/storage/test_factory.py +++ b/tests/integration/storage/test_factory.py @@ -31,7 +31,7 @@ def test_create_blob_storage(): "base_dir": "testbasedir", "container_name": "testcontainer", } - storage = StorageFactory.create_storage(StorageType.blob.value, kwargs) + storage = StorageFactory().create(StorageType.blob.value, kwargs) assert isinstance(storage, BlobPipelineStorage) @@ -46,19 +46,19 @@ def test_create_cosmosdb_storage(): "base_dir": "testdatabase", "container_name": "testcontainer", } - storage = StorageFactory.create_storage(StorageType.cosmosdb.value, kwargs) + storage = StorageFactory().create(StorageType.cosmosdb.value, kwargs) assert isinstance(storage, CosmosDBPipelineStorage) -def test_create_file_storage(): +def test_create_file(): kwargs = {"type": "file", "base_dir": "/tmp/teststorage"} - storage = StorageFactory.create_storage(StorageType.file.value, kwargs) + storage = StorageFactory().create(StorageType.file.value, kwargs) assert isinstance(storage, FilePipelineStorage) def test_create_memory_storage(): kwargs = {} # MemoryPipelineStorage doesn't accept any constructor parameters - storage = StorageFactory.create_storage(StorageType.memory.value, kwargs) + storage = StorageFactory().create(StorageType.memory.value, kwargs) assert isinstance(storage, MemoryPipelineStorage) @@ -74,8 +74,8 @@ def test_register_and_create_custom_storage(): instance.initialized = True custom_storage_class.return_value = instance - StorageFactory.register("custom", lambda **kwargs: custom_storage_class(**kwargs)) - storage = StorageFactory.create_storage("custom", {}) + StorageFactory().register("custom", lambda **kwargs: custom_storage_class(**kwargs)) + storage = StorageFactory().create("custom", {}) assert custom_storage_class.called assert storage is instance @@ -83,22 +83,20 @@ def test_register_and_create_custom_storage(): assert storage.initialized is True # type: ignore # Attribute only exists on our mock # Check if it's in the list of registered storage types - assert "custom" in StorageFactory.get_storage_types() - assert StorageFactory.is_supported_type("custom") + assert "custom" in StorageFactory() def test_get_storage_types(): - storage_types = StorageFactory.get_storage_types() # Check that built-in types are registered - assert StorageType.file.value in storage_types - assert StorageType.memory.value in storage_types - assert StorageType.blob.value in storage_types - assert StorageType.cosmosdb.value in storage_types + assert StorageType.file.value in StorageFactory() + assert StorageType.memory.value in StorageFactory() + assert StorageType.blob.value in StorageFactory() + assert StorageType.cosmosdb.value in StorageFactory() def test_create_unknown_storage(): - with pytest.raises(ValueError, match="Unknown storage type: unknown"): - StorageFactory.create_storage("unknown", {}) + with pytest.raises(ValueError, match="Strategy 'unknown' is not registered\\."): + StorageFactory().create("unknown") def test_register_class_directly_works(): @@ -148,12 +146,11 @@ async def get_creation_date(self, key: str) -> str: return "2024-01-01 00:00:00 +0000" # StorageFactory allows registering classes directly (no TypeError) - StorageFactory.register("custom_class", CustomStorage) + StorageFactory().register("custom_class", CustomStorage) # Verify it was registered - assert "custom_class" in StorageFactory.get_storage_types() - assert StorageFactory.is_supported_type("custom_class") + assert "custom_class" in StorageFactory() # Test creating an instance - storage = StorageFactory.create_storage("custom_class", {}) + storage = StorageFactory().create("custom_class") assert isinstance(storage, CustomStorage) diff --git a/tests/integration/vector_stores/test_factory.py b/tests/integration/vector_stores/test_factory.py index 2a2c9c7173..69720c9664 100644 --- a/tests/integration/vector_stores/test_factory.py +++ b/tests/integration/vector_stores/test_factory.py @@ -19,13 +19,13 @@ def test_create_lancedb_vector_store(): kwargs = { "db_uri": "/tmp/lancedb", - } - vector_store = VectorStoreFactory.create_vector_store( - vector_store_type=VectorStoreType.LanceDB.value, - vector_store_schema_config=VectorStoreSchemaConfig( + "vector_store_schema_config": VectorStoreSchemaConfig( index_name="test_collection" ), - kwargs=kwargs, + } + vector_store = VectorStoreFactory().create( + VectorStoreType.LanceDB.value, + kwargs, ) assert isinstance(vector_store, LanceDBVectorStore) assert vector_store.index_name == "test_collection" @@ -36,13 +36,13 @@ def test_create_azure_ai_search_vector_store(): kwargs = { "url": "https://test.search.windows.net", "api_key": "test_key", - } - vector_store = VectorStoreFactory.create_vector_store( - vector_store_type=VectorStoreType.AzureAISearch.value, - vector_store_schema_config=VectorStoreSchemaConfig( + "vector_store_schema_config": VectorStoreSchemaConfig( index_name="test_collection" ), - kwargs=kwargs, + } + vector_store = VectorStoreFactory().create( + VectorStoreType.AzureAISearch.value, + kwargs, ) assert isinstance(vector_store, AzureAISearchVectorStore) @@ -52,14 +52,14 @@ def test_create_cosmosdb_vector_store(): kwargs = { "connection_string": "AccountEndpoint=https://test.documents.azure.com:443/;AccountKey=test_key==", "database_name": "test_db", - } - - vector_store = VectorStoreFactory.create_vector_store( - vector_store_type=VectorStoreType.CosmosDB.value, - vector_store_schema_config=VectorStoreSchemaConfig( + "vector_store_schema_config": VectorStoreSchemaConfig( index_name="test_collection" ), - kwargs=kwargs, + } + + vector_store = VectorStoreFactory().create( + VectorStoreType.CosmosDB.value, + kwargs, ) assert isinstance(vector_store, CosmosDBVectorStore) @@ -76,12 +76,12 @@ def test_register_and_create_custom_vector_store(): instance.initialized = True custom_vector_store_class.return_value = instance - VectorStoreFactory.register( + VectorStoreFactory().register( "custom", lambda **kwargs: custom_vector_store_class(**kwargs) ) - vector_store = VectorStoreFactory.create_vector_store( - vector_store_type="custom", vector_store_schema_config=VectorStoreSchemaConfig() + vector_store = VectorStoreFactory().create( + "custom", {"vector_store_schema_config": VectorStoreSchemaConfig()} ) assert custom_vector_store_class.called @@ -90,38 +90,26 @@ def test_register_and_create_custom_vector_store(): assert vector_store.initialized is True # type: ignore # Attribute only exists on our mock # Check if it's in the list of registered vector store types - assert "custom" in VectorStoreFactory.get_vector_store_types() - assert VectorStoreFactory.is_supported_type("custom") - - -def test_get_vector_store_types(): - vector_store_types = VectorStoreFactory.get_vector_store_types() - # Check that built-in types are registered - assert VectorStoreType.LanceDB.value in vector_store_types - assert VectorStoreType.AzureAISearch.value in vector_store_types - assert VectorStoreType.CosmosDB.value in vector_store_types + assert "custom" in VectorStoreFactory() def test_create_unknown_vector_store(): - with pytest.raises(ValueError, match="Unknown vector store type: unknown"): - VectorStoreFactory.create_vector_store( - vector_store_type="unknown", - vector_store_schema_config=VectorStoreSchemaConfig(), - ) + with pytest.raises(ValueError, match="Strategy 'unknown' is not registered\\."): + VectorStoreFactory().create("unknown") def test_is_supported_type(): # Test built-in types - assert VectorStoreFactory.is_supported_type(VectorStoreType.LanceDB.value) - assert VectorStoreFactory.is_supported_type(VectorStoreType.AzureAISearch.value) - assert VectorStoreFactory.is_supported_type(VectorStoreType.CosmosDB.value) + assert VectorStoreType.LanceDB.value in VectorStoreFactory() + assert VectorStoreType.AzureAISearch.value in VectorStoreFactory() + assert VectorStoreType.CosmosDB.value in VectorStoreFactory() # Test unknown type - assert not VectorStoreFactory.is_supported_type("unknown") + assert "unknown" not in VectorStoreFactory() def test_register_class_directly_works(): - """Test that registering a class directly works (VectorStoreFactory allows this).""" + """Test that registering a class directly works.""" from graphrag.vector_stores.base import BaseVectorStore class CustomVectorStore(BaseVectorStore): @@ -143,25 +131,21 @@ def similarity_search_by_vector(self, query_embedding, k=10, **kwargs): def similarity_search_by_text(self, text, text_embedder, k=10, **kwargs): return [] - def filter_by_id(self, include_ids): - return {} - def search_by_id(self, id): from graphrag.vector_stores.base import VectorStoreDocument return VectorStoreDocument(id=id, vector=None) - # VectorStoreFactory allows registering classes directly (no TypeError) - VectorStoreFactory.register("custom_class", CustomVectorStore) + # VectorStoreFactory() allows registering classes directly (no TypeError) + VectorStoreFactory().register("custom_class", CustomVectorStore) # Verify it was registered - assert "custom_class" in VectorStoreFactory.get_vector_store_types() - assert VectorStoreFactory.is_supported_type("custom_class") + assert "custom_class" in VectorStoreFactory() # Test creating an instance - vector_store = VectorStoreFactory.create_vector_store( - vector_store_type="custom_class", - vector_store_schema_config=VectorStoreSchemaConfig(), + vector_store = VectorStoreFactory().create( + "custom_class", + {"vector_store_schema_config": VectorStoreSchemaConfig()}, ) assert isinstance(vector_store, CustomVectorStore) diff --git a/tests/smoke/test_fixtures.py b/tests/smoke/test_fixtures.py index 7a81bf2d16..e0dd6b7b37 100644 --- a/tests/smoke/test_fixtures.py +++ b/tests/smoke/test_fixtures.py @@ -128,6 +128,7 @@ def __run_indexer( self, root: Path, input_file_type: str, + index_method: str, ): command = [ "uv", @@ -138,7 +139,7 @@ def __run_indexer( "--root", root.resolve().as_posix(), "--method", - "standard", + index_method, ] command = [arg for arg in command if arg] logger.info("running command ", " ".join(command)) @@ -234,6 +235,7 @@ def test_fixture( self, input_path: str, input_file_type: str, + index_method: str, workflow_config: dict[str, dict[str, Any]], query_config: list[dict[str, str]], ): @@ -248,7 +250,7 @@ def test_fixture( dispose = asyncio.run(prepare_azurite_data(input_path, azure)) print("running indexer") - self.__run_indexer(root, input_file_type) + self.__run_indexer(root, input_file_type, index_method) print("indexer complete") if dispose is not None: diff --git a/tests/unit/indexing/input/test_csv_loader.py b/tests/unit/indexing/input/test_csv_loader.py index 965f836676..8a6b0e351d 100644 --- a/tests/unit/indexing/input/test_csv_loader.py +++ b/tests/unit/indexing/input/test_csv_loader.py @@ -4,7 +4,7 @@ from graphrag.config.enums import InputFileType from graphrag.config.models.input_config import InputConfig from graphrag.config.models.storage_config import StorageConfig -from graphrag.index.input.factory import create_input +from graphrag.index.input.factory import InputReaderFactory from graphrag.utils.api import create_storage_from_config @@ -17,7 +17,11 @@ async def test_csv_loader_one_file(): file_pattern=".*\\.csv$", ) storage = create_storage_from_config(config.storage) - documents = await create_input(config=config, storage=storage) + documents = ( + await InputReaderFactory() + .create(config.file_type, {"storage": storage, "config": config}) + .read_files() + ) assert documents.shape == (2, 4) assert documents["title"].iloc[0] == "input.csv" @@ -32,7 +36,11 @@ async def test_csv_loader_one_file_with_title(): title_column="title", ) storage = create_storage_from_config(config.storage) - documents = await create_input(config=config, storage=storage) + documents = ( + await InputReaderFactory() + .create(config.file_type, {"storage": storage, "config": config}) + .read_files() + ) assert documents.shape == (2, 4) assert documents["title"].iloc[0] == "Hello" @@ -48,7 +56,12 @@ async def test_csv_loader_one_file_with_metadata(): metadata=["title"], ) storage = create_storage_from_config(config.storage) - documents = await create_input(config=config, storage=storage) + documents = ( + await InputReaderFactory() + .create(config.file_type, {"storage": storage, "config": config}) + .read_files() + ) + print(documents) assert documents.shape == (2, 5) assert documents["metadata"][0] == {"title": "Hello"} @@ -62,5 +75,9 @@ async def test_csv_loader_multiple_files(): file_pattern=".*\\.csv$", ) storage = create_storage_from_config(config.storage) - documents = await create_input(config=config, storage=storage) + documents = ( + await InputReaderFactory() + .create(config.file_type, {"storage": storage, "config": config}) + .read_files() + ) assert documents.shape == (4, 4) diff --git a/tests/unit/indexing/input/test_json_loader.py b/tests/unit/indexing/input/test_json_loader.py index c97d38d4a0..1ce7001aab 100644 --- a/tests/unit/indexing/input/test_json_loader.py +++ b/tests/unit/indexing/input/test_json_loader.py @@ -4,7 +4,7 @@ from graphrag.config.enums import InputFileType from graphrag.config.models.input_config import InputConfig from graphrag.config.models.storage_config import StorageConfig -from graphrag.index.input.factory import create_input +from graphrag.index.input.factory import InputReaderFactory from graphrag.utils.api import create_storage_from_config @@ -17,7 +17,11 @@ async def test_json_loader_one_file_one_object(): file_pattern=".*\\.json$", ) storage = create_storage_from_config(config.storage) - documents = await create_input(config=config, storage=storage) + documents = ( + await InputReaderFactory() + .create(config.file_type, {"storage": storage, "config": config}) + .read_files() + ) assert documents.shape == (1, 4) assert documents["title"].iloc[0] == "input.json" @@ -31,7 +35,11 @@ async def test_json_loader_one_file_multiple_objects(): file_pattern=".*\\.json$", ) storage = create_storage_from_config(config.storage) - documents = await create_input(config=config, storage=storage) + documents = ( + await InputReaderFactory() + .create(config.file_type, {"storage": storage, "config": config}) + .read_files() + ) print(documents) assert documents.shape == (3, 4) assert documents["title"].iloc[0] == "input.json" @@ -47,7 +55,11 @@ async def test_json_loader_one_file_with_title(): title_column="title", ) storage = create_storage_from_config(config.storage) - documents = await create_input(config=config, storage=storage) + documents = ( + await InputReaderFactory() + .create(config.file_type, {"storage": storage, "config": config}) + .read_files() + ) assert documents.shape == (1, 4) assert documents["title"].iloc[0] == "Hello" @@ -63,7 +75,11 @@ async def test_json_loader_one_file_with_metadata(): metadata=["title"], ) storage = create_storage_from_config(config.storage) - documents = await create_input(config=config, storage=storage) + documents = ( + await InputReaderFactory() + .create(config.file_type, {"storage": storage, "config": config}) + .read_files() + ) assert documents.shape == (1, 5) assert documents["metadata"][0] == {"title": "Hello"} @@ -77,5 +93,9 @@ async def test_json_loader_multiple_files(): file_pattern=".*\\.json$", ) storage = create_storage_from_config(config.storage) - documents = await create_input(config=config, storage=storage) + documents = ( + await InputReaderFactory() + .create(config.file_type, {"storage": storage, "config": config}) + .read_files() + ) assert documents.shape == (4, 4) diff --git a/tests/unit/indexing/input/test_txt_loader.py b/tests/unit/indexing/input/test_txt_loader.py index 6b82a408fb..239f622d72 100644 --- a/tests/unit/indexing/input/test_txt_loader.py +++ b/tests/unit/indexing/input/test_txt_loader.py @@ -4,7 +4,7 @@ from graphrag.config.enums import InputFileType from graphrag.config.models.input_config import InputConfig from graphrag.config.models.storage_config import StorageConfig -from graphrag.index.input.factory import create_input +from graphrag.index.input.factory import InputReaderFactory from graphrag.utils.api import create_storage_from_config @@ -17,7 +17,11 @@ async def test_txt_loader_one_file(): file_pattern=".*\\.txt$", ) storage = create_storage_from_config(config.storage) - documents = await create_input(config=config, storage=storage) + documents = ( + await InputReaderFactory() + .create(config.file_type, {"storage": storage, "config": config}) + .read_files() + ) assert documents.shape == (1, 4) assert documents["title"].iloc[0] == "input.txt" @@ -32,7 +36,11 @@ async def test_txt_loader_one_file_with_metadata(): metadata=["title"], ) storage = create_storage_from_config(config.storage) - documents = await create_input(config=config, storage=storage) + documents = ( + await InputReaderFactory() + .create(config.file_type, {"storage": storage, "config": config}) + .read_files() + ) assert documents.shape == (1, 5) # unlike csv, we cannot set the title to anything other than the filename assert documents["metadata"][0] == {"title": "input.txt"} @@ -47,5 +55,9 @@ async def test_txt_loader_multiple_files(): file_pattern=".*\\.txt$", ) storage = create_storage_from_config(config.storage) - documents = await create_input(config=config, storage=storage) + documents = ( + await InputReaderFactory() + .create(config.file_type, {"storage": storage, "config": config}) + .read_files() + ) assert documents.shape == (2, 4) diff --git a/tests/unit/indexing/verbs/entities/extraction/strategies/graph_intelligence/test_gi_entity_extraction.py b/tests/unit/indexing/verbs/entities/extraction/strategies/graph_intelligence/test_gi_entity_extraction.py index 39f041b10c..8390bfe8aa 100644 --- a/tests/unit/indexing/verbs/entities/extraction/strategies/graph_intelligence/test_gi_entity_extraction.py +++ b/tests/unit/indexing/verbs/entities/extraction/strategies/graph_intelligence/test_gi_entity_extraction.py @@ -2,7 +2,6 @@ # Licensed under the MIT License import unittest -from graphrag.cache.factory import CacheFactory from graphrag.index.operations.extract_graph.extract_graph import run_extract_graph from graphrag.index.operations.extract_graph.typing import ( Document, @@ -10,8 +9,6 @@ from graphrag.prompts.index.extract_graph import GRAPH_EXTRACTION_PROMPT from tests.unit.indexing.verbs.helpers.mock_llm import create_mock_llm -_cache = CacheFactory.create_cache("none", kwargs={}) - class TestRunChain(unittest.IsolatedAsyncioTestCase): async def test_run_extract_graph_single_document_correct_entities_returned(self): diff --git a/tests/unit/litellm_services/__init__.py b/tests/unit/litellm_services/__init__.py deleted file mode 100644 index 0a3e38adfb..0000000000 --- a/tests/unit/litellm_services/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License