Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: vec-327 vec-372 deprecate field_names, add include_fields and exclude_fields optional arguments to client.get a… #51

Merged
merged 8 commits into from
Sep 27, 2024
67 changes: 59 additions & 8 deletions src/aerospike_vector_search/aio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
import sys
from typing import Any, Optional, Union
import warnings

import grpc
import numpy as np
Expand Down Expand Up @@ -269,9 +270,12 @@ async def get(
*,
namespace: str,
key: Union[int, str, bytes, bytearray, np.generic, np.ndarray],
field_names: Optional[list[str]] = None,
include_fields: Optional[list[str]] = None,
exclude_fields: Optional[list[str]] = None,
set_name: Optional[str] = None,
timeout: Optional[int] = None,
# field_names is deprecated, use include_fields
field_names: Optional[list[str]] = None,
) -> types.RecordWithKey:
"""
Read a record from Aerospike Vector Search.
Expand All @@ -282,16 +286,29 @@ async def get(
:param key: The key for the record.
:type key: Union[int, str, bytes, bytearray, np.generic, np.ndarray]

:param field_names: A list of field names to retrieve from the record.
:param include_fields: A list of field names to retrieve from the record.
When used, fields that are not included are not sent by the server,
saving on network traffic.
DomPeliniAerospike marked this conversation as resolved.
Show resolved Hide resolved
If a field is listed in both include_fields and exclude_fields,
exclude_fields takes priority, and the field is not returned.
If None, all fields are retrieved. Defaults to None.
:type field_names: Optional[list[str]]
:type include_fields: Optional[list[str]]

:param exclude_fields: A list of field names to exclude from the record.
When used, the excluded fields are not sent by the server,
saving on network traffic.
If None, all fields are retrieved. Defaults to None.
:type exclude_fields: Optional[list[str]]

:param set_name: The name of the set from which to read the record. Defaults to None.
:type set_name: Optional[str]

:param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError <aerospike_vector_search.types.AVSServerError>`. Defaults to None.
:type timeout: int

:param field_names: Deprecated, use include_fields instead.
:type field_names: Optional[list[str]]

Returns:
types.RecordWithKey: A record with its associated key.

Expand All @@ -301,10 +318,18 @@ async def get(

"""

# TODO remove this when 'field_names' is removed
if field_names is not None:
warnings.warn(
DomPeliniAerospike marked this conversation as resolved.
Show resolved Hide resolved
"The 'field_names' argument is deprecated. Use 'include_fields' instead",
FutureWarning,
)
include_fields = field_names

await self._channel_provider._is_ready()

(transact_stub, key, get_request, kwargs) = self._prepare_get(
namespace, key, field_names, set_name, timeout, logger
namespace, key, include_fields, exclude_fields, set_name, timeout, logger
)

try:
Expand Down Expand Up @@ -472,8 +497,11 @@ async def vector_search(
query: list[Union[bool, float]],
limit: int,
search_params: Optional[types.HnswSearchParams] = None,
field_names: Optional[list[str]] = None,
include_fields: Optional[list[str]] = None,
exclude_fields: Optional[list[str]] = None,
timeout: Optional[int] = None,
# field_names is deprecated, use include_fields
field_names: Optional[list[str]] = None,
) -> list[types.Neighbor]:
"""
Perform a Hierarchical Navigable Small World (HNSW) vector search in Aerospike Vector Search.
Expand All @@ -494,20 +522,42 @@ async def vector_search(
If None, the default parameters for the index are used. Defaults to None.
:type search_params: Optional[types_pb2.HnswSearchParams]

:param field_names: A list of field names to retrieve from the results.
:param include_fields: A list of field names to retrieve from the results.
When used, fields that are not included are not sent by the server,
saving on network traffic.
DomPeliniAerospike marked this conversation as resolved.
Show resolved Hide resolved
If a field is listed in both include_fields and exclude_fields,
exclude_fields takes priority, and the field is not returned.
If None, all fields are retrieved. Defaults to None.
:type field_names: Optional[list[str]]
:type include_fields: Optional[list[str]]

:param exclude_fields: A list of field names to exclude from the results.
When used, the excluded fields are not sent by the server,
saving on network traffic.
If None, all fields are retrieved. Defaults to None.
:type exclude_fields: Optional[list[str]]

:param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError <aerospike_vector_search.types.AVSServerError>`. Defaults to None.
:type timeout: int

:param field_names: Deprecated, use include_fields instead.
:type field_names: Optional[list[str]]

Returns:
list[types.Neighbor]: A list of neighbors records found by the search.

Raises:
AVSServerError: Raised if an error occurs during the RPC communication with the server while attempting to vector search.
This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters.
"""

