diff --git a/integration/conftest.py b/integration/conftest.py index d9104244c..f0708e386 100644 --- a/integration/conftest.py +++ b/integration/conftest.py @@ -1,23 +1,11 @@ import os -from typing import ( - Any, - AsyncGenerator, - Optional, - List, - Generator, - Protocol, - Type, - Dict, - Tuple, - Union, -) +from typing import Any, Optional, List, Generator, Protocol, Type, Dict, Tuple, Union import pytest -import pytest_asyncio from _pytest.fixtures import SubRequest import weaviate -from weaviate.collections import Collection, CollectionAsync +from weaviate.collections import Collection from weaviate.collections.classes.config import ( Property, _VectorizerConfigCreate, @@ -31,10 +19,11 @@ _VectorIndexConfigCreate, _RerankerConfigCreate, ) -from weaviate.collections.classes.config_named_vectors import _NamedVectorConfigCreate from weaviate.collections.classes.types import Properties from weaviate.config import AdditionalConfig +from weaviate.collections.classes.config_named_vectors import _NamedVectorConfigCreate + class CollectionFactory(Protocol): """Typing for fixture.""" @@ -58,55 +47,20 @@ def __call__( vector_index_config: Optional[_VectorIndexConfigCreate] = None, description: Optional[str] = None, reranker_config: Optional[_RerankerConfigCreate] = None, - ) -> Collection[Any, Any]: - """Typing for fixture.""" - ... - - -class ClientFactory(Protocol): - """Typing for fixture.""" - - def __call__( - self, - headers: Optional[Dict[str, str]] = None, - ports: Tuple[int, int] = (8080, 50051), - ) -> weaviate.WeaviateClient: + return_client: bool = False, + ) -> Union[Collection[Any, Any], Tuple[Collection[Any, Any], weaviate.WeaviateClient]]: """Typing for fixture.""" ... -@pytest.fixture -def client_factory() -> Generator[ClientFactory, None, None]: - client_fixture: Optional[weaviate.WeaviateClient] = None - - def _factory( - headers: Optional[Dict[str, str]] = None, - ports: Tuple[int, int] = (8080, 50051), - ) -> weaviate.WeaviateClient: - nonlocal client_fixture - if client_fixture is None: - client_fixture = weaviate.connect_to_local( - headers=headers, - grpc_port=ports[1], - port=ports[0], - additional_config=AdditionalConfig(timeout=(60, 120)), # for image tests - ) - return client_fixture - - try: - yield _factory - finally: - if client_fixture is not None: - client_fixture.close() - - @pytest.fixture def collection_factory( - request: SubRequest, client_factory: ClientFactory -) -> Generator[CollectionFactory, None, None]: - name_fixtures: List[str] = [] + request: SubRequest, +) -> Generator[ + Union[Collection[Any, Any], Tuple[Collection[Any, Any], weaviate.WeaviateClient]], None, None +]: + name_fixture: Optional[str] = None client_fixture: Optional[weaviate.WeaviateClient] = None - call_counter: int = 0 def _factory( name: str = "", @@ -126,180 +80,44 @@ def _factory( vector_index_config: Optional[_VectorIndexConfigCreate] = None, description: Optional[str] = None, reranker_config: Optional[_RerankerConfigCreate] = None, + return_client: bool = False, ) -> Collection[Any, Any]: - try: - nonlocal client_fixture, name_fixtures, call_counter - call_counter += 1 - name_fixture = ( - _sanitize_collection_name(request.node.fspath.basename + "_" + request.node.name) - + name - + "_" - + str(call_counter) - ) - name_fixtures.append(name_fixture) - client_fixture = client_factory( - headers=headers, - ports=ports, - ) - collection: Collection[Any, Any] = client_fixture.collections.create( - name=name_fixture, - description=description, - vectorizer_config=vectorizer_config or Configure.Vectorizer.none(), - properties=properties, - references=references, - inverted_index_config=inverted_index_config, - multi_tenancy_config=multi_tenancy_config, - generative_config=generative_config, - data_model_properties=data_model_properties, - data_model_references=data_model_refs, - replication_config=replication_config, - vector_index_config=vector_index_config, - reranker_config=reranker_config, - ) - return collection - except Exception as e: - print("Got exception in _factory", e) - raise e - - try: - yield _factory - except Exception as e: - print("Got exception in collection_factory", e) - raise e - finally: - if client_fixture is not None and name_fixtures is not None: - for name_fixture in name_fixtures: - client_fixture.collections.delete(name_fixture) - - -class AsyncCollectionFactory(Protocol): - """Typing for fixture.""" - - async def __call__( - self, - name: str = "", - properties: Optional[List[Property]] = None, - references: Optional[List[_ReferencePropertyBase]] = None, - vectorizer_config: Optional[ - Union[_VectorizerConfigCreate, List[_NamedVectorConfigCreate]] - ] = None, - inverted_index_config: Optional[_InvertedIndexConfigCreate] = None, - multi_tenancy_config: Optional[_MultiTenancyConfigCreate] = None, - generative_config: Optional[_GenerativeConfigCreate] = None, - headers: Optional[Dict[str, str]] = None, - ports: Tuple[int, int] = (8080, 50051), - data_model_properties: Optional[Type[Properties]] = None, - data_model_refs: Optional[Type[Properties]] = None, - replication_config: Optional[_ReplicationConfigCreate] = None, - vector_index_config: Optional[_VectorIndexConfigCreate] = None, - description: Optional[str] = None, - reranker_config: Optional[_RerankerConfigCreate] = None, - ) -> CollectionAsync[Any, Any]: - """Typing for fixture.""" - ... - - -class AsyncClientFactory(Protocol): - """Typing for fixture.""" - - async def __call__( - self, - headers: Optional[Dict[str, str]] = None, - ports: Tuple[int, int] = (8080, 50051), - ) -> weaviate.WeaviateAsyncClient: - """Typing for fixture.""" - ... - - -@pytest_asyncio.fixture -async def async_client_factory() -> AsyncGenerator[AsyncClientFactory, None]: - client_fixture: Optional[weaviate.WeaviateAsyncClient] = None - - async def _factory( - headers: Optional[Dict[str, str]] = None, - ports: Tuple[int, int] = (8080, 50051), - ) -> weaviate.WeaviateAsyncClient: - nonlocal client_fixture - if client_fixture is None: - client_fixture = weaviate.use_async_with_local( - headers=headers, - grpc_port=ports[1], - port=ports[0], - additional_config=AdditionalConfig(timeout=(60, 120)), # for image tests - ) - await client_fixture.connect() - return client_fixture - - try: - yield _factory - finally: - if client_fixture is not None: - await client_fixture.close() - - -@pytest_asyncio.fixture -async def async_collection_factory( - request: SubRequest, async_client_factory: AsyncClientFactory -) -> AsyncGenerator[AsyncCollectionFactory, None]: - name_fixtures: List[str] = [] - client_fixture: Optional[weaviate.WeaviateAsyncClient] = None + nonlocal client_fixture, name_fixture + name_fixture = _sanitize_collection_name(request.node.name) + name + client_fixture = weaviate.connect_to_local( + headers=headers, + grpc_port=ports[1], + port=ports[0], + additional_config=AdditionalConfig(timeout=(60, 120)), # for image tests + ) + client_fixture.collections.delete(name_fixture) - async def _factory( - name: str = "", - properties: Optional[List[Property]] = None, - references: Optional[List[_ReferencePropertyBase]] = None, - vectorizer_config: Optional[ - Union[_VectorizerConfigCreate, List[_NamedVectorConfigCreate]] - ] = None, - inverted_index_config: Optional[_InvertedIndexConfigCreate] = None, - multi_tenancy_config: Optional[_MultiTenancyConfigCreate] = None, - generative_config: Optional[_GenerativeConfigCreate] = None, - headers: Optional[Dict[str, str]] = None, - ports: Tuple[int, int] = (8080, 50051), - data_model_properties: Optional[Type[Properties]] = None, - data_model_refs: Optional[Type[Properties]] = None, - replication_config: Optional[_ReplicationConfigCreate] = None, - vector_index_config: Optional[_VectorIndexConfigCreate] = None, - description: Optional[str] = None, - reranker_config: Optional[_RerankerConfigCreate] = None, - ) -> CollectionAsync[Any, Any]: - try: - nonlocal client_fixture, name_fixtures - name_fixture = _sanitize_collection_name(request.node.name) + name - name_fixtures.append(name_fixture) - client_fixture = await async_client_factory( - headers=headers, - ports=ports, - ) - collection: CollectionAsync[Any, Any] = await client_fixture.collections.create( - name=name_fixture, - description=description, - vectorizer_config=vectorizer_config or Configure.Vectorizer.none(), - properties=properties, - references=references, - inverted_index_config=inverted_index_config, - multi_tenancy_config=multi_tenancy_config, - generative_config=generative_config, - data_model_properties=data_model_properties, - data_model_references=data_model_refs, - replication_config=replication_config, - vector_index_config=vector_index_config, - reranker_config=reranker_config, - ) + collection: Collection[Any, Any] = client_fixture.collections.create( + name=name_fixture, + description=description, + vectorizer_config=vectorizer_config or Configure.Vectorizer.none(), + properties=properties, + references=references, + inverted_index_config=inverted_index_config, + multi_tenancy_config=multi_tenancy_config, + generative_config=generative_config, + data_model_properties=data_model_properties, + data_model_references=data_model_refs, + replication_config=replication_config, + vector_index_config=vector_index_config, + reranker_config=reranker_config, + ) + if return_client: + return collection, client_fixture + else: return collection - except Exception as e: - print("Got exception in _factory", e) - raise e try: yield _factory - except Exception as e: - print("Got exception in collection_factory", e) - raise e finally: - if client_fixture is not None and name_fixtures is not None: - for name_fixture in name_fixtures: - await client_fixture.collections.delete(name_fixture) + if client_fixture is not None and name_fixture is not None: + client_fixture.collections.delete(name_fixture) + client_fixture.close() class OpenAICollection(Protocol): @@ -351,55 +169,6 @@ def _factory( yield _factory -class AsyncOpenAICollectionFactory(Protocol): - """Typing for fixture.""" - - async def __call__( - self, - name: str = "", - vectorizer_config: Optional[ - Union[_VectorizerConfigCreate, List[_NamedVectorConfigCreate]] - ] = None, - ) -> CollectionAsync[Any, Any]: - """Typing for fixture.""" - ... - - -@pytest_asyncio.fixture -async def async_openai_collection( - async_collection_factory: AsyncCollectionFactory, -) -> AsyncGenerator[AsyncOpenAICollectionFactory, None]: - async def _factory( - name: str = "", - vectorizer_config: Optional[ - Union[_VectorizerConfigCreate, List[_NamedVectorConfigCreate]] - ] = None, - ) -> CollectionAsync[Any, Any]: - api_key = os.environ.get("OPENAI_APIKEY") - if api_key is None: - pytest.skip("No OpenAI API key found.") - - if vectorizer_config is None: - vectorizer_config = Configure.Vectorizer.none() - - collection = await async_collection_factory( - name=name, - vectorizer_config=vectorizer_config or Configure.Vectorizer.none(), - properties=[ - Property(name="text", data_type=DataType.TEXT), - Property(name="content", data_type=DataType.TEXT), - Property(name="extra", data_type=DataType.TEXT), - ], - generative_config=Configure.Generative.openai(), - ports=(8086, 50057), - headers={"X-OpenAI-Api-Key": api_key}, - ) - - return collection - - yield _factory - - class CollectionFactoryGet(Protocol): """Typing for fixture.""" @@ -415,9 +184,8 @@ def __call__( @pytest.fixture -def collection_factory_get( - client_factory: ClientFactory, -) -> Generator[CollectionFactoryGet, None, None]: +def collection_factory_get() -> Generator[CollectionFactoryGet, None, None]: + client_fixture: Optional[weaviate.WeaviateClient] = None name_fixture: Optional[str] = None def _factory( @@ -426,9 +194,11 @@ def _factory( data_model_refs: Optional[Type[Properties]] = None, skip_argument_validation: bool = False, ) -> Collection[Any, Any]: - nonlocal name_fixture + nonlocal client_fixture, name_fixture name_fixture = _sanitize_collection_name(name) - collection: Collection[Any, Any] = client_factory().collections.get( + client_fixture = weaviate.connect_to_local() + + collection: Collection[Any, Any] = client_fixture.collections.get( name=name_fixture, data_model_properties=data_model_props, data_model_references=data_model_refs, @@ -436,7 +206,12 @@ def _factory( ) return collection - yield _factory + try: + yield _factory + finally: + if client_fixture is not None and name_fixture is not None: + client_fixture.collections.delete(name_fixture) + client_fixture.close() def _sanitize_collection_name(name: str) -> str: diff --git a/integration/test_collection_config.py b/integration/test_collection_config.py index d20a768ab..1bd8e2c29 100644 --- a/integration/test_collection_config.py +++ b/integration/test_collection_config.py @@ -30,6 +30,7 @@ GenerativeSearches, Rerankers, _RerankerConfigCreate, + Tokenization, ) from weaviate.collections.classes.tenants import Tenant from weaviate.exceptions import UnexpectedStatusCodeError, WeaviateInvalidInputError @@ -684,6 +685,89 @@ def test_collection_config_get_shards_multi_tenancy(collection_factory: Collecti assert "tenant2" in [shard.name for shard in shards] +def test_collection_config_create_from_dict(collection_factory: CollectionFactory) -> None: + collection, client = collection_factory( + inverted_index_config=Configure.inverted_index(bm25_b=0.8, bm25_k1=1.3), + multi_tenancy_config=Configure.multi_tenancy(enabled=True), + generative_config=Configure.Generative.openai(model="gpt-4"), + vectorizer_config=Configure.Vectorizer.text2vec_openai( + model="text-embedding-3-small", + base_url="http://weaviate.io", + vectorize_collection_name=False, + dimensions=512, + ), + vector_index_config=Configure.VectorIndex.flat( + vector_cache_max_objects=234, + quantizer=Configure.VectorIndex.Quantizer.bq(rescore_limit=456), + ), + description="Some description", + reranker_config=Configure.Reranker.cohere(model="rerank-english-v2.0"), + properties=[ + Property( + name="field_tokenization", data_type=DataType.TEXT, tokenization=Tokenization.FIELD + ), + Property( + name="field_description", + data_type=DataType.TEXT, + tokenization=Tokenization.FIELD, + description="field desc", + ), + Property( + name="field_index_filterable", data_type=DataType.TEXT, index_filterable=False + ), + Property( + name="field_skip_vectorization", data_type=DataType.TEXT, skip_vectorization=True + ), + Property(name="text", data_type=DataType.TEXT), + Property(name="texts", data_type=DataType.TEXT_ARRAY), + Property(name="number", data_type=DataType.NUMBER), + Property(name="numbers", data_type=DataType.NUMBER_ARRAY), + Property(name="int", data_type=DataType.INT), + Property(name="ints", data_type=DataType.INT_ARRAY), + Property(name="date", data_type=DataType.DATE), + Property(name="dates", data_type=DataType.DATE_ARRAY), + Property(name="boolean", data_type=DataType.BOOL), + Property(name="booleans", data_type=DataType.BOOL_ARRAY), + Property(name="geo", data_type=DataType.GEO_COORDINATES), + Property(name="phone", data_type=DataType.PHONE_NUMBER), + Property( + name="vectorize_property_name", + data_type=DataType.TEXT, + vectorize_property_name=False, + ), + Property( + name="field_index_searchable", data_type=DataType.TEXT, index_searchable=False + ), + # TODO: this will fail + # Property( + # name="name", + # data_type=DataType.OBJECT, + # nested_properties=[ + # Property(name="first", data_type=DataType.TEXT), + # Property(name="last", data_type=DataType.TEXT), + # ], + # ), + ], + return_client=True, + ) + old_dict = collection.config.get().to_dict() + new_dict = old_dict + new_collection_name = collection.name + "_FROM_DICT" + client.collections.delete(new_collection_name) + new_dict["class"] = new_collection_name + new_collection = client.collections.create_from_dict(new_dict) + new_collection_dict = new_collection.config.get().to_dict() + # make the same name for collections + new_collection_dict["class"] = collection.name + old_dict["class"] = collection.name + # check if both dict are the same + # print("old", old_dict) + # print("new", new_collection_dict) + assert new_collection_dict == old_dict + # remove the created collection + client.collections.delete(new_collection_name) + + def test_config_vector_index_flat_and_quantizer_bq(collection_factory: CollectionFactory) -> None: collection = collection_factory( vector_index_config=Configure.VectorIndex.flat( diff --git a/weaviate/collections/classes/config.py b/weaviate/collections/classes/config.py index 2b8d84203..3ca1664e5 100644 --- a/weaviate/collections/classes/config.py +++ b/weaviate/collections/classes/config.py @@ -1219,8 +1219,8 @@ def to_dict(self) -> Dict[str, Any]: out = super().to_dict() out["dataType"] = [self.data_type.value] out["indexFilterable"] = self.index_filterable - out["indexVector"] = self.index_searchable - out["tokenizer"] = self.tokenization.value if self.tokenization else None + out["indexSearchable"] = self.index_searchable + out["tokenization"] = self.tokenization.value if self.tokenization else None module_config: Dict[str, Any] = {} if self.vectorizer is not None: