Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ members = [
[workspace.dependencies]
arrow = "52.2.0"
async-trait = "0.1"
axum = { version = "0.8", features = ["macros"] }
axum = { version = "0.8", features = ["macros", "http1", "http2"] }
bytes = "1.10"
chrono = { version = "0.4", features = ["serde"] }
clap = { version = "4", features = ["derive"] }
Expand Down
10 changes: 10 additions & 0 deletions chromadb/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,11 @@ def get_user_identity(self) -> UserIdentity:
"""
pass

@abstractmethod
def close(self) -> None:
"""Close the client, releasing all resources."""
pass


class ClientAPI(BaseAPI, ABC):
tenant: str
Expand Down Expand Up @@ -559,6 +564,11 @@ def get_tenant(self, name: str) -> Tenant:
"""
pass

@abstractmethod
def close(self) -> None:
"""Close the client, releasing all resources."""
pass


class ServerAPI(BaseAPI, AdminAPI, Component):
"""An API instance that extends the relevant Base API methods by passing
Expand Down
10 changes: 10 additions & 0 deletions chromadb/api/async_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,11 @@ async def get_user_identity(self) -> UserIdentity:
"""
pass

@abstractmethod
async def close(self) -> None:
"""Close the client, releasing all resources."""
pass


class AsyncClientAPI(AsyncBaseAPI, ABC):
tenant: str
Expand Down Expand Up @@ -553,6 +558,11 @@ async def get_tenant(self, name: str) -> Tenant:
"""
pass

@abstractmethod
async def close(self) -> None:
"""Close the client, releasing all resources."""
pass


class AsyncServerAPI(AsyncBaseAPI, AsyncAdminAPI, Component):
"""An API instance that extends the relevant Base API methods by passing
Expand Down
9 changes: 9 additions & 0 deletions chromadb/api/async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,11 @@ def get_settings(self) -> Settings:
async def get_max_batch_size(self) -> int:
return await self._server.get_max_batch_size()

@override
async def close(self) -> None:
await self._server.close()
await self._admin_client.close()

# endregion


Expand Down Expand Up @@ -485,3 +490,7 @@ def from_system(
SharedSystemClient._populate_data_from_system(system)
instance = cls(settings=system.settings)
return instance

@override
async def close(self) -> None:
await self._server.close()
54 changes: 54 additions & 0 deletions chromadb/api/async_fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,10 @@ async def _cleanup(self) -> None:
async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
await self._cleanup()

@override
def start(self) -> None:
super().start()

@override
def stop(self) -> None:
super().stop()
Expand Down Expand Up @@ -152,6 +156,7 @@ async def _make_request(
@trace_method("AsyncFastAPI.heartbeat", OpenTelemetryGranularity.OPERATION)
@override
async def heartbeat(self) -> int:
self.raise_if_stopped()
response = await self._make_request("get", "")
return int(response["nanosecond heartbeat"])

Expand All @@ -162,6 +167,7 @@ async def create_database(
name: str,
tenant: str = DEFAULT_TENANT,
) -> None:
self.raise_if_stopped()
await self._make_request(
"post",
f"/tenants/{tenant}/databases",
Expand All @@ -175,6 +181,7 @@ async def get_database(
name: str,
tenant: str = DEFAULT_TENANT,
) -> Database:
self.raise_if_stopped()
response = await self._make_request(
"get",
f"/tenants/{tenant}/databases/{name}",
Expand All @@ -192,6 +199,7 @@ async def delete_database(
name: str,
tenant: str = DEFAULT_TENANT,
) -> None:
self.raise_if_stopped()
await self._make_request(
"delete",
f"/tenants/{tenant}/databases/{name}",
Expand All @@ -205,6 +213,7 @@ async def list_databases(
offset: Optional[int] = None,
tenant: str = DEFAULT_TENANT,
) -> Sequence[Database]:
self.raise_if_stopped()
response = await self._make_request(
"get",
f"/tenants/{tenant}/databases",
Expand All @@ -224,6 +233,7 @@ async def list_databases(
@trace_method("AsyncFastAPI.create_tenant", OpenTelemetryGranularity.OPERATION)
@override
async def create_tenant(self, name: str) -> None:
self.raise_if_stopped()
await self._make_request(
"post",
"/tenants",
Expand All @@ -233,6 +243,7 @@ async def create_tenant(self, name: str) -> None:
@trace_method("AsyncFastAPI.get_tenant", OpenTelemetryGranularity.OPERATION)
@override
async def get_tenant(self, name: str) -> Tenant:
self.raise_if_stopped()
resp_json = await self._make_request(
"get",
"/tenants/" + name,
Expand All @@ -243,6 +254,7 @@ async def get_tenant(self, name: str) -> Tenant:
@trace_method("AsyncFastAPI.get_user_identity", OpenTelemetryGranularity.OPERATION)
@override
async def get_user_identity(self) -> UserIdentity:
self.raise_if_stopped()
return UserIdentity(**(await self._make_request("get", "/auth/identity")))

@trace_method("AsyncFastAPI.list_collections", OpenTelemetryGranularity.OPERATION)
Expand Down Expand Up @@ -275,6 +287,7 @@ async def list_collections(
async def count_collections(
self, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE
) -> int:
self.raise_if_stopped()
resp_json = await self._make_request(
"get",
f"/tenants/{tenant}/databases/{database}/collections_count",
Expand All @@ -294,6 +307,7 @@ async def create_collection(
database: str = DEFAULT_DATABASE,
) -> CollectionModel:
"""Creates a collection"""
self.raise_if_stopped()
config_json = (
create_collection_configuration_to_json(configuration)
if configuration
Expand Down Expand Up @@ -321,6 +335,7 @@ async def get_collection(
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> CollectionModel:
self.raise_if_stopped()
resp_json = await self._make_request(
"get",
f"/tenants/{tenant}/databases/{database}/collections/{name}",
Expand All @@ -342,6 +357,7 @@ async def get_or_create_collection(
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> CollectionModel:
self.raise_if_stopped()
return await self.create_collection(
name=name,
configuration=configuration,
Expand All @@ -362,6 +378,7 @@ async def _modify(
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> None:
self.raise_if_stopped()
await self._make_request(
"put",
f"/tenants/{tenant}/databases/{database}/collections/{id}",
Expand All @@ -385,6 +402,7 @@ async def _fork(
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> CollectionModel:
self.raise_if_stopped()
resp_json = await self._make_request(
"post",
f"/tenants/{tenant}/databases/{database}/collections/{collection_id}/fork",
Expand All @@ -401,6 +419,7 @@ async def delete_collection(
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> None:
self.raise_if_stopped()
await self._make_request(
"delete",
f"/tenants/{tenant}/databases/{database}/collections/{name}",
Expand All @@ -415,6 +434,7 @@ async def _count(
database: str = DEFAULT_DATABASE,
) -> int:
"""Returns the number of embeddings in the database"""
self.raise_if_stopped()
resp_json = await self._make_request(
"get",
f"/tenants/{tenant}/databases/{database}/collections/{collection_id}/count",
Expand All @@ -431,6 +451,7 @@ async def _peek(
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> GetResult:
self.raise_if_stopped()
resp = await self._get(
collection_id,
tenant=tenant,
Expand Down Expand Up @@ -492,6 +513,7 @@ async def _delete(
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> None:
self.raise_if_stopped()
await self._make_request(
"post",
f"/tenants/{tenant}/databases/{database}/collections/{collection_id}/delete",
Expand All @@ -514,6 +536,7 @@ async def _submit_batch(
"""
Submits a batch of embeddings to the database
"""
self.raise_if_stopped()
return await self._make_request(
"post",
url,
Expand All @@ -539,6 +562,7 @@ async def _add(
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> bool:
self.raise_if_stopped()
batch = (
ids,
convert_np_embeddings_to_list(embeddings),
Expand Down Expand Up @@ -566,6 +590,7 @@ async def _update(
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> bool:
self.raise_if_stopped()
batch = (
ids,
convert_np_embeddings_to_list(embeddings)
Expand Down Expand Up @@ -597,6 +622,7 @@ async def _upsert(
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> bool:
self.raise_if_stopped()
batch = (
ids,
convert_np_embeddings_to_list(embeddings),
Expand Down Expand Up @@ -624,6 +650,7 @@ async def _query(
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> QueryResult:
self.raise_if_stopped()
# Servers do not support the "data" include, as that is hydrated on the client side
filtered_include = [i for i in include if i != "data"]

Expand Down Expand Up @@ -655,23 +682,50 @@ async def _query(
@trace_method("AsyncFastAPI.reset", OpenTelemetryGranularity.ALL)
@override
async def reset(self) -> bool:
self.raise_if_stopped()
resp_json = await self._make_request("post", "/reset")
return cast(bool, resp_json)

@trace_method("AsyncFastAPI.get_version", OpenTelemetryGranularity.OPERATION)
@override
async def get_version(self) -> str:
self.raise_if_stopped()
resp_json = await self._make_request("get", "/version")
return cast(str, resp_json)

@override
def get_settings(self) -> Settings:
self.raise_if_stopped()
return self._settings

@trace_method("AsyncFastAPI.get_max_batch_size", OpenTelemetryGranularity.OPERATION)
@override
async def get_max_batch_size(self) -> int:
self.raise_if_stopped()
if self._max_batch_size == -1:
resp_json = await self._make_request("get", "/pre-flight-checks")
self._max_batch_size = cast(int, resp_json["max_batch_size"])
return self._max_batch_size

@trace_method("AsyncFastAPI.close", OpenTelemetryGranularity.OPERATION)
@override
async def close(self) -> None:
# If already stopped, don't raise exception
if not self.is_running():
return

loop_hash = None
try:
loop = asyncio.get_event_loop()
loop_hash = loop.__hash__()
except RuntimeError:
loop_hash = 0

if loop_hash in self._clients:
client_to_close = self._clients.pop(loop_hash)
await client_to_close.aclose()
elif loop_hash == 0 and 0 in self._clients: # Handle fallback case
client_to_close = self._clients.pop(0)
await client_to_close.aclose()

super().stop()
24 changes: 24 additions & 0 deletions chromadb/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,22 @@ def get_settings(self) -> Settings:
def get_max_batch_size(self) -> int:
return self._server.get_max_batch_size()

@override
def close(self) -> None:
# Close methods can be called multiple times, so we need to handle the case
# where one of these components is already stopped
try:
self._server.close()
except RuntimeError:
# Server may already be closed
pass

try:
self._admin_client.close()
except RuntimeError:
# Admin client may already be closed
pass

# endregion

# region ClientAPI Methods
Expand Down Expand Up @@ -491,3 +507,11 @@ def from_system(
SharedSystemClient._populate_data_from_system(system)
instance = cls(settings=system.settings)
return instance

@override
def close(self) -> None:
try:
self._server.close()
except RuntimeError:
# Server may already be closed
pass
Loading
Loading