Skip to content

Commit

Permalink
Collection options is a dataclass and not a dict anymore (#269)
Browse files Browse the repository at this point in the history
* full usage of structures in collection options/descriptors

* adapted coll.lifecycle int.tests to structured coll.options

* add flatten method to collection options structures
  • Loading branch information
hemidactylus authored Mar 27, 2024
1 parent a569a38 commit ff3933c
Show file tree
Hide file tree
Showing 8 changed files with 518 additions and 231 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,10 @@ from astrapy.info import (
AdminDatabaseInfo,
DatabaseInfo,
CollectionInfo,
CollectionDefaultIDOptions,
CollectionVectorOptions,
CollectionOptions,
CollectionDescriptor,
)
```

Expand Down
40 changes: 19 additions & 21 deletions astrapy/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
BulkWriteResult,
)
from astrapy.cursors import AsyncCursor, Cursor
from astrapy.info import CollectionInfo
from astrapy.info import CollectionInfo, CollectionOptions


if TYPE_CHECKING:
Expand Down Expand Up @@ -347,7 +347,7 @@ def set_caller(
caller_version=caller_version,
)

def options(self, *, max_time_ms: Optional[int] = None) -> Dict[str, Any]:
def options(self, *, max_time_ms: Optional[int] = None) -> CollectionOptions:
"""
Get the collection options, i.e. its configuration as read from the database.
Expand All @@ -359,22 +359,21 @@ def options(self, *, max_time_ms: Optional[int] = None) -> Dict[str, Any]:
max_time_ms: a timeout, in milliseconds, for the underlying HTTP request.
Returns:
a dictionary expressing the collection as a set of key-value pairs
matching the arguments of a `create_collection` call.
a CollectionOptions instance describing the collection.
(See also the database `list_collections` method.)
Example:
>>> my_coll.options()
{'name': 'my_v_collection', 'dimension': 3, 'metric': 'cosine'}
CollectionOptions(vector=CollectionVectorOptions(dimension=3, metric='cosine'))
"""

self_dicts = [
coll_dict
for coll_dict in self.database.list_collections(max_time_ms=max_time_ms)
if coll_dict["name"] == self.name
self_descriptors = [
coll_desc
for coll_desc in self.database.list_collections(max_time_ms=max_time_ms)
if coll_desc.name == self.name
]
if self_dicts:
return self_dicts[0] # type: ignore[no-any-return]
if self_descriptors:
return self_descriptors[0].options # type: ignore[no-any-return]
else:
raise CollectionNotFoundException(
text=f"Collection {self.namespace}.{self.name} not found.",
Expand Down Expand Up @@ -2411,7 +2410,7 @@ def set_caller(
caller_version=caller_version,
)

async def options(self, *, max_time_ms: Optional[int] = None) -> Dict[str, Any]:
async def options(self, *, max_time_ms: Optional[int] = None) -> CollectionOptions:
"""
Get the collection options, i.e. its configuration as read from the database.
Expand All @@ -2423,24 +2422,23 @@ async def options(self, *, max_time_ms: Optional[int] = None) -> Dict[str, Any]:
max_time_ms: a timeout, in milliseconds, for the underlying HTTP request.
Returns:
a dictionary expressing the collection as a set of key-value pairs
matching the arguments of a `create_collection` call.
a CollectionOptions instance describing the collection.
(See also the database `list_collections` method.)
Example:
>>> asyncio.run(my_async_coll.options())
{'name': 'my_v_collection', 'dimension': 3, 'metric': 'cosine'}
CollectionOptions(vector=CollectionVectorOptions(dimension=3, metric='cosine'))
"""

self_dicts = [
coll_dict
async for coll_dict in self.database.list_collections(
self_descriptors = [
coll_desc
async for coll_desc in self.database.list_collections(
max_time_ms=max_time_ms
)
if coll_dict["name"] == self.name
if coll_desc.name == self.name
]
if self_dicts:
return self_dicts[0] # type: ignore[no-any-return]
if self_descriptors:
return self_descriptors[0].options # type: ignore[no-any-return]
else:
raise CollectionNotFoundException(
text=f"Collection {self.namespace}.{self.name} not found.",
Expand Down
67 changes: 15 additions & 52 deletions astrapy/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
base_timeout_info,
)
from astrapy.cursors import AsyncCommandCursor, CommandCursor
from astrapy.info import DatabaseInfo
from astrapy.info import DatabaseInfo, CollectionDescriptor
from astrapy.admin import parse_api_endpoint, fetch_database_info

if TYPE_CHECKING:
Expand Down Expand Up @@ -70,41 +70,6 @@ def _validate_create_collection_options(
)


def _recast_api_collection_dict(api_coll_dict: Dict[str, Any]) -> Dict[str, Any]:
_name = api_coll_dict["name"]
_options = api_coll_dict.get("options") or {}
_v_options0 = _options.get("vector") or {}
_indexing = _options.get("indexing") or {}
_v_dimension = _v_options0.get("dimension")
_v_metric = _v_options0.get("metric")
_default_id = _options.get("defaultId")
# defaultId may potentially in the future have other subfields than 'type':
if _default_id:
_default_id_type = _default_id.get("type")
_rest_default_id = {k: v for k, v in _default_id.items() if k != "type"}
else:
_default_id_type = None
_rest_default_id = None
_additional_options = {
**{
k: v
for k, v in _options.items()
if k not in {"vector", "indexing", "defaultId"}
},
**({"defaultId": _rest_default_id} if _rest_default_id else {}),
}
recast_dict0 = {
"name": _name,
"dimension": _v_dimension,
"metric": _v_metric,
"indexing": _indexing,
"default_id_type": _default_id_type,
"additional_options": _additional_options,
}
recast_dict = {k: v for k, v in recast_dict0.items() if v}
return recast_dict


class Database:
"""
A Data API database. This is the entry-point object for doing database-level
Expand Down Expand Up @@ -592,7 +557,7 @@ def list_collections(
*,
namespace: Optional[str] = None,
max_time_ms: Optional[int] = None,
) -> CommandCursor[Dict[str, Any]]:
) -> CommandCursor[CollectionDescriptor]:
"""
List all collections in a given namespace for this database.
Expand All @@ -602,20 +567,19 @@ def list_collections(
max_time_ms: a timeout, in milliseconds, for the underlying HTTP request.
Returns:
a `CommandCursor` to iterate over dictionaries, each
expressing a collection as a set of key-value pairs
matching the arguments of a `create_collection` call.
a `CommandCursor` to iterate over CollectionDescriptor instances,
each corresponding to a collection.
Example:
>>> ccur = my_db.list_collections()
>>> ccur
<astrapy.cursors.CommandCursor object at ...>
>>> list(ccur)
[{'name': 'my_v_col'}]
[CollectionDescriptor(name='my_v_col', options=CollectionOptions())]
>>> for coll_dict in my_db.list_collections():
... print(coll_dict)
...
{'name': 'my_v_col'}
CollectionDescriptor(name='my_v_col', options=CollectionOptions())
"""

if namespace:
Expand All @@ -631,11 +595,11 @@ def list_collections(
raw_response=gc_response,
)
else:
# we know this is a list of dicts which need a little adjusting
# we know this is a list of dicts, to marshal into "descriptors"
return CommandCursor(
address=self._astra_db.base_url,
items=[
_recast_api_collection_dict(col_dict)
CollectionDescriptor.from_dict(col_dict)
for col_dict in gc_response["status"]["collections"]
],
)
Expand Down Expand Up @@ -1286,7 +1250,7 @@ def list_collections(
*,
namespace: Optional[str] = None,
max_time_ms: Optional[int] = None,
) -> AsyncCommandCursor[Dict[str, Any]]:
) -> AsyncCommandCursor[CollectionDescriptor]:
"""
List all collections in a given namespace for this database.
Expand All @@ -1296,9 +1260,8 @@ def list_collections(
max_time_ms: a timeout, in milliseconds, for the underlying HTTP request.
Returns:
an `AsyncCommandCursor` to iterate over dictionaries, each
expressing a collection as a set of key-value pairs
matching the arguments of a `create_collection` call.
an `AsyncCommandCursor` to iterate over CollectionDescriptor instances,
each corresponding to a collection.
Example:
>>> async def a_list_colls(adb: AsyncDatabase) -> None:
Expand All @@ -1310,8 +1273,8 @@ def list_collections(
...
>>> asyncio.run(a_list_colls(my_async_db))
* a_ccur: <astrapy.cursors.AsyncCommandCursor object at ...>
* list: [{'name': 'my_v_col'}]
* coll: {'name': 'my_v_col'}
* list: [CollectionDescriptor(name='my_v_col', options=CollectionOptions())]
* coll: CollectionDescriptor(name='my_v_col', options=CollectionOptions())
"""

_client: AsyncAstraDB
Expand All @@ -1329,11 +1292,11 @@ def list_collections(
raw_response=gc_response,
)
else:
# we know this is a list of dicts which need a little adjusting
# we know this is a list of dicts, to marshal into "descriptors"
return AsyncCommandCursor(
address=self._astra_db.base_url,
items=[
_recast_api_collection_dict(col_dict)
CollectionDescriptor.from_dict(col_dict)
for col_dict in gc_response["status"]["collections"]
],
)
Expand Down
Loading

0 comments on commit ff3933c

Please sign in to comment.