diff --git a/chromadb/api/__init__.py b/chromadb/api/__init__.py index 40662d82a09..60865835cc4 100644 --- a/chromadb/api/__init__.py +++ b/chromadb/api/__init__.py @@ -270,6 +270,7 @@ def _query( self, collection_id: UUID, query_embeddings: Embeddings, + ids: Optional[IDs] = None, n_results: int = 10, where: Optional[Where] = None, where_document: Optional[WhereDocument] = None, @@ -280,6 +281,7 @@ def _query( Args: collection_id: The UUID of the collection to query. query_embeddings: The embeddings to use as the query. + ids: The IDs to filter by during the query. Defaults to None. n_results: The number of results to return. Defaults to 10. where: Conditional filtering on metadata. Defaults to None. where_document: Conditional filtering on documents. Defaults to None. @@ -734,6 +736,7 @@ def _query( self, collection_id: UUID, query_embeddings: Embeddings, + ids: Optional[IDs] = None, n_results: int = 10, where: Optional[Where] = None, where_document: Optional[WhereDocument] = None, diff --git a/chromadb/api/async_api.py b/chromadb/api/async_api.py index 8561661fca9..f3eb365cc7c 100644 --- a/chromadb/api/async_api.py +++ b/chromadb/api/async_api.py @@ -264,6 +264,7 @@ async def _query( self, collection_id: UUID, query_embeddings: Embeddings, + ids: Optional[IDs] = None, n_results: int = 10, where: Optional[Where] = None, where_document: Optional[WhereDocument] = None, @@ -728,6 +729,7 @@ async def _query( self, collection_id: UUID, query_embeddings: Embeddings, + ids: Optional[IDs] = None, n_results: int = 10, where: Optional[Where] = None, where_document: Optional[WhereDocument] = None, diff --git a/chromadb/api/async_client.py b/chromadb/api/async_client.py index b787b3884a1..3259e7a88db 100644 --- a/chromadb/api/async_client.py +++ b/chromadb/api/async_client.py @@ -403,6 +403,7 @@ async def _query( self, collection_id: UUID, query_embeddings: Embeddings, + ids: Optional[IDs] = None, n_results: int = 10, where: Optional[Where] = None, where_document: Optional[WhereDocument] = None, @@ -411,6 +412,7 @@ async def _query( return await self._server._query( collection_id=collection_id, query_embeddings=query_embeddings, + ids=ids, n_results=n_results, where=where, where_document=where_document, diff --git a/chromadb/api/async_fastapi.py b/chromadb/api/async_fastapi.py index afa753deb28..4772e21cb84 100644 --- a/chromadb/api/async_fastapi.py +++ b/chromadb/api/async_fastapi.py @@ -617,6 +617,7 @@ async def _query( self, collection_id: UUID, query_embeddings: Embeddings, + ids: Optional[IDs] = None, n_results: int = 10, where: Optional[Where] = None, where_document: Optional[WhereDocument] = None, @@ -631,6 +632,7 @@ async def _query( "post", f"/tenants/{tenant}/databases/{database}/collections/{collection_id}/query", json={ + "ids": ids, "query_embeddings": convert_np_embeddings_to_list(query_embeddings) if query_embeddings is not None else None, diff --git a/chromadb/api/client.py b/chromadb/api/client.py index ef0dd350354..1f132077c41 100644 --- a/chromadb/api/client.py +++ b/chromadb/api/client.py @@ -374,6 +374,7 @@ def _query( self, collection_id: UUID, query_embeddings: Embeddings, + ids: Optional[IDs] = None, n_results: int = 10, where: Optional[Where] = None, where_document: Optional[WhereDocument] = None, @@ -381,6 +382,7 @@ def _query( ) -> QueryResult: return self._server._query( collection_id=collection_id, + ids=ids, tenant=self.tenant, database=self.database, query_embeddings=query_embeddings, diff --git a/chromadb/api/fastapi.py b/chromadb/api/fastapi.py index 5ddf348c3d1..137b82515fa 100644 --- a/chromadb/api/fastapi.py +++ b/chromadb/api/fastapi.py @@ -588,6 +588,7 @@ def _query( self, collection_id: UUID, query_embeddings: Embeddings, + ids: Optional[IDs] = None, n_results: int = 10, where: Optional[Where] = None, where_document: Optional[WhereDocument] = None, @@ -603,6 +604,7 @@ def _query( "post", f"/tenants/{tenant}/databases/{database}/collections/{collection_id}/query", json={ + "ids": ids, "query_embeddings": convert_np_embeddings_to_list(query_embeddings) if query_embeddings is not None else None, diff --git a/chromadb/api/models/AsyncCollection.py b/chromadb/api/models/AsyncCollection.py index 7713de48385..542c0a6090e 100644 --- a/chromadb/api/models/AsyncCollection.py +++ b/chromadb/api/models/AsyncCollection.py @@ -169,6 +169,7 @@ async def query( query_texts: Optional[OneOrMany[Document]] = None, query_images: Optional[OneOrMany[Image]] = None, query_uris: Optional[OneOrMany[URI]] = None, + ids: Optional[OneOrMany[ID]] = None, n_results: int = 10, where: Optional[Where] = None, where_document: Optional[WhereDocument] = None, @@ -184,6 +185,7 @@ async def query( query_embeddings: The embeddings to get the closes neighbors of. Optional. query_texts: The document texts to get the closes neighbors of. Optional. query_images: The images to get the closes neighbors of. Optional. + ids: A subset of ids to search within. Optional. n_results: The number of neighbors to return for each query_embedding or query_texts. Optional. where: A Where type dict used to filter results by. E.g. `{"$and": [{"color" : "red"}, {"price": {"$gte": 4.20}}]}`. Optional. where_document: A WhereDocument type dict used to filter by the documents. E.g. `{"$contains": "hello"}`. Optional. @@ -205,6 +207,7 @@ async def query( query_texts=query_texts, query_images=query_images, query_uris=query_uris, + ids=ids, n_results=n_results, where=where, where_document=where_document, @@ -213,6 +216,7 @@ async def query( query_results = await self._client._query( collection_id=self.id, + ids=query_request["ids"], query_embeddings=query_request["embeddings"], n_results=query_request["n_results"], where=query_request["where"], @@ -279,7 +283,7 @@ async def fork( client=self._client, model=model, embedding_function=self._embedding_function, - data_loader=self._data_loader + data_loader=self._data_loader, ) async def update( diff --git a/chromadb/api/models/Collection.py b/chromadb/api/models/Collection.py index 9a9b8f1c163..b42c2ff64ec 100644 --- a/chromadb/api/models/Collection.py +++ b/chromadb/api/models/Collection.py @@ -172,6 +172,7 @@ def query( query_texts: Optional[OneOrMany[Document]] = None, query_images: Optional[OneOrMany[Image]] = None, query_uris: Optional[OneOrMany[URI]] = None, + ids: Optional[OneOrMany[ID]] = None, n_results: int = 10, where: Optional[Where] = None, where_document: Optional[WhereDocument] = None, @@ -188,6 +189,7 @@ def query( query_texts: The document texts to get the closes neighbors of. Optional. query_images: The images to get the closes neighbors of. Optional. query_uris: The URIs to be used with data loader. Optional. + ids: A subset of ids to search within. Optional. n_results: The number of neighbors to return for each query_embedding or query_texts. Optional. where: A Where type dict used to filter results by. E.g. `{"$and": [{"color" : "red"}, {"price": {"$gte": 4.20}}]}`. Optional. where_document: A WhereDocument type dict used to filter by the documents. E.g. `{"$contains": "hello"}`. Optional. @@ -209,6 +211,7 @@ def query( query_texts=query_texts, query_images=query_images, query_uris=query_uris, + ids=ids, n_results=n_results, where=where, where_document=where_document, @@ -217,6 +220,7 @@ def query( query_results = self._client._query( collection_id=self.id, + ids=query_request["ids"], query_embeddings=query_request["embeddings"], n_results=query_request["n_results"], where=query_request["where"], @@ -285,7 +289,7 @@ def fork( client=self._client, model=model, embedding_function=self._embedding_function, - data_loader=self._data_loader + data_loader=self._data_loader, ) def update( diff --git a/chromadb/api/models/CollectionCommon.py b/chromadb/api/models/CollectionCommon.py index 68d65ad8e2a..d7b1cebd9fb 100644 --- a/chromadb/api/models/CollectionCommon.py +++ b/chromadb/api/models/CollectionCommon.py @@ -294,6 +294,7 @@ def _validate_and_prepare_query_request( query_texts: Optional[OneOrMany[Document]], query_images: Optional[OneOrMany[Image]], query_uris: Optional[OneOrMany[URI]], + ids: Optional[OneOrMany[ID]], n_results: int, where: Optional[Where], where_document: Optional[WhereDocument], @@ -307,6 +308,8 @@ def _validate_and_prepare_query_request( uris=query_uris, ) + filter_ids = maybe_cast_one_to_many(ids) + filters = FilterSet( where=where, where_document=where_document, @@ -335,6 +338,7 @@ def _validate_and_prepare_query_request( return QueryRequest( embeddings=request_embeddings, + ids=filter_ids, where=request_where, where_document=request_where_document, include=request_include, diff --git a/chromadb/api/rust.py b/chromadb/api/rust.py index 18a2a6037bd..465c6dc2634 100644 --- a/chromadb/api/rust.py +++ b/chromadb/api/rust.py @@ -482,6 +482,7 @@ def _query( self, collection_id: UUID, query_embeddings: Embeddings, + ids: Optional[IDs] = None, n_results: int = 10, where: Optional[Where] = None, where_document: Optional[WhereDocument] = None, @@ -490,10 +491,12 @@ def _query( database: str = DEFAULT_DATABASE, ) -> QueryResult: query_amount = len(query_embeddings) + filtered_ids_amount = len(ids) if ids else 0 self.product_telemetry_client.capture( CollectionQueryEvent( collection_uuid=str(collection_id), query_amount=query_amount, + filtered_ids_amount=filtered_ids_amount, n_results=n_results, with_metadata_filter=query_amount if where is not None else 0, with_document_filter=query_amount if where_document is not None else 0, @@ -506,6 +509,7 @@ def _query( rust_response = self.bindings.query( str(collection_id), + ids, query_embeddings, n_results, json.dumps(where) if where else None, diff --git a/chromadb/api/segment.py b/chromadb/api/segment.py index e0acb1ef892..c7b900323f8 100644 --- a/chromadb/api/segment.py +++ b/chromadb/api/segment.py @@ -785,6 +785,7 @@ def _query( self, collection_id: UUID, query_embeddings: Embeddings, + ids: Optional[IDs] = None, n_results: int = 10, where: Optional[Where] = None, where_document: Optional[WhereDocument] = None, @@ -801,10 +802,12 @@ def _query( ) query_amount = len(query_embeddings) + ids_amount = len(ids) if ids else 0 self._product_telemetry_client.capture( CollectionQueryEvent( collection_uuid=str(collection_id), query_amount=query_amount, + filtered_ids_amount=ids_amount, n_results=n_results, with_metadata_filter=query_amount if where is not None else 0, with_document_filter=query_amount if where_document is not None else 0, diff --git a/chromadb/api/types.py b/chromadb/api/types.py index 7d64e6707d1..bb723cee180 100644 --- a/chromadb/api/types.py +++ b/chromadb/api/types.py @@ -396,6 +396,7 @@ class GetResult(TypedDict): class QueryRequest(TypedDict): embeddings: Embeddings + ids: Optional[IDs] where: Optional[Where] where_document: Optional[WhereDocument] include: Include diff --git a/chromadb/telemetry/product/events.py b/chromadb/telemetry/product/events.py index 568a84ca7c7..8ea7fe8bed9 100644 --- a/chromadb/telemetry/product/events.py +++ b/chromadb/telemetry/product/events.py @@ -137,6 +137,7 @@ class CollectionQueryEvent(ProductTelemetryEvent): batch_size: int collection_uuid: str query_amount: int + filtered_ids_amount: int with_metadata_filter: int with_document_filter: int n_results: int @@ -149,6 +150,7 @@ def __init__( self, collection_uuid: str, query_amount: int, + filtered_ids_amount: int, with_metadata_filter: int, with_document_filter: int, n_results: int, @@ -161,6 +163,7 @@ def __init__( super().__init__() self.collection_uuid = collection_uuid self.query_amount = query_amount + self.filtered_ids_amount = filtered_ids_amount self.with_metadata_filter = with_metadata_filter self.with_document_filter = with_document_filter self.n_results = n_results @@ -182,6 +185,7 @@ def batch(self, other: "ProductTelemetryEvent") -> "CollectionQueryEvent": return CollectionQueryEvent( collection_uuid=self.collection_uuid, query_amount=total_amount, + filtered_ids_amount=self.filtered_ids_amount + other.filtered_ids_amount, with_metadata_filter=self.with_metadata_filter + other.with_metadata_filter, with_document_filter=self.with_document_filter + other.with_document_filter, n_results=self.n_results + other.n_results, diff --git a/chromadb/test/property/test_filtering.py b/chromadb/test/property/test_filtering.py index 984a59b0f77..8fc80969c6b 100644 --- a/chromadb/test/property/test_filtering.py +++ b/chromadb/test/property/test_filtering.py @@ -19,7 +19,6 @@ import chromadb.test.property.strategies as strategies import hypothesis.strategies as st import logging -import random import re from chromadb.test.utils.wait_for_version_increase import wait_for_version_increase import numpy as np @@ -325,6 +324,7 @@ def test_filterable_metadata_get_limit_offset( min_size=1, ), should_compact=st.booleans(), + data=st.data(), ) def test_filterable_metadata_query( caplog: pytest.LogCaptureFixture, @@ -333,6 +333,7 @@ def test_filterable_metadata_query( record_set: strategies.RecordSet, filters: List[strategies.Filter], should_compact: bool, + data: st.DataObject, ) -> None: caplog.set_level(logging.ERROR) @@ -355,19 +356,21 @@ def test_filterable_metadata_query( wait_for_version_increase(client, collection.name, initial_version) # type: ignore total_count = len(normalized_record_set["ids"]) - # Pick a random vector + # Pick a random vector using Hypothesis data random_query: Embedding + + query_index = data.draw(st.integers(min_value=0, max_value=total_count - 1)) if collection.has_embeddings: assert normalized_record_set["embeddings"] is not None assert all(isinstance(e, list) for e in normalized_record_set["embeddings"]) - random_query = normalized_record_set["embeddings"][ - random.randint(0, total_count - 1) - ] + # Use data.draw to select index + random_query = normalized_record_set["embeddings"][query_index] else: assert isinstance(normalized_record_set["documents"], list) assert collection.embedding_function is not None + # Use data.draw to select index random_query = collection.embedding_function( - [normalized_record_set["documents"][random.randint(0, total_count - 1)]] + [normalized_record_set["documents"][query_index]] )[0] for filter in filters: result_ids = set( @@ -402,7 +405,7 @@ def test_empty_filter(client: ClientAPI) -> None: query_embeddings=test_query_embedding, where={"q": {"$eq": 4}}, # type: ignore[dict-item] n_results=3, - include=["embeddings", "distances", "metadatas"], # type: ignore[list-item] + include=["embeddings", "distances", "metadatas"], ) assert res["ids"] == [[]] if res["embeddings"] is not None: @@ -459,9 +462,107 @@ def check_empty_res(res: GetResult) -> None: coll.add(ids=test_ids, embeddings=test_embeddings, metadatas=test_metadatas) - res = coll.get(ids=["nope"], include=["embeddings", "metadatas", "documents"]) # type: ignore[list-item] + res = coll.get(ids=["nope"], include=["embeddings", "metadatas", "documents"]) check_empty_res(res) res = coll.get( - include=["embeddings", "metadatas", "documents"], where={"test": 100} # type: ignore[list-item] + include=["embeddings", "metadatas", "documents"], where={"test": 100} ) check_empty_res(res) + + +@settings( + deadline=90000, + suppress_health_check=[ + HealthCheck.function_scoped_fixture, + HealthCheck.large_base_example, + ], +) +@given( + collection=collection_st, + record_set=recordset_st, + n_results_st=st.integers(min_value=1, max_value=100), + should_compact=st.booleans(), + data=st.data(), +) +def test_query_ids_filter_property( + caplog: pytest.LogCaptureFixture, + client: ClientAPI, + collection: strategies.Collection, + record_set: strategies.RecordSet, + n_results_st: int, + should_compact: bool, + data: st.DataObject, +) -> None: + """Property test for querying with only the ids filter.""" + if ( + client.get_settings().chroma_api_impl + == "chromadb.api.async_fastapi.AsyncFastAPI" + ): + pytest.skip( + "Skipping test for async client due to potential resource/timeout issues" + ) + caplog.set_level(logging.ERROR) + reset(client) + coll = client.create_collection( + name=collection.name, + metadata=collection.metadata, # type: ignore + embedding_function=collection.embedding_function, + ) + initial_version = coll.get_model()["version"] + normalized_record_set = invariants.wrap_all(record_set) + + if len(normalized_record_set["ids"]) == 0: + # Cannot add empty record set + return + + coll.add(**record_set) # type: ignore[arg-type] + + if not NOT_CLUSTER_ONLY: + if should_compact and len(normalized_record_set["ids"]) > 10: + wait_for_version_increase(client, collection.name, initial_version) # type: ignore + + total_count = len(normalized_record_set["ids"]) + n_results = min(n_results_st, total_count) + + # Generate a random subset of ids to filter on using Hypothesis data + ids_to_query = data.draw( + st.lists( + st.sampled_from(normalized_record_set["ids"]), + min_size=0, + max_size=total_count, + unique=True, + ) + ) + + # Pick a random query vector using Hypothesis data + random_query: Embedding + query_index = data.draw(st.integers(min_value=0, max_value=total_count - 1)) + if collection.has_embeddings: + assert normalized_record_set["embeddings"] is not None + assert all(isinstance(e, list) for e in normalized_record_set["embeddings"]) + # Use data.draw to select index + random_query = normalized_record_set["embeddings"][query_index] + else: + assert isinstance(normalized_record_set["documents"], list) + assert collection.embedding_function is not None + # Use data.draw to select index + random_query = collection.embedding_function( + [normalized_record_set["documents"][query_index]] + )[0] + + # Perform the query with only the ids filter + result = coll.query( + query_embeddings=[random_query], + ids=ids_to_query, + n_results=n_results, + ) + + result_ids = set(result["ids"][0]) + filter_ids_set = set(ids_to_query) + + # The core assertion: all returned IDs must be within the filter set + assert result_ids.issubset(filter_ids_set) + + # Also check that the number of results is reasonable + assert len(result_ids) <= n_results + assert len(result_ids) <= len(filter_ids_set) diff --git a/chromadb/test/test_api.py b/chromadb/test/test_api.py index 8650b46ab9f..bf7d91be877 100644 --- a/chromadb/test/test_api.py +++ b/chromadb/test/test_api.py @@ -1636,3 +1636,158 @@ def test_ssl_self_signed_without_ssl_verify(client_ssl): ) client_ssl.clear_system_cache() assert "CERTIFICATE_VERIFY_FAILED" in "".join(stack_trace) + + +def test_query_id_filtering_small_dataset(client): + client.reset() + collection = client.create_collection("test_query_id_filtering_small") + + num_vectors = 100 + dim = 512 + small_records = np.random.rand(100, 512).astype(np.float32).tolist() + ids = [f"{i}" for i in range(num_vectors)] + + collection.add( + embeddings=small_records, + ids=ids, + ) + + query_ids = [f"{i}" for i in range(0, num_vectors, 10)] + query_embedding = np.random.rand(dim).astype(np.float32).tolist() + results = collection.query( + query_embeddings=query_embedding, + ids=query_ids, + n_results=num_vectors, + include=[], + ) + + all_returned_ids = [item for sublist in results["ids"] for item in sublist] + assert all(id in query_ids for id in all_returned_ids) + + +def test_query_id_filtering_medium_dataset(client): + client.reset() + collection = client.create_collection("test_query_id_filtering_medium") + + num_vectors = 1000 + dim = 512 + medium_records = np.random.rand(num_vectors, dim).astype(np.float32).tolist() + ids = [f"{i}" for i in range(num_vectors)] + + collection.add( + embeddings=medium_records, + ids=ids, + ) + + query_ids = [f"{i}" for i in range(0, num_vectors, 10)] + + query_embedding = np.random.rand(dim).astype(np.float32).tolist() + results = collection.query( + query_embeddings=query_embedding, + ids=query_ids, + n_results=num_vectors, + include=[], + ) + + all_returned_ids = [item for sublist in results["ids"] for item in sublist] + assert all(id in query_ids for id in all_returned_ids) + + multi_query_embeddings = [ + np.random.rand(dim).astype(np.float32).tolist() for _ in range(3) + ] + multi_results = collection.query( + query_embeddings=multi_query_embeddings, + ids=query_ids, + n_results=10, + include=[], + ) + + for result_set in multi_results["ids"]: + assert all(id in query_ids for id in result_set) + + +def test_query_id_filtering_e2e(client): + client.reset() + collection = client.create_collection("test_query_id_filtering_e2e") + + dim = 512 + num_vectors = 100 + embeddings = np.random.rand(num_vectors, dim).astype(np.float32).tolist() + ids = [f"{i}" for i in range(num_vectors)] + metadatas = [{"index": i} for i in range(num_vectors)] + + collection.add( + embeddings=embeddings, + ids=ids, + metadatas=metadatas, + ) + + ids_to_delete = [f"{i}" for i in range(10, 30)] + collection.delete(ids=ids_to_delete) + + # modify some existing ids, and add some new ones to check query returns updated metadata + ids_to_upsert_existing = [f"{i}" for i in range(30, 50)] + new_num_vectors = num_vectors + 20 + ids_to_upsert_new = [f"{i}" for i in range(num_vectors, new_num_vectors)] + + upsert_embeddings = ( + np.random.rand(len(ids_to_upsert_existing) + len(ids_to_upsert_new), dim) + .astype(np.float32) + .tolist() + ) + upsert_metadatas = [ + {"index": i, "upserted": True} for i in range(len(upsert_embeddings)) + ] + + collection.upsert( + embeddings=upsert_embeddings, + ids=ids_to_upsert_existing + ids_to_upsert_new, + metadatas=upsert_metadatas, + ) + + valid_query_ids = ( + [f"{i}" for i in range(5, 10)] # subset of existing ids + + [f"{i}" for i in range(35, 45)] # subset of existing, but upserted + + [ + f"{i}" for i in range(num_vectors + 5, num_vectors + 15) + ] # subset of new upserted ids + ) + + includes = ["metadatas"] + query_embedding = np.random.rand(dim).astype(np.float32).tolist() + results = collection.query( + query_embeddings=query_embedding, + ids=valid_query_ids, + n_results=new_num_vectors, + include=includes, + ) + + all_returned_ids = [item for sublist in results["ids"] for item in sublist] + assert all(id in valid_query_ids for id in all_returned_ids) + + for result_index, id_list in enumerate(results["ids"]): + for item_index, item_id in enumerate(id_list): + if item_id in ids_to_upsert_existing or item_id in ids_to_upsert_new: + # checks if metadata correctly has upserted flag + assert results["metadatas"][result_index][item_index]["upserted"] + + upserted_id = ids_to_upsert_existing[0] + # test single id filtering + results = collection.query( + query_embeddings=query_embedding, + ids=upserted_id, + n_results=1, + include=includes, + ) + assert results["metadatas"][0][0]["upserted"] + + deleted_id = ids_to_delete[0] + # test deleted id filter raises + with pytest.raises(Exception) as error: + collection.query( + query_embeddings=query_embedding, + ids=deleted_id, + n_results=1, + include=includes, + ) + assert "Error finding id" in str(error.value) diff --git a/clients/js/packages/chromadb-core/src/Collection.ts b/clients/js/packages/chromadb-core/src/Collection.ts index 89388e27417..8e13c92200f 100644 --- a/clients/js/packages/chromadb-core/src/Collection.ts +++ b/clients/js/packages/chromadb-core/src/Collection.ts @@ -250,6 +250,7 @@ export class Collection { * @param {string | string[]} [params.queryTexts] - Optional query text(s) to search for in the collection. * @param {WhereDocument} [params.whereDocument] - Optional query condition to filter results based on document content. * @param {IncludeEnum[]} [params.include] - Optional array of fields to include in the result, such as "metadata" and "document". + * @param {IDs} [params.ids] - Optional IDs to filter on before querying. * * @returns {Promise} A promise that resolves to the query results. * @throws {Error} If there is an issue executing the query. @@ -280,6 +281,7 @@ export class Collection { include, queryTexts, queryEmbeddings, + ids, }: QueryRecordsParams): Promise { await this.client.init(); @@ -300,6 +302,11 @@ export class Collection { ); } + let filter_ids: string[] | null = null; + if (ids) { + filter_ids = toArray(ids); + } + const resp = await this.client.api.collectionQuery( this.client.tenant, this.client.database, @@ -308,6 +315,7 @@ export class Collection { undefined, { query_embeddings: embeddings, + ids: filter_ids, n_results: nResults, where, where_document: whereDocument, diff --git a/clients/js/packages/chromadb-core/src/types.ts b/clients/js/packages/chromadb-core/src/types.ts index d688447679d..a99b6de0803 100644 --- a/clients/js/packages/chromadb-core/src/types.ts +++ b/clients/js/packages/chromadb-core/src/types.ts @@ -219,6 +219,7 @@ export type ForkCollectionParams = { }; export type BaseQueryParams = { + ids?: ID | IDs; nResults?: PositiveInteger; where?: Where; queryTexts?: string | string[]; diff --git a/clients/js/packages/chromadb-core/test/query.collection.test.ts b/clients/js/packages/chromadb-core/test/query.collection.test.ts index d70a0e45c30..e08e24d5a38 100644 --- a/clients/js/packages/chromadb-core/test/query.collection.test.ts +++ b/clients/js/packages/chromadb-core/test/query.collection.test.ts @@ -230,4 +230,247 @@ describe("query records", () => { await collection.query({ queryEmbeddings: [1, 2, 3] }); }).rejects.toThrow(ChromaNotFoundError); }); + + test("it should query a collection with specific IDs", async () => { + const collection = await client.createCollection({ + name: "test", + embeddingFunction: new TestEmbeddingFunction(), + }); + await collection.add({ + ids: IDS, + embeddings: EMBEDDINGS, + metadatas: METADATAS, + documents: DOCUMENTS, + }); + + const results = await collection.query({ + queryEmbeddings: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + nResults: 3, + ids: ["test1", "test3"], + }); + + expect(results).toBeDefined(); + expect(results.ids[0]).toHaveLength(2); + expect(results.ids[0]).toEqual(expect.arrayContaining(["test1", "test3"])); + expect(results.ids[0]).not.toContain("test2"); + + expect(results.documents[0]).toEqual( + expect.arrayContaining(["This is a test", "This is a third test"]), + ); + expect(results.metadatas[0]).toEqual( + expect.arrayContaining([ + { test: "test1", float_value: -2 }, + { test: "test3", float_value: 2 }, + ]), + ); + }); +}); + +describe("id filtering", () => { + const client = new ChromaClient({ + path: process.env.DEFAULT_CHROMA_INSTANCE_URL, + }); + + beforeEach(async () => { + await client.reset(); + }); + + test("it should filter by IDs in a small dataset", async () => { + const collection = await client.createCollection({ + name: "test_id_filtering_small", + }); + + const numVectors = 100; + const dim = 10; + const smallRecords: number[][] = []; + const ids: string[] = []; + + for (let i = 0; i < numVectors; i++) { + const embedding = Array.from({ length: dim }, () => Math.random()); + smallRecords.push(embedding); + ids.push(`id_${i}`); + } + + await collection.add({ + ids: ids, + embeddings: smallRecords, + }); + + const queryIds = ids.filter((_, i) => i % 10 === 0); + + const queryEmbedding = Array.from({ length: dim }, () => Math.random()); + const results = await collection.query({ + queryEmbeddings: queryEmbedding, + ids: queryIds, + nResults: numVectors, + }); + + const allReturnedIds = results.ids[0]; + allReturnedIds.forEach((id) => { + expect(queryIds).toContain(id); + }); + }); + + test("it should filter by IDs in a medium dataset", async () => { + const collection = await client.createCollection({ + name: "test_id_filtering_medium", + }); + + const numVectors = 1000; + const dim = 10; + const mediumRecords: number[][] = []; + const ids: string[] = []; + + for (let i = 0; i < numVectors; i++) { + const embedding = Array.from({ length: dim }, () => Math.random()); + mediumRecords.push(embedding); + ids.push(`id_${i}`); + } + + await collection.add({ + ids: ids, + embeddings: mediumRecords, + }); + + const queryIds = ids.filter((_, i) => i % 10 === 0); + + const queryEmbedding = Array.from({ length: dim }, () => Math.random()); + const results = await collection.query({ + queryEmbeddings: queryEmbedding, + ids: queryIds, + nResults: numVectors, + }); + + const allReturnedIds = results.ids[0]; + allReturnedIds.forEach((id) => { + expect(queryIds).toContain(id); + }); + + const multiQueryEmbeddings = [ + Array.from({ length: dim }, () => Math.random()), + Array.from({ length: dim }, () => Math.random()), + Array.from({ length: dim }, () => Math.random()), + ]; + + const multiResults = await collection.query({ + queryEmbeddings: multiQueryEmbeddings, + ids: queryIds, + nResults: 10, + }); + + expect(multiResults.ids.length).toBe(multiQueryEmbeddings.length); + multiResults.ids.forEach((idSet) => { + idSet.forEach((id) => { + expect(queryIds).toContain(id); + }); + }); + }); + + test("it should handle ID filtering with deleted and upserted IDs", async () => { + const collection = await client.createCollection({ + name: "test_id_filtering_e2e", + }); + + const dim = 10; + const numVectors = 100; + const embeddings: number[][] = []; + const ids: string[] = []; + const metadatas: Record[] = []; + + for (let i = 0; i < numVectors; i++) { + const embedding = Array.from({ length: dim }, () => Math.random()); + embeddings.push(embedding); + ids.push(`id_${i}`); + metadatas.push({ index: i }); + } + + await collection.add({ + embeddings: embeddings, + ids: ids, + metadatas: metadatas, + }); + + const idsToDelete = ids.slice(10, 30); + await collection.delete({ ids: idsToDelete }); + + const idsToUpsertExisting = ids.slice(30, 50); + const idsToUpsertNew = Array.from( + { length: 20 }, + (_, i) => `id_${numVectors + i}`, + ); + + const upsertEmbeddings: number[][] = []; + const upsertMetadatas: Record[] = []; + + for ( + let i = 0; + i < idsToUpsertExisting.length + idsToUpsertNew.length; + i++ + ) { + const embedding = Array.from({ length: dim }, () => Math.random()); + upsertEmbeddings.push(embedding); + upsertMetadatas.push({ index: i, upserted: true }); + } + + await collection.upsert({ + embeddings: upsertEmbeddings, + ids: [...idsToUpsertExisting, ...idsToUpsertNew], + metadatas: upsertMetadatas, + }); + + const validQueryIds = [ + ...ids.slice(5, 10), + ...ids.slice(35, 45), + ...idsToUpsertNew.slice(5, 15), + ]; + + const queryEmbedding = Array.from({ length: dim }, () => Math.random()); + const results = await collection.query({ + queryEmbeddings: queryEmbedding, + ids: validQueryIds, + nResults: validQueryIds.length, + include: [IncludeEnum.Metadatas], + }); + + const allReturnedIds = results.ids[0]; + + allReturnedIds.forEach((id) => { + expect(validQueryIds).toContain(id); + }); + + // Verify upserted IDs have updated metadata + results.ids[0].forEach((id, idx) => { + if (idsToUpsertExisting.includes(id) || idsToUpsertNew.includes(id)) { + const metadata = results.metadatas?.[0]?.[idx]; + if (metadata) { + expect(metadata.upserted).toBe(true); + } + } + }); + + // Test querying a specific upserted ID + const upsertedId = idsToUpsertExisting[0]; + const upsertResults = await collection.query({ + queryEmbeddings: queryEmbedding, + ids: upsertedId, + nResults: 1, + include: [IncludeEnum.Metadatas], + }); + + const firstMetadata = upsertResults.metadatas?.[0]?.[0]; + expect(firstMetadata).toBeTruthy(); + if (firstMetadata) { + expect(firstMetadata.upserted).toBe(true); + } + + const deletedId = idsToDelete[0]; + await expect(async () => { + await collection.query({ + queryEmbeddings: queryEmbedding, + ids: deletedId, + nResults: 1, + include: [IncludeEnum.Metadatas], + }); + }).rejects.toThrow(); + }); }); diff --git a/docs/docs.trychroma.com/markdoc/content/docs/querying-collections/query-and-get.md b/docs/docs.trychroma.com/markdoc/content/docs/querying-collections/query-and-get.md index 55b06adb977..968d009ff70 100644 --- a/docs/docs.trychroma.com/markdoc/content/docs/querying-collections/query-and-get.md +++ b/docs/docs.trychroma.com/markdoc/content/docs/querying-collections/query-and-get.md @@ -12,7 +12,8 @@ collection.query( query_embeddings=[[11.1, 12.1, 13.1],[1.1, 2.3, 3.2], ...], n_results=10, where={"metadata_field": "is_equal_to_this"}, - where_document={"$contains":"search_string"} + where_document={"$contains":"search_string"}, + ids=["id1", "id2", ...] ) ``` {% /Tab %} @@ -23,6 +24,8 @@ const result = await collection.query({ queryEmbeddings: [[11.1, 12.1, 13.1],[1.1, 2.3, 3.2], ...], nResults: 10, where: {"metadata_field": "is_equal_to_this"}, + whereDocument: {"$contains": "search_string"}, + ids: ["id1", "id2", ...] }) ``` {% /Tab %} @@ -32,6 +35,7 @@ const result = await collection.query({ The query will return the `n results` closest matches to each `query embedding`, in order. An optional `where` filter dictionary can be supplied to filter by the `metadata` associated with each document. Additionally, an optional `where document` filter dictionary can be supplied to filter by contents of the document. +An optional `ids` list can be provided to filter results to only include documents with those specific IDs before performing the query. If the supplied `query embeddings` are not the same dimension as the collection, an exception will be raised. @@ -45,7 +49,8 @@ collection.query( query_texts=["doc10", "thus spake zarathustra", ...], n_results=10, where={"metadata_field": "is_equal_to_this"}, - where_document={"$contains":"search_string"} + where_document={"$contains":"search_string"}, + ids=["id1", "id2", ...] ) ``` {% /Tab %} @@ -56,7 +61,8 @@ await collection.query({ queryTexts: ["doc10", "thus spake zarathustra", ...], nResults: 10, where: {"metadata_field": "is_equal_to_this"}, - whereDocument: {"$contains": "search_string"} + whereDocument: {"$contains": "search_string"}, + ids: ["id1", "id2", ...] }) ``` {% /Tab %} diff --git a/rust/frontend/src/quota/mod.rs b/rust/frontend/src/quota/mod.rs index 686d88c8f00..b4ffbe4b566 100644 --- a/rust/frontend/src/quota/mod.rs +++ b/rust/frontend/src/quota/mod.rs @@ -82,6 +82,7 @@ pub struct QuotaPayload<'other> { pub offset: Option, pub n_results: Option, pub query_embeddings: Option<&'other [Vec]>, + pub query_ids: Option<&'other [String]>, pub collection_uuid: Option, } @@ -107,6 +108,7 @@ impl<'other> QuotaPayload<'other> { offset: None, n_results: None, query_embeddings: None, + query_ids: None, collection_uuid: None, } } @@ -203,6 +205,11 @@ impl<'other> QuotaPayload<'other> { self } + pub fn with_query_ids(mut self, query_ids: &'other [String]) -> Self { + self.query_ids = Some(query_ids); + self + } + pub fn with_collection_uuid(mut self, collection_uuid: CollectionUuid) -> Self { self.collection_uuid = Some(collection_uuid); self @@ -232,6 +239,7 @@ pub enum UsageType { CollectionSizeRecords, // Number of records in the collection NumCollections, // Total number of collections for a tenant NumDatabases, // Total number of databases for a tenant + NumQueryIDs, // Number of IDs to filter by in a query } impl fmt::Display for UsageType { @@ -260,6 +268,7 @@ impl fmt::Display for UsageType { UsageType::CollectionSizeRecords => write!(f, "Collection size (records)"), UsageType::NumCollections => write!(f, "Number of collections"), UsageType::NumDatabases => write!(f, "Number of databases"), + UsageType::NumQueryIDs => write!(f, "Number of IDs to filter by in a query"), } } } @@ -288,6 +297,7 @@ impl TryFrom<&str> for UsageType { "collection_size_records" => Ok(UsageType::CollectionSizeRecords), "num_collections" => Ok(UsageType::NumCollections), "num_databases" => Ok(UsageType::NumDatabases), + "num_query_ids" => Ok(UsageType::NumQueryIDs), _ => Err(format!("Invalid UsageType: {}", value)), } } @@ -315,6 +325,7 @@ lazy_static::lazy_static! { m.insert(UsageType::CollectionSizeRecords, 1_000_000); m.insert(UsageType::NumCollections, 1_000_000); m.insert(UsageType::NumDatabases, 10); + m.insert(UsageType::NumQueryIDs, 1000); m }; } diff --git a/rust/frontend/src/server.rs b/rust/frontend/src/server.rs index 99cc896c87f..18f20ee6c4e 100644 --- a/rust/frontend/src/server.rs +++ b/rust/frontend/src/server.rs @@ -1808,6 +1808,9 @@ async fn collection_query( if let Some(n_results) = payload.n_results { quota_payload = quota_payload.with_n_results(n_results); } + if let Some(ids) = &payload.ids { + quota_payload = quota_payload.with_query_ids(ids); + } server.quota_enforcer.enforce("a_payload).await?; tracing::info!( "Querying records from collection [{collection_id}] in database [{database}] for tenant [{tenant}]", diff --git a/rust/python_bindings/src/bindings.rs b/rust/python_bindings/src/bindings.rs index 2a15cf9ec00..05f461968b3 100644 --- a/rust/python_bindings/src/bindings.rs +++ b/rust/python_bindings/src/bindings.rs @@ -615,12 +615,13 @@ impl Bindings { } #[pyo3( - signature = (collection_id, query_embeddings, n_results, r#where = None, where_document = None, include = ["metadatas".to_string(), "documents".to_string()].to_vec(), tenant = DEFAULT_TENANT.to_string(), database = DEFAULT_DATABASE.to_string()) + signature = (collection_id, ids, query_embeddings, n_results, r#where = None, where_document = None, include = ["metadatas".to_string(), "documents".to_string()].to_vec(), tenant = DEFAULT_TENANT.to_string(), database = DEFAULT_DATABASE.to_string()) )] #[allow(clippy::too_many_arguments)] fn query( &self, collection_id: String, + ids: Option>, query_embeddings: Vec>, n_results: u32, r#where: Option, @@ -646,7 +647,7 @@ impl Bindings { tenant, database, collection_id, - None, + ids, r#where, query_embeddings, n_results,