Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions chromadb/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ def _query(
self,
collection_id: UUID,
query_embeddings: Embeddings,
ids: Optional[IDs] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add to docstring(?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added, thanks

n_results: int = 10,
where: Optional[Where] = None,
where_document: Optional[WhereDocument] = None,
Expand All @@ -280,6 +281,7 @@ def _query(
Args:
collection_id: The UUID of the collection to query.
query_embeddings: The embeddings to use as the query.
ids: The IDs to filter by during the query. Defaults to None.
n_results: The number of results to return. Defaults to 10.
where: Conditional filtering on metadata. Defaults to None.
where_document: Conditional filtering on documents. Defaults to None.
Expand Down Expand Up @@ -734,6 +736,7 @@ def _query(
self,
collection_id: UUID,
query_embeddings: Embeddings,
ids: Optional[IDs] = None,
n_results: int = 10,
where: Optional[Where] = None,
where_document: Optional[WhereDocument] = None,
Expand Down
2 changes: 2 additions & 0 deletions chromadb/api/async_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ async def _query(
self,
collection_id: UUID,
query_embeddings: Embeddings,
ids: Optional[IDs] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

docstring(?)

n_results: int = 10,
where: Optional[Where] = None,
where_document: Optional[WhereDocument] = None,
Expand Down Expand Up @@ -728,6 +729,7 @@ async def _query(
self,
collection_id: UUID,
query_embeddings: Embeddings,
ids: Optional[IDs] = None,
n_results: int = 10,
where: Optional[Where] = None,
where_document: Optional[WhereDocument] = None,
Expand Down
2 changes: 2 additions & 0 deletions chromadb/api/async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,7 @@ async def _query(
self,
collection_id: UUID,
query_embeddings: Embeddings,
ids: Optional[IDs] = None,
n_results: int = 10,
where: Optional[Where] = None,
where_document: Optional[WhereDocument] = None,
Expand All @@ -411,6 +412,7 @@ async def _query(
return await self._server._query(
collection_id=collection_id,
query_embeddings=query_embeddings,
ids=ids,
n_results=n_results,
where=where,
where_document=where_document,
Expand Down
2 changes: 2 additions & 0 deletions chromadb/api/async_fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,7 @@ async def _query(
self,
collection_id: UUID,
query_embeddings: Embeddings,
ids: Optional[IDs] = None,
n_results: int = 10,
where: Optional[Where] = None,
where_document: Optional[WhereDocument] = None,
Expand All @@ -631,6 +632,7 @@ async def _query(
"post",
f"/tenants/{tenant}/databases/{database}/collections/{collection_id}/query",
json={
"ids": ids,
"query_embeddings": convert_np_embeddings_to_list(query_embeddings)
if query_embeddings is not None
else None,
Expand Down
2 changes: 2 additions & 0 deletions chromadb/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,13 +374,15 @@ def _query(
self,
collection_id: UUID,
query_embeddings: Embeddings,
ids: Optional[IDs] = None,
n_results: int = 10,
where: Optional[Where] = None,
where_document: Optional[WhereDocument] = None,
include: Include = IncludeMetadataDocumentsDistances,
) -> QueryResult:
return self._server._query(
collection_id=collection_id,
ids=ids,
tenant=self.tenant,
database=self.database,
query_embeddings=query_embeddings,
Expand Down
2 changes: 2 additions & 0 deletions chromadb/api/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,6 +588,7 @@ def _query(
self,
collection_id: UUID,
query_embeddings: Embeddings,
ids: Optional[IDs] = None,
n_results: int = 10,
where: Optional[Where] = None,
where_document: Optional[WhereDocument] = None,
Expand All @@ -603,6 +604,7 @@ def _query(
"post",
f"/tenants/{tenant}/databases/{database}/collections/{collection_id}/query",
json={
"ids": ids,
"query_embeddings": convert_np_embeddings_to_list(query_embeddings)
if query_embeddings is not None
else None,
Expand Down
6 changes: 5 additions & 1 deletion chromadb/api/models/AsyncCollection.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ async def query(
query_texts: Optional[OneOrMany[Document]] = None,
query_images: Optional[OneOrMany[Image]] = None,
query_uris: Optional[OneOrMany[URI]] = None,
ids: Optional[OneOrMany[ID]] = None,
n_results: int = 10,
where: Optional[Where] = None,
where_document: Optional[WhereDocument] = None,
Expand All @@ -184,6 +185,7 @@ async def query(
query_embeddings: The embeddings to get the closes neighbors of. Optional.
query_texts: The document texts to get the closes neighbors of. Optional.
query_images: The images to get the closes neighbors of. Optional.
ids: A subset of ids to search within. Optional.
n_results: The number of neighbors to return for each query_embedding or query_texts. Optional.
where: A Where type dict used to filter results by. E.g. `{"$and": [{"color" : "red"}, {"price": {"$gte": 4.20}}]}`. Optional.
where_document: A WhereDocument type dict used to filter by the documents. E.g. `{"$contains": "hello"}`. Optional.
Expand All @@ -205,6 +207,7 @@ async def query(
query_texts=query_texts,
query_images=query_images,
query_uris=query_uris,
ids=ids,
n_results=n_results,
where=where,
where_document=where_document,
Expand All @@ -213,6 +216,7 @@ async def query(

query_results = await self._client._query(
collection_id=self.id,
ids=query_request["ids"],
query_embeddings=query_request["embeddings"],
n_results=query_request["n_results"],
where=query_request["where"],
Expand Down Expand Up @@ -279,7 +283,7 @@ async def fork(
client=self._client,
model=model,
embedding_function=self._embedding_function,
data_loader=self._data_loader
data_loader=self._data_loader,
)

async def update(
Expand Down
6 changes: 5 additions & 1 deletion chromadb/api/models/Collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ def query(
query_texts: Optional[OneOrMany[Document]] = None,
query_images: Optional[OneOrMany[Image]] = None,
query_uris: Optional[OneOrMany[URI]] = None,
ids: Optional[OneOrMany[ID]] = None,
n_results: int = 10,
where: Optional[Where] = None,
where_document: Optional[WhereDocument] = None,
Expand All @@ -188,6 +189,7 @@ def query(
query_texts: The document texts to get the closes neighbors of. Optional.
query_images: The images to get the closes neighbors of. Optional.
query_uris: The URIs to be used with data loader. Optional.
ids: A subset of ids to search within. Optional.
n_results: The number of neighbors to return for each query_embedding or query_texts. Optional.
where: A Where type dict used to filter results by. E.g. `{"$and": [{"color" : "red"}, {"price": {"$gte": 4.20}}]}`. Optional.
where_document: A WhereDocument type dict used to filter by the documents. E.g. `{"$contains": "hello"}`. Optional.
Expand All @@ -209,6 +211,7 @@ def query(
query_texts=query_texts,
query_images=query_images,
query_uris=query_uris,
ids=ids,
n_results=n_results,
where=where,
where_document=where_document,
Expand All @@ -217,6 +220,7 @@ def query(

query_results = self._client._query(
collection_id=self.id,
ids=query_request["ids"],
query_embeddings=query_request["embeddings"],
n_results=query_request["n_results"],
where=query_request["where"],
Expand Down Expand Up @@ -285,7 +289,7 @@ def fork(
client=self._client,
model=model,
embedding_function=self._embedding_function,
data_loader=self._data_loader
data_loader=self._data_loader,
)

def update(
Expand Down
4 changes: 4 additions & 0 deletions chromadb/api/models/CollectionCommon.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,7 @@ def _validate_and_prepare_query_request(
query_texts: Optional[OneOrMany[Document]],
query_images: Optional[OneOrMany[Image]],
query_uris: Optional[OneOrMany[URI]],
ids: Optional[OneOrMany[ID]],
n_results: int,
where: Optional[Where],
where_document: Optional[WhereDocument],
Expand All @@ -307,6 +308,8 @@ def _validate_and_prepare_query_request(
uris=query_uris,
)

filter_ids = maybe_cast_one_to_many(ids)

filters = FilterSet(
where=where,
where_document=where_document,
Expand Down Expand Up @@ -335,6 +338,7 @@ def _validate_and_prepare_query_request(

return QueryRequest(
embeddings=request_embeddings,
ids=filter_ids,
where=request_where,
where_document=request_where_document,
include=request_include,
Expand Down
4 changes: 4 additions & 0 deletions chromadb/api/rust.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,7 @@ def _query(
self,
collection_id: UUID,
query_embeddings: Embeddings,
ids: Optional[IDs] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add to telemetry capture call below? just observing the pattern

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added, thanks

n_results: int = 10,
where: Optional[Where] = None,
where_document: Optional[WhereDocument] = None,
Expand All @@ -490,10 +491,12 @@ def _query(
database: str = DEFAULT_DATABASE,
) -> QueryResult:
query_amount = len(query_embeddings)
filtered_ids_amount = len(ids) if ids else 0
self.product_telemetry_client.capture(
CollectionQueryEvent(
collection_uuid=str(collection_id),
query_amount=query_amount,
filtered_ids_amount=filtered_ids_amount,
n_results=n_results,
with_metadata_filter=query_amount if where is not None else 0,
with_document_filter=query_amount if where_document is not None else 0,
Expand All @@ -506,6 +509,7 @@ def _query(

rust_response = self.bindings.query(
str(collection_id),
ids,
query_embeddings,
n_results,
json.dumps(where) if where else None,
Expand Down
3 changes: 3 additions & 0 deletions chromadb/api/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -785,6 +785,7 @@ def _query(
self,
collection_id: UUID,
query_embeddings: Embeddings,
ids: Optional[IDs] = None,
n_results: int = 10,
where: Optional[Where] = None,
where_document: Optional[WhereDocument] = None,
Expand All @@ -801,10 +802,12 @@ def _query(
)

query_amount = len(query_embeddings)
ids_amount = len(ids) if ids else 0
self._product_telemetry_client.capture(
CollectionQueryEvent(
collection_uuid=str(collection_id),
query_amount=query_amount,
filtered_ids_amount=ids_amount,
n_results=n_results,
with_metadata_filter=query_amount if where is not None else 0,
with_document_filter=query_amount if where_document is not None else 0,
Expand Down
1 change: 1 addition & 0 deletions chromadb/api/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,7 @@ class GetResult(TypedDict):

class QueryRequest(TypedDict):
embeddings: Embeddings
ids: Optional[IDs]
where: Optional[Where]
where_document: Optional[WhereDocument]
include: Include
Expand Down
4 changes: 4 additions & 0 deletions chromadb/telemetry/product/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ class CollectionQueryEvent(ProductTelemetryEvent):
batch_size: int
collection_uuid: str
query_amount: int
filtered_ids_amount: int
with_metadata_filter: int
with_document_filter: int
n_results: int
Expand All @@ -149,6 +150,7 @@ def __init__(
self,
collection_uuid: str,
query_amount: int,
filtered_ids_amount: int,
with_metadata_filter: int,
with_document_filter: int,
n_results: int,
Expand All @@ -161,6 +163,7 @@ def __init__(
super().__init__()
self.collection_uuid = collection_uuid
self.query_amount = query_amount
self.filtered_ids_amount = filtered_ids_amount
self.with_metadata_filter = with_metadata_filter
self.with_document_filter = with_document_filter
self.n_results = n_results
Expand All @@ -182,6 +185,7 @@ def batch(self, other: "ProductTelemetryEvent") -> "CollectionQueryEvent":
return CollectionQueryEvent(
collection_uuid=self.collection_uuid,
query_amount=total_amount,
filtered_ids_amount=self.filtered_ids_amount + other.filtered_ids_amount,
with_metadata_filter=self.with_metadata_filter + other.with_metadata_filter,
with_document_filter=self.with_document_filter + other.with_document_filter,
n_results=self.n_results + other.n_results,
Expand Down
Loading
Loading