From e6764e1827de1cf856acb4040d1e2c8c230ef0bc Mon Sep 17 00:00:00 2001 From: Duda Nogueira Date: Mon, 12 Aug 2024 16:39:30 -0300 Subject: [PATCH 1/4] Fix tokenization parameter in to_dict of a property. ISSUE #1237 --- weaviate/collections/classes/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/weaviate/collections/classes/config.py b/weaviate/collections/classes/config.py index e0b7700e3..9230390f8 100644 --- a/weaviate/collections/classes/config.py +++ b/weaviate/collections/classes/config.py @@ -1017,7 +1017,7 @@ def to_dict(self) -> Dict[str, Any]: out["dataType"] = [self.data_type.value] out["indexFilterable"] = self.index_filterable out["indexVector"] = self.index_searchable - out["tokenizer"] = self.tokenization.value if self.tokenization else None + out["tokenization"] = self.tokenization.value if self.tokenization else None module_config: Dict[str, Any] = {} if self.vectorizer is not None: From 6d1135fbb3e68f7e0ca8b9847a18e0d92a63ec85 Mon Sep 17 00:00:00 2001 From: Duda Nogueira Date: Thu, 15 Aug 2024 17:15:25 -0300 Subject: [PATCH 2/4] add some tests --- integration/conftest.py | 11 +++-- integration/test_collection_config.py | 65 ++++++++++++++++++++++++++- 2 files changed, 72 insertions(+), 4 deletions(-) diff --git a/integration/conftest.py b/integration/conftest.py index f41c7ebaf..d19f3830d 100644 --- a/integration/conftest.py +++ b/integration/conftest.py @@ -47,13 +47,14 @@ def __call__( vector_index_config: Optional[_VectorIndexConfigCreate] = None, description: Optional[str] = None, reranker_config: Optional[_RerankerConfigCreate] = None, - ) -> Collection[Any, Any]: + return_client: bool = False + ) -> Union[Collection[Any, Any], Tuple[Collection[Any, Any], weaviate.WeaviateClient]]: """Typing for fixture.""" ... @pytest.fixture -def collection_factory(request: SubRequest) -> Generator[CollectionFactory, None, None]: +def collection_factory(request: SubRequest) -> Generator[Union[Collection[Any, Any], Tuple[Collection[Any, Any], weaviate.WeaviateClient]], None, None]: name_fixture: Optional[str] = None client_fixture: Optional[weaviate.WeaviateClient] = None @@ -75,6 +76,7 @@ def _factory( vector_index_config: Optional[_VectorIndexConfigCreate] = None, description: Optional[str] = None, reranker_config: Optional[_RerankerConfigCreate] = None, + return_client: bool = False ) -> Collection[Any, Any]: nonlocal client_fixture, name_fixture name_fixture = _sanitize_collection_name(request.node.name) + name @@ -101,7 +103,10 @@ def _factory( vector_index_config=vector_index_config, reranker_config=reranker_config, ) - return collection + if return_client: + return collection, client_fixture + else: + return collection try: yield _factory diff --git a/integration/test_collection_config.py b/integration/test_collection_config.py index 7c099734d..3699b315f 100644 --- a/integration/test_collection_config.py +++ b/integration/test_collection_config.py @@ -27,6 +27,7 @@ GenerativeSearches, Rerankers, _RerankerConfigCreate, + Tokenization ) from weaviate.collections.classes.tenants import Tenant @@ -589,7 +590,69 @@ def test_collection_config_get_shards_multi_tenancy(collection_factory: Collecti assert "tenant1" in [shard.name for shard in shards] assert "tenant2" in [shard.name for shard in shards] - +def test_collection_config_create_from_dict(collection_factory: CollectionFactory) -> None: + collection, client = collection_factory( + inverted_index_config=Configure.inverted_index(bm25_b=0.8, bm25_k1=1.3), + multi_tenancy_config=Configure.multi_tenancy(enabled=True), + generative_config=Configure.Generative.openai(model="gpt-4"), + vectorizer_config=Configure.Vectorizer.text2vec_openai(model="ada"), + vector_index_config=Configure.VectorIndex.flat( + vector_cache_max_objects=234, + quantizer=Configure.VectorIndex.Quantizer.bq(rescore_limit=456), + ), + description="Some description", + reranker_config=Configure.Reranker.cohere(model="rerank-english-v2.0"), + properties=[ + Property(name="field_tokenization", data_type=DataType.TEXT, tokenization=Tokenization.FIELD), + Property(name="field_description", data_type=DataType.TEXT, + tokenization=Tokenization.FIELD, description="field desc"), + Property(name="field_index_filterable", data_type=DataType.TEXT, + index_filterable=False), + Property(name="field_skip_vectorization", data_type=DataType.TEXT, + skip_vectorization=True), + Property(name="text", data_type=DataType.TEXT), + Property(name="texts", data_type=DataType.TEXT_ARRAY), + Property(name="number", data_type=DataType.NUMBER), + Property(name="numbers", data_type=DataType.NUMBER_ARRAY), + Property(name="int", data_type=DataType.INT), + Property(name="ints", data_type=DataType.INT_ARRAY), + Property(name="date", data_type=DataType.DATE), + Property(name="dates", data_type=DataType.DATE_ARRAY), + Property(name="boolean", data_type=DataType.BOOL), + Property(name="booleans", data_type=DataType.BOOL_ARRAY), + Property(name="geo", data_type=DataType.GEO_COORDINATES), + Property(name="phone", data_type=DataType.PHONE_NUMBER), + # TODO: this will fail + # Property(name="field_index_searchable", data_type=DataType.TEXT, + # index_searchable=False), + # Property(name="field_skip_vectorization", data_type=DataType.TEXT, + # vectorize_property_name=False), + # Property( + # name="name", + # data_type=DataType.OBJECT, + # nested_properties=[ + # Property(name="first", data_type=DataType.TEXT), + # Property(name="last", data_type=DataType.TEXT), + # ], + # ), + ], + return_client=True + ) + old_dict = collection.config.get().to_dict() + new_dict = old_dict + new_collection_name = collection.name + "_FROM_DICT" + client.collections.delete(new_collection_name) + new_dict["class"] = new_collection_name + new_collection = client.collections.create_from_dict(new_dict) + new_collection_dict = new_collection.config.get().to_dict() + # make the same name for collections + new_collection_dict["class"] = collection.name + old_dict["class"] = collection.name + # check if both dict are the same + assert new_collection_dict == old_dict + # remove the created collection + client.collections.delete(new_collection_name) + def test_config_vector_index_flat_and_quantizer_bq(collection_factory: CollectionFactory) -> None: collection = collection_factory( vector_index_config=Configure.VectorIndex.flat( From 82fc2ceb3e1679b8186f24ea502a3972d728fd4a Mon Sep 17 00:00:00 2001 From: Duda Nogueira Date: Thu, 15 Aug 2024 17:41:26 -0300 Subject: [PATCH 3/4] improve tests and fix indexSearchable not being passed --- integration/test_collection_config.py | 19 +++++++++++++------ weaviate/collections/classes/config.py | 2 +- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/integration/test_collection_config.py b/integration/test_collection_config.py index 3699b315f..1c9b4d7bf 100644 --- a/integration/test_collection_config.py +++ b/integration/test_collection_config.py @@ -595,7 +595,12 @@ def test_collection_config_create_from_dict(collection_factory: CollectionFactor inverted_index_config=Configure.inverted_index(bm25_b=0.8, bm25_k1=1.3), multi_tenancy_config=Configure.multi_tenancy(enabled=True), generative_config=Configure.Generative.openai(model="gpt-4"), - vectorizer_config=Configure.Vectorizer.text2vec_openai(model="ada"), + vectorizer_config=Configure.Vectorizer.text2vec_openai( + model="text-embedding-3-small", + base_url="http://weaviate.io", + vectorize_collection_name=False, + dimensions=512 + ), vector_index_config=Configure.VectorIndex.flat( vector_cache_max_objects=234, quantizer=Configure.VectorIndex.Quantizer.bq(rescore_limit=456), @@ -621,12 +626,12 @@ def test_collection_config_create_from_dict(collection_factory: CollectionFactor Property(name="boolean", data_type=DataType.BOOL), Property(name="booleans", data_type=DataType.BOOL_ARRAY), Property(name="geo", data_type=DataType.GEO_COORDINATES), - Property(name="phone", data_type=DataType.PHONE_NUMBER), + Property(name="phone", data_type=DataType.PHONE_NUMBER), + Property(name="vectorize_property_name", data_type=DataType.TEXT, + vectorize_property_name=False), + Property(name="field_index_searchable", data_type=DataType.TEXT, + index_searchable=False), # TODO: this will fail - # Property(name="field_index_searchable", data_type=DataType.TEXT, - # index_searchable=False), - # Property(name="field_skip_vectorization", data_type=DataType.TEXT, - # vectorize_property_name=False), # Property( # name="name", # data_type=DataType.OBJECT, @@ -649,6 +654,8 @@ def test_collection_config_create_from_dict(collection_factory: CollectionFactor new_collection_dict["class"] = collection.name old_dict["class"] = collection.name # check if both dict are the same + #print("old", old_dict) + #print("new", new_collection_dict) assert new_collection_dict == old_dict # remove the created collection client.collections.delete(new_collection_name) diff --git a/weaviate/collections/classes/config.py b/weaviate/collections/classes/config.py index 9230390f8..e57b71127 100644 --- a/weaviate/collections/classes/config.py +++ b/weaviate/collections/classes/config.py @@ -1016,7 +1016,7 @@ def to_dict(self) -> Dict[str, Any]: out = super().to_dict() out["dataType"] = [self.data_type.value] out["indexFilterable"] = self.index_filterable - out["indexVector"] = self.index_searchable + out["indexSearchable"] = self.index_searchable out["tokenization"] = self.tokenization.value if self.tokenization else None module_config: Dict[str, Any] = {} From aa6c1faa339150473f1186f13b5bd88886bbfabb Mon Sep 17 00:00:00 2001 From: Duda Nogueira Date: Thu, 15 Aug 2024 18:06:03 -0300 Subject: [PATCH 4/4] black linting --- integration/conftest.py | 10 ++++-- integration/test_collection_config.py | 48 +++++++++++++++++---------- 2 files changed, 38 insertions(+), 20 deletions(-) diff --git a/integration/conftest.py b/integration/conftest.py index d19f3830d..f0708e386 100644 --- a/integration/conftest.py +++ b/integration/conftest.py @@ -47,14 +47,18 @@ def __call__( vector_index_config: Optional[_VectorIndexConfigCreate] = None, description: Optional[str] = None, reranker_config: Optional[_RerankerConfigCreate] = None, - return_client: bool = False + return_client: bool = False, ) -> Union[Collection[Any, Any], Tuple[Collection[Any, Any], weaviate.WeaviateClient]]: """Typing for fixture.""" ... @pytest.fixture -def collection_factory(request: SubRequest) -> Generator[Union[Collection[Any, Any], Tuple[Collection[Any, Any], weaviate.WeaviateClient]], None, None]: +def collection_factory( + request: SubRequest, +) -> Generator[ + Union[Collection[Any, Any], Tuple[Collection[Any, Any], weaviate.WeaviateClient]], None, None +]: name_fixture: Optional[str] = None client_fixture: Optional[weaviate.WeaviateClient] = None @@ -76,7 +80,7 @@ def _factory( vector_index_config: Optional[_VectorIndexConfigCreate] = None, description: Optional[str] = None, reranker_config: Optional[_RerankerConfigCreate] = None, - return_client: bool = False + return_client: bool = False, ) -> Collection[Any, Any]: nonlocal client_fixture, name_fixture name_fixture = _sanitize_collection_name(request.node.name) + name diff --git a/integration/test_collection_config.py b/integration/test_collection_config.py index 5a65937dc..1bd8e2c29 100644 --- a/integration/test_collection_config.py +++ b/integration/test_collection_config.py @@ -30,7 +30,7 @@ GenerativeSearches, Rerankers, _RerankerConfigCreate, - Tokenization + Tokenization, ) from weaviate.collections.classes.tenants import Tenant from weaviate.exceptions import UnexpectedStatusCodeError, WeaviateInvalidInputError @@ -684,6 +684,7 @@ def test_collection_config_get_shards_multi_tenancy(collection_factory: Collecti assert "tenant1" in [shard.name for shard in shards] assert "tenant2" in [shard.name for shard in shards] + def test_collection_config_create_from_dict(collection_factory: CollectionFactory) -> None: collection, client = collection_factory( inverted_index_config=Configure.inverted_index(bm25_b=0.8, bm25_k1=1.3), @@ -693,7 +694,7 @@ def test_collection_config_create_from_dict(collection_factory: CollectionFactor model="text-embedding-3-small", base_url="http://weaviate.io", vectorize_collection_name=False, - dimensions=512 + dimensions=512, ), vector_index_config=Configure.VectorIndex.flat( vector_cache_max_objects=234, @@ -702,13 +703,21 @@ def test_collection_config_create_from_dict(collection_factory: CollectionFactor description="Some description", reranker_config=Configure.Reranker.cohere(model="rerank-english-v2.0"), properties=[ - Property(name="field_tokenization", data_type=DataType.TEXT, tokenization=Tokenization.FIELD), - Property(name="field_description", data_type=DataType.TEXT, - tokenization=Tokenization.FIELD, description="field desc"), - Property(name="field_index_filterable", data_type=DataType.TEXT, - index_filterable=False), - Property(name="field_skip_vectorization", data_type=DataType.TEXT, - skip_vectorization=True), + Property( + name="field_tokenization", data_type=DataType.TEXT, tokenization=Tokenization.FIELD + ), + Property( + name="field_description", + data_type=DataType.TEXT, + tokenization=Tokenization.FIELD, + description="field desc", + ), + Property( + name="field_index_filterable", data_type=DataType.TEXT, index_filterable=False + ), + Property( + name="field_skip_vectorization", data_type=DataType.TEXT, skip_vectorization=True + ), Property(name="text", data_type=DataType.TEXT), Property(name="texts", data_type=DataType.TEXT_ARRAY), Property(name="number", data_type=DataType.NUMBER), @@ -721,10 +730,14 @@ def test_collection_config_create_from_dict(collection_factory: CollectionFactor Property(name="booleans", data_type=DataType.BOOL_ARRAY), Property(name="geo", data_type=DataType.GEO_COORDINATES), Property(name="phone", data_type=DataType.PHONE_NUMBER), - Property(name="vectorize_property_name", data_type=DataType.TEXT, - vectorize_property_name=False), - Property(name="field_index_searchable", data_type=DataType.TEXT, - index_searchable=False), + Property( + name="vectorize_property_name", + data_type=DataType.TEXT, + vectorize_property_name=False, + ), + Property( + name="field_index_searchable", data_type=DataType.TEXT, index_searchable=False + ), # TODO: this will fail # Property( # name="name", @@ -735,7 +748,7 @@ def test_collection_config_create_from_dict(collection_factory: CollectionFactor # ], # ), ], - return_client=True + return_client=True, ) old_dict = collection.config.get().to_dict() new_dict = old_dict @@ -748,12 +761,13 @@ def test_collection_config_create_from_dict(collection_factory: CollectionFactor new_collection_dict["class"] = collection.name old_dict["class"] = collection.name # check if both dict are the same - #print("old", old_dict) - #print("new", new_collection_dict) + # print("old", old_dict) + # print("new", new_collection_dict) assert new_collection_dict == old_dict # remove the created collection client.collections.delete(new_collection_name) - + + def test_config_vector_index_flat_and_quantizer_bq(collection_factory: CollectionFactory) -> None: collection = collection_factory( vector_index_config=Configure.VectorIndex.flat(