# TODO remove this when 'field_names' is removed
if field_names is not None:
warnings.warn(
"The 'field_names' argument is deprecated. Use 'include_fields' instead",
FutureWarning,
)
include_fields = field_names

await self._channel_provider._is_ready()

(transact_stub, vector_search_request, kwargs) = self._prepare_vector_search(
Expand All @@ -516,7 +566,8 @@ async def vector_search(
query,
limit,
search_params,
field_names,
include_fields,
exclude_fields,
timeout,
logger,
)
Expand Down
66 changes: 58 additions & 8 deletions src/aerospike_vector_search/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import sys
import time
from typing import Any, Optional, Union
import warnings

import grpc

Expand Down Expand Up @@ -258,9 +259,12 @@ def get(
*,
namespace: str,
key: Union[int, str, bytes, bytearray],
field_names: Optional[list[str]] = None,
include_fields: Optional[list[str]] = None,
exclude_fields: Optional[list[str]] = None,
set_name: Optional[str] = None,
timeout: Optional[int] = None,
# field_names is deprecated, use include_fields
field_names: Optional[list[str]] = None,
) -> types.RecordWithKey:
"""
Read a record from Aerospike Vector Search.
Expand All @@ -271,17 +275,29 @@ def get(
:param key: The key for the record.
:type key: Union[int, str, bytes, bytearray, np.generic, np.ndarray]

:param include_fields: A list of field names to retrieve from the record.
When used, fields that are not included are not sent by the server,
saving on network traffic.
If a field is listed in both include_fields and exclude_fields,
exclude_fields takes priority, and the field is not returned.
If None, all fields are retrieved. Defaults to None.
DomPeliniAerospike marked this conversation as resolved.
Show resolved Hide resolved
:type include_fields: Optional[list[str]]

:param field_names: A list of field names to retrieve from the record.
:param exclude_fields: A list of field names to exclude from the record.
When used, the excluded fields are not sent by the server,
saving on network traffic.
If None, all fields are retrieved. Defaults to None.
:type field_names: Optional[list[str]]
:type exclude_fields: Optional[list[str]]

:param set_name: The name of the set from which to read the record. Defaults to None.
:type set_name: Optional[str]

:param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError <aerospike_vector_search.types.AVSServerError>`. Defaults to None.
:type timeout: int

:param field_names: Deprecated, use include_fields instead.
:type field_names: Optional[list[str]]

Returns:
types.RecordWithKey: A record with its associated key.

Expand All @@ -290,8 +306,17 @@ def get(
This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters.
"""

# TODO remove this when 'field_names' is removed
if field_names is not None:
warnings.warn(
"The 'field_names' argument is deprecated. Use 'include_fields' instead",
FutureWarning,
)
include_fields = field_names


(transact_stub, key, get_request, kwargs) = self._prepare_get(
namespace, key, field_names, set_name, timeout, logger
namespace, key, include_fields, exclude_fields, set_name, timeout, logger
)

try:
Expand Down Expand Up @@ -451,8 +476,11 @@ def vector_search(
query: list[Union[bool, float]],
limit: int,
search_params: Optional[types.HnswSearchParams] = None,
field_names: Optional[list[str]] = None,
include_fields: Optional[list[str]] = None,
exclude_fields: Optional[list[str]] = None,
timeout: Optional[int] = None,
# field_names is deprecated, use include_fields
field_names: Optional[list[str]] = None,
) -> list[types.Neighbor]:
"""
Perform a Hierarchical Navigable Small World (HNSW) vector search in Aerospike Vector Search.
Expand All @@ -473,13 +501,26 @@ def vector_search(
If None, the default parameters for the index are used. Defaults to None.
:type search_params: Optional[types_pb2.HnswSearchParams]

:param field_names: A list of field names to retrieve from the results.
:param include_fields: A list of field names to retrieve from the results.
When used, fields that are not included are not sent by the server,
saving on network traffic.
DomPeliniAerospike marked this conversation as resolved.
Show resolved Hide resolved
If a field is listed in both include_fields and exclude_fields,
exclude_fields takes priority, and the field is not returned.
If None, all fields are retrieved. Defaults to None.
:type field_names: Optional[list[str]]
:type include_fields: Optional[list[str]]

:param exclude_fields: A list of field names to exclude from the results.
When used, the excluded fields are not sent by the server,
saving on network traffic.
If None, all fields are retrieved. Defaults to None.
:type exclude_fields: Optional[list[str]]

:param timeout: Time in seconds this operation will wait before raising an :class:`AVSServerError <aerospike_vector_search.types.AVSServerError>`. Defaults to None.
:type timeout: int

:param field_names: Deprecated, use include_fields instead.
:type field_names: Optional[list[str]]

Returns:
list[types.Neighbor]: A list of neighbors records found by the search.

Expand All @@ -488,13 +529,22 @@ def vector_search(
This error could occur due to various reasons such as network issues, server-side failures, or invalid request parameters.
"""

# TODO remove this when 'field_names' is removed
if field_names is not None:
warnings.warn(
"The 'field_names' argument is deprecated. Use 'include_fields' instead",
FutureWarning,
)
include_fields = field_names

(transact_stub, vector_search_request, kwargs) = self._prepare_vector_search(
namespace,
index_name,
query,
limit,
search_params,
field_names,
include_fields,
exclude_fields,
timeout,
logger,
)
Expand Down
45 changes: 22 additions & 23 deletions src/aerospike_vector_search/shared/client_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,14 +131,15 @@ def _prepare_upsert(
)

def _prepare_get(
self, namespace, key, field_names, set_name, timeout, logger
self, namespace, key, include_fields, exclude_fields, set_name, timeout, logger
) -> None:

logger.debug(
"Getting record: namespace=%s, key=%s, field_names:%s, set_name:%s, timeout:%s",
"Getting record: namespace=%s, key=%s, include_fields:%s, exclude_fields:%s, set_name:%s, timeout:%s",
namespace,
key,
field_names,
include_fields,
exclude_fields,
set_name,
timeout,
)
Expand All @@ -148,7 +149,7 @@ def _prepare_get(
kwargs["timeout"] = timeout

key = self._get_key(namespace, set_name, key)
projection_spec = self._get_projection_spec(field_names=field_names)
projection_spec = self._get_projection_spec(include_fields=include_fields, exclude_fields=exclude_fields)

transact_stub = self._get_transact_stub()
get_request = transact_pb2.GetRequest(key=key, projection=projection_spec)
Expand Down Expand Up @@ -232,7 +233,8 @@ def _prepare_vector_search(
query,
limit,
search_params,
field_names,
include_fields,
exclude_fields,
timeout,
logger,
) -> None:
Expand All @@ -242,20 +244,21 @@ def _prepare_vector_search(
kwargs["timeout"] = timeout

logger.debug(
"Performing vector search: namespace=%s, index_name=%s, query=%s, limit=%s, search_params=%s, field_names=%s, timeout:%s",
"Performing vector search: namespace=%s, index_name=%s, query=%s, limit=%s, search_params=%s, include_fields=%s, exclude_fields=%s, timeout:%s",
namespace,
index_name,
query,
limit,
search_params,
field_names,
include_fields,
exclude_fields,
timeout,
)

if search_params != None:
search_params = search_params._to_pb2()

projection_spec = self._get_projection_spec(field_names=field_names)
projection_spec = self._get_projection_spec(include_fields=include_fields, exclude_fields=exclude_fields)

index = types_pb2.IndexId(namespace=namespace, name=index_name)

Expand Down Expand Up @@ -299,30 +302,26 @@ def _respond_neighbor(self, response) -> None:
def _get_projection_spec(
self,
*,
field_names: Optional[list] = None,
exclude_field_names: Optional[list] = None,
include_fields: Optional[list] = None,
exclude_fields: Optional[list] = None,
):

if field_names:
# include all fields by default
if include_fields is None:
include = transact_pb2.ProjectionFilter(
type=transact_pb2.ProjectionType.SPECIFIED, fields=field_names
)
exclude = transact_pb2.ProjectionFilter(
type=transact_pb2.ProjectionType.NONE, fields=None
type=transact_pb2.ProjectionType.ALL, fields=None
)
elif exclude_field_names:
else:
include = transact_pb2.ProjectionFilter(
type=transact_pb2.ProjectionType.NONE, fields=None
type=transact_pb2.ProjectionType.SPECIFIED, fields=include_fields
)

if exclude_fields is None:
exclude = transact_pb2.ProjectionFilter(
type=transact_pb2.ProjectionType.SPECIFIED, fields=exclude_field_names
type=transact_pb2.ProjectionType.NONE, fields=None
)
else:
include = transact_pb2.ProjectionFilter(
type=transact_pb2.ProjectionType.ALL, fields=None
)
exclude = transact_pb2.ProjectionFilter(
type=transact_pb2.ProjectionType.NONE, fields=None
type=transact_pb2.ProjectionType.SPECIFIED, fields=exclude_fields
)

projection_spec = transact_pb2.ProjectionSpec(include=include, exclude=exclude)
Expand Down
Loading
Loading