Skip to content

Commit

Permalink
Merge remote-tracking branch 'chroma/main' into fix/dependency-mgmt
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffchuber committed Nov 3, 2023
2 parents f25236c + cdcafc8 commit dff2238
Show file tree
Hide file tree
Showing 59 changed files with 3,271 additions and 835 deletions.
16 changes: 14 additions & 2 deletions .github/workflows/chroma-release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ name: Chroma Release
on:
push:
tags:
- '*'
- "*"
branches:
- main

Expand Down Expand Up @@ -43,7 +43,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.10'
python-version: "3.10"
- name: Install Client Dev Dependencies
run: python -m pip install -r requirements_dev.txt
- name: Build Client
Expand Down Expand Up @@ -141,3 +141,15 @@ jobs:
artifacts: "dist/chroma-${{steps.version.outputs.version}}.tar.gz"
allowUpdates: true
prerelease: true
- name: Trigger Hosted Chroma Release
uses: actions/github-script@v6
with:
github-token: ${{ secrets.HOSTED_CHROMA_WORKFLOW_DISPATCH_TOKEN }}
script: |
const result = await github.rest.actions.createWorkflowDispatch({
owner: 'chroma-core',
repo: 'hosted-chroma',
workflow_id: 'build-and-publish-image.yaml',
ref: 'main'
})
console.log(result)
6 changes: 4 additions & 2 deletions chromadb/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ def create_collection(
@abstractmethod
def get_collection(
self,
name: str,
name: Optional[str] = None,
id: Optional[UUID] = None,
embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(),
) -> Collection:
"""Get a collection with the given name.
Expand Down Expand Up @@ -496,7 +497,8 @@ def create_collection(
@override
def get_collection(
self,
name: str,
name: Optional[str] = None,
id: Optional[UUID] = None,
embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(),
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
Expand Down
4 changes: 3 additions & 1 deletion chromadb/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,10 +188,12 @@ def create_collection(
@override
def get_collection(
self,
name: str,
name: Optional[str] = None,
id: Optional[UUID] = None,
embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(),
) -> Collection:
return self._server.get_collection(
id=id,
name=name,
embedding_function=embedding_function,
tenant=self.tenant,
Expand Down
12 changes: 9 additions & 3 deletions chromadb/api/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,15 +250,21 @@ def create_collection(
@override
def get_collection(
self,
name: str,
name: Optional[str] = None,
id: Optional[UUID] = None,
embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(),
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> Collection:
"""Returns a collection"""
if (name is None and id is None) or (name is not None and id is not None):
raise ValueError("Name or id must be specified, but not both")

_params = {"tenant": tenant, "database": database}
if id is not None:
_params["type"] = str(id)
resp = self._session.get(
self._api_url + "/collections/" + name,
params={"tenant": tenant, "database": database},
self._api_url + "/collections/" + name if name else str(id), params=_params
)
raise_chroma_error(resp)
resp_json = resp.json()
Expand Down
8 changes: 7 additions & 1 deletion chromadb/api/models/Collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ class Collection(BaseModel):
name: str
id: UUID
metadata: Optional[CollectionMetadata] = None
tenant: Optional[str] = None
database: Optional[str] = None
_client: "ServerAPI" = PrivateAttr()
_embedding_function: Optional[EmbeddingFunction] = PrivateAttr()

Expand All @@ -49,9 +51,13 @@ def __init__(
name: str,
id: UUID,
embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(),
tenant: Optional[str] = None,
database: Optional[str] = None,
metadata: Optional[CollectionMetadata] = None,
):
super().__init__(name=name, metadata=metadata, id=id)
super().__init__(
name=name, metadata=metadata, id=id, tenant=tenant, database=database
)
self._client = client
self._embedding_function = embedding_function

Expand Down
17 changes: 14 additions & 3 deletions chromadb/api/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,8 @@ def create_collection(
name=name,
metadata=coll["metadata"], # type: ignore
embedding_function=embedding_function,
tenant=tenant,
database=database,
)

@trace_method(
Expand Down Expand Up @@ -214,13 +216,16 @@ def get_or_create_collection(
@override
def get_collection(
self,
name: str,
name: Optional[str] = None,
id: Optional[UUID] = None,
embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(),
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> Collection:
if id is None and name is None or (id is not None and name is not None):
raise ValueError("Name or id must be specified, but not both")
existing = self._sysdb.get_collections(
name=name, tenant=tenant, database=database
id=id, name=name, tenant=tenant, database=database
)

if existing:
Expand All @@ -230,6 +235,8 @@ def get_collection(
name=existing[0]["name"],
metadata=existing[0]["metadata"], # type: ignore
embedding_function=embedding_function,
tenant=tenant,
database=database,
)
else:
raise ValueError(f"Collection {name} does not exist.")
Expand All @@ -250,6 +257,8 @@ def list_collections(
id=db_collection["id"],
name=db_collection["name"],
metadata=db_collection["metadata"], # type: ignore
tenant=db_collection["tenant"],
database=db_collection["database"],
)
)
return collections
Expand Down Expand Up @@ -486,7 +495,9 @@ def _get(
embeddings=[r["embedding"] for r in vectors]
if "embeddings" in include
else None,
metadatas=_clean_metadatas(metadatas) if "metadatas" in include else None, # type: ignore
metadatas=_clean_metadatas(metadatas)
if "metadatas" in include
else None, # type: ignore
documents=documents if "documents" in include else None, # type: ignore
)

Expand Down
9 changes: 9 additions & 0 deletions chromadb/auth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,15 @@ def from_function_kwargs(**kwargs: Any) -> Callable[..., str]:
lambda **kwargs: kwargs["function_kwargs"][kwargs["arg_name"]], **kwargs
)

@staticmethod
def dict_from_function_kwargs(**kwargs: Any) -> Callable[..., Dict[str, Any]]:
return partial(
lambda **kwargs: {
k: kwargs["function_kwargs"][k] for k in kwargs["arg_names"]
},
**kwargs,
)


@dataclass
class AuthzAction:
Expand Down
4 changes: 4 additions & 0 deletions chromadb/auth/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ def decorator(f: Callable[..., Any]) -> Callable[..., Any]:
@wraps(f)
def wrapped(*args: Any, **kwargs: Dict[Any, Any]) -> Any:
_dynamic_kwargs = {
"api": args[0]._api,
"function": f,
"function_args": args,
"function_kwargs": kwargs,
Expand Down Expand Up @@ -190,6 +191,9 @@ def wrapped(*args: Any, **kwargs: Dict[Any, Any]) -> Any:
tenant=request.state.user_identity.get_user_tenant()
if hasattr(request.state, "user_identity")
else DEFAULT_TENANT,
attributes=request.state.user_identity.get_user_attributes()
if hasattr(request.state, "user_identity")
else {},
),
resource=_resource,
action=_action,
Expand Down
29 changes: 25 additions & 4 deletions chromadb/auth/fastapi_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from functools import partial
from typing import Any, Callable, Dict
from typing import Any, Callable, Dict, Optional, Sequence, cast
from chromadb.api import ServerAPI
from chromadb.auth import AuthzResourceTypes


Expand All @@ -25,7 +26,27 @@ def find_key_with_value_of_type(


def attr_from_resource_object(
type: AuthzResourceTypes, **kwargs: Any
type: AuthzResourceTypes,
additional_attrs: Optional[Sequence[str]] = None,
**kwargs: Any,
) -> Callable[..., Dict[str, Any]]:
def _wrap(**wkwargs: Any) -> Dict[str, Any]:
obj = find_key_with_value_of_type(type, **wkwargs)
if additional_attrs:
obj.update({k: wkwargs["function_kwargs"][k]
for k in additional_attrs})
return obj

return partial(_wrap, **kwargs)


def attr_from_collection_lookup(
collection_id_arg: str, **kwargs: Any
) -> Callable[..., Dict[str, Any]]:
obj = find_key_with_value_of_type(type, **kwargs)
return partial(lambda **kwargs: obj, **kwargs)
def _wrap(**kwargs: Any) -> Dict[str, Any]:
_api = cast(ServerAPI, kwargs["api"])
col = _api.get_collection(
id=kwargs["function_kwargs"][collection_id_arg])
return {"tenant": col.tenant, "database": col.database}

return partial(_wrap, **kwargs)
2 changes: 2 additions & 0 deletions chromadb/db/impl/grpc/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,8 @@ def CreateCollection(
metadata=from_proto_metadata(request.metadata),
dimension=request.dimension,
topic=self._assignment_policy.assign_collection(id),
database=database,
tenant=tenant,
)
collections[request.id] = new_collection
return CreateCollectionResponse(
Expand Down
15 changes: 14 additions & 1 deletion chromadb/db/mixins/sysdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,13 @@ def create_collection(

topic = self._assignment_policy.assign_collection(id)
collection = Collection(
id=id, topic=topic, name=name, metadata=metadata, dimension=dimension
id=id,
topic=topic,
name=name,
metadata=metadata,
dimension=dimension,
tenant=tenant,
database=database,
)

with self.tx() as cur:
Expand Down Expand Up @@ -379,6 +385,7 @@ def get_collections(

collections_t = Table("collections")
metadata_t = Table("collection_metadata")
databases_t = Table("databases")
q = (
self.querybuilder()
.from_(collections_t)
Expand All @@ -387,13 +394,17 @@ def get_collections(
collections_t.name,
collections_t.topic,
collections_t.dimension,
databases_t.name,
databases_t.tenant_id,
metadata_t.key,
metadata_t.str_value,
metadata_t.int_value,
metadata_t.float_value,
)
.left_join(metadata_t)
.on(collections_t.id == metadata_t.collection_id)
.left_join(databases_t)
.on(collections_t.database_id == databases_t.id)
.orderby(collections_t.id)
)
if id:
Expand Down Expand Up @@ -433,6 +444,8 @@ def get_collections(
name=name,
metadata=metadata,
dimension=dimension,
tenant=str(rows[0][5]),
database=str(rows[0][4]),
)
)

Expand Down
Loading

0 comments on commit dff2238

Please sign in to comment.