Skip to content

Commit

Permalink
feat: Add query by id for MatchingEngineIndexEndpoint `find_neighbo…
Browse files Browse the repository at this point in the history
…rs()` public endpoint query.

PiperOrigin-RevId: 599497930
  • Loading branch information
lingyinw authored and copybara-github committed Jan 18, 2024
1 parent 67e593b commit 42c7e08
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 53 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,13 @@ class MatchNeighbor:
Required. The id of the neighbor.
distance (float):
Required. The distance to the query embedding.
feature_vector (List(float)):
Optional. The feature vector of the matching datapoint.
"""

id: str
distance: float
feature_vector: Optional[List[float]] = None


@dataclass
Expand Down Expand Up @@ -1185,14 +1188,15 @@ def find_neighbors(
self,
*,
deployed_index_id: str,
queries: List[List[float]],
queries: Optional[List[List[float]]] = None,
num_neighbors: int = 10,
filter: Optional[List[Namespace]] = None,
per_crowding_attribute_neighbor_count: Optional[int] = None,
approx_num_neighbors: Optional[int] = None,
fraction_leaf_nodes_to_search_override: Optional[float] = None,
return_full_datapoint: bool = False,
numeric_filter: Optional[List[NumericNamespace]] = None,
embedding_ids: Optional[List[str]] = None,
) -> List[List[MatchNeighbor]]:
"""Retrieves nearest neighbors for the given embedding queries on the
specified deployed index which is deployed to either public or private
Expand Down Expand Up @@ -1243,11 +1247,18 @@ def find_neighbors(
Note that returning full datapoint will significantly increase the
latency and cost of the query.
numeric_filter (Optional[list[NumericNamespace]]):
numeric_filter (list[NumericNamespace]):
Optional. A list of NumericNamespaces for filtering the matching
results. For example:
[NumericNamespace(name="cost", value_int=5, op="GREATER")]
will match datapoints that its cost is greater than 5.
embedding_ids (str):
Optional. If `queries` is set, will use `queries` to do nearest
neighbor search. If `queries` isn't set, will first use
`embedding_ids` to lookup embedding values from dataset, if embedding
with `embedding_ids` exists in the dataset, do nearest neighbor search.
Returns:
List[List[MatchNeighbor]] - A list of nearest neighbors for each query.
"""
Expand All @@ -1262,7 +1273,6 @@ def find_neighbors(
per_crowding_attribute_num_neighbors=per_crowding_attribute_neighbor_count,
approx_num_neighbors=approx_num_neighbors,
fraction_leaf_nodes_to_search_override=fraction_leaf_nodes_to_search_override,
return_full_datapoint=return_full_datapoint,
)

# Create the FindNeighbors request
Expand All @@ -1271,50 +1281,60 @@ def find_neighbors(
find_neighbors_request.deployed_index_id = deployed_index_id
find_neighbors_request.return_full_datapoint = return_full_datapoint

for query in queries:
find_neighbors_query = (
gca_match_service_v1beta1.FindNeighborsRequest.Query()
)
find_neighbors_query.neighbor_count = num_neighbors
find_neighbors_query.per_crowding_attribute_neighbor_count = (
per_crowding_attribute_neighbor_count
)
find_neighbors_query.approximate_neighbor_count = approx_num_neighbors
find_neighbors_query.fraction_leaf_nodes_to_search_override = (
fraction_leaf_nodes_to_search_override
# Token restricts
restricts = []
if filter:
for namespace in filter:
restrict = gca_index_v1beta1.IndexDatapoint.Restriction()
restrict.namespace = namespace.name
restrict.allow_list.extend(namespace.allow_tokens)
restrict.deny_list.extend(namespace.deny_tokens)
restricts.append(restrict)
# Numeric restricts
numeric_restricts = []
if numeric_filter:
for numeric_namespace in numeric_filter:
numeric_restrict = gca_index_v1beta1.IndexDatapoint.NumericRestriction()
numeric_restrict.namespace = numeric_namespace.name
numeric_restrict.op = numeric_namespace.op
numeric_restrict.value_int = numeric_namespace.value_int
numeric_restrict.value_float = numeric_namespace.value_float
numeric_restrict.value_double = numeric_namespace.value_double
numeric_restricts.append(numeric_restrict)
# Queries
query_by_id = False if queries else True
queries = queries if queries else embedding_ids
if queries:
for query in queries:
find_neighbors_query = gca_match_service_v1beta1.FindNeighborsRequest.Query(
neighbor_count=num_neighbors,
per_crowding_attribute_neighbor_count=per_crowding_attribute_neighbor_count,
approximate_neighbor_count=approx_num_neighbors,
fraction_leaf_nodes_to_search_override=fraction_leaf_nodes_to_search_override,
)
datapoint = gca_index_v1beta1.IndexDatapoint(
datapoint_id=query if query_by_id else None,
feature_vector=None if query_by_id else query,
)
datapoint.restricts.extend(restricts)
datapoint.numeric_restricts.extend(numeric_restricts)
find_neighbors_query.datapoint = datapoint
find_neighbors_request.queries.append(find_neighbors_query)
else:
raise ValueError(
"To find neighbors using matching engine,"
"please specify `queries` or `embedding_ids`"
)
datapoint = gca_index_v1beta1.IndexDatapoint(feature_vector=query)
# Token restricts
if filter:
for namespace in filter:
restrict = gca_index_v1beta1.IndexDatapoint.Restriction()
restrict.namespace = namespace.name
restrict.allow_list.extend(namespace.allow_tokens)
restrict.deny_list.extend(namespace.deny_tokens)
datapoint.restricts.append(restrict)
# Numeric restricts
if numeric_filter:
for numeric_namespace in numeric_filter:
numeric_restrict = (
gca_index_v1beta1.IndexDatapoint.NumericRestriction()
)
numeric_restrict.namespace = numeric_namespace.name
numeric_restrict.op = numeric_namespace.op
numeric_restrict.value_int = numeric_namespace.value_int
numeric_restrict.value_float = numeric_namespace.value_float
numeric_restrict.value_double = numeric_namespace.value_double
datapoint.numeric_restricts.append(numeric_restrict)

find_neighbors_query.datapoint = datapoint
find_neighbors_request.queries.append(find_neighbors_query)

response = self._public_match_client.find_neighbors(find_neighbors_request)

# Wrap the results in MatchNeighbor objects and return
return [
[
MatchNeighbor(
id=neighbor.datapoint.datapoint_id, distance=neighbor.distance
id=neighbor.datapoint.datapoint_id,
distance=neighbor.distance,
feature_vector=neighbor.datapoint.feature_vector,
)
for neighbor in embedding_neighbors.neighbors
]
Expand Down Expand Up @@ -1429,13 +1449,12 @@ def _batch_get_embeddings(
def match(
self,
deployed_index_id: str,
queries: Optional[List[List[float]]] = None,
queries: List[List[float]] = None,
num_neighbors: int = 1,
filter: Optional[List[Namespace]] = None,
per_crowding_attribute_num_neighbors: Optional[int] = None,
approx_num_neighbors: Optional[int] = None,
fraction_leaf_nodes_to_search_override: Optional[float] = None,
return_full_datapoint: bool = False,
low_level_batch_size: int = 0,
) -> List[List[MatchNeighbor]]:
"""Retrieves nearest neighbors for the given embedding queries on the
Expand Down Expand Up @@ -1468,11 +1487,6 @@ def match(
query time allows user to tune search performance. This value
increase result in both search accuracy and latency increase.
The value should be between 0.0 and 1.0.
return_full_datapoint (bool):
Optional. If set to true, the full datapoints (including all
vector values and of the nearest neighbors are returned.
Note that returning full datapoint will significantly increase the
latency and cost of the query.
low_level_batch_size (int):
Optional. Selects the optimal batch size to use for low-level
batching. Queries within each low level batch are executed
Expand Down Expand Up @@ -1518,9 +1532,13 @@ def match(
per_crowding_attribute_num_neighbors=per_crowding_attribute_num_neighbors,
approx_num_neighbors=approx_num_neighbors,
fraction_leaf_nodes_to_search_override=fraction_leaf_nodes_to_search_override,
embedding_enabled=return_full_datapoint,
)
requests.append(request)
else:
raise ValueError(
"To find neighbors using matching engine,"
"please specify `queries` or `embedding_ids`"
)

batch_request_for_index.requests.extend(requests)
batch_request.requests.append(batch_request_for_index)
Expand All @@ -1531,8 +1549,11 @@ def match(
# Wrap the results in MatchNeighbor objects and return
return [
[
MatchNeighbor(id=neighbor.id, distance=neighbor.distance)
for neighbor in embedding_neighbors.neighbor
MatchNeighbor(
id=embedding_neighbors.neighbor[i].id,
distance=embedding_neighbors.neighbor[i].distance,
)
for i in range(len(embedding_neighbors.neighbor))
]
for embedding_neighbors in response.responses[0].responses
]
56 changes: 52 additions & 4 deletions tests/unit/aiplatform/test_matching_engine_index_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@
-0.021106,
]
]
_TEST_QUERY_IDS = ["1", "2"]
_TEST_NUM_NEIGHBOURS = 1
_TEST_FILTER = [
Namespace(name="class", allow_tokens=["token_1"], deny_tokens=["token_2"])
Expand Down Expand Up @@ -1044,7 +1045,7 @@ def test_index_endpoint_match_queries_backward_compatibility(
index_endpoint_match_queries_mock.assert_called_with(batch_request)

@pytest.mark.usefixtures("get_index_endpoint_mock")
def test_private_index_endpoint_match_queries(
def test_private_service_access_index_endpoint_match_queries(
self, index_endpoint_match_queries_mock
):
aiplatform.init(project=_TEST_PROJECT)
Expand All @@ -1061,7 +1062,6 @@ def test_private_index_endpoint_match_queries(
per_crowding_attribute_num_neighbors=_TEST_PER_CROWDING_ATTRIBUTE_NUM_NEIGHBOURS,
approx_num_neighbors=_TEST_APPROX_NUM_NEIGHBORS,
fraction_leaf_nodes_to_search_override=_TEST_FRACTION_LEAF_NODES_TO_SEARCH_OVERRIDE,
return_full_datapoint=_TEST_RETURN_FULL_DATAPOINT,
low_level_batch_size=_TEST_LOW_LEVEL_BATCH_SIZE,
)

Expand All @@ -1085,7 +1085,6 @@ def test_private_index_endpoint_match_queries(
per_crowding_attribute_num_neighbors=_TEST_PER_CROWDING_ATTRIBUTE_NUM_NEIGHBOURS,
approx_num_neighbors=_TEST_APPROX_NUM_NEIGHBORS,
fraction_leaf_nodes_to_search_override=_TEST_FRACTION_LEAF_NODES_TO_SEARCH_OVERRIDE,
embedding_enabled=_TEST_RETURN_FULL_DATAPOINT,
)
for i in range(len(_TEST_QUERIES))
],
Expand Down Expand Up @@ -1135,7 +1134,6 @@ def test_private_index_endpoint_find_neighbor_queries(
per_crowding_attribute_num_neighbors=_TEST_PER_CROWDING_ATTRIBUTE_NUM_NEIGHBOURS,
approx_num_neighbors=_TEST_APPROX_NUM_NEIGHBORS,
fraction_leaf_nodes_to_search_override=_TEST_FRACTION_LEAF_NODES_TO_SEARCH_OVERRIDE,
embedding_enabled=_TEST_RETURN_FULL_DATAPOINT,
)
for test_query in _TEST_QUERIES
],
Expand Down Expand Up @@ -1241,6 +1239,56 @@ def test_index_public_endpoint_find_neighbors_queries(
find_neighbors_request
)

@pytest.mark.usefixtures("get_index_public_endpoint_mock")
def test_index_public_endpoint_find_neiggbor_query_by_id(
self, index_public_endpoint_match_queries_mock
):
aiplatform.init(project=_TEST_PROJECT)

my_pubic_index_endpoint = aiplatform.MatchingEngineIndexEndpoint(
index_endpoint_name=_TEST_INDEX_ENDPOINT_ID
)

my_pubic_index_endpoint.find_neighbors(
deployed_index_id=_TEST_DEPLOYED_INDEX_ID,
num_neighbors=_TEST_NUM_NEIGHBOURS,
filter=_TEST_FILTER,
per_crowding_attribute_neighbor_count=_TEST_PER_CROWDING_ATTRIBUTE_NUM_NEIGHBOURS,
approx_num_neighbors=_TEST_APPROX_NUM_NEIGHBORS,
fraction_leaf_nodes_to_search_override=_TEST_FRACTION_LEAF_NODES_TO_SEARCH_OVERRIDE,
return_full_datapoint=_TEST_RETURN_FULL_DATAPOINT,
embedding_ids=_TEST_QUERY_IDS,
)

find_neighbors_request = gca_match_service_v1beta1.FindNeighborsRequest(
index_endpoint=my_pubic_index_endpoint.resource_name,
deployed_index_id=_TEST_DEPLOYED_INDEX_ID,
queries=[
gca_match_service_v1beta1.FindNeighborsRequest.Query(
neighbor_count=_TEST_NUM_NEIGHBOURS,
datapoint=gca_index_v1beta1.IndexDatapoint(
datapoint_id=_TEST_QUERY_IDS[i],
restricts=[
gca_index_v1beta1.IndexDatapoint.Restriction(
namespace="class",
allow_list=["token_1"],
deny_list=["token_2"],
)
],
),
per_crowding_attribute_neighbor_count=_TEST_PER_CROWDING_ATTRIBUTE_NUM_NEIGHBOURS,
approximate_neighbor_count=_TEST_APPROX_NUM_NEIGHBORS,
fraction_leaf_nodes_to_search_override=_TEST_FRACTION_LEAF_NODES_TO_SEARCH_OVERRIDE,
)
for i in range(len(_TEST_QUERY_IDS))
],
return_full_datapoint=_TEST_RETURN_FULL_DATAPOINT,
)

index_public_endpoint_match_queries_mock.assert_called_with(
find_neighbors_request
)

@pytest.mark.usefixtures("get_index_public_endpoint_mock")
def test_index_public_endpoint_match_queries_with_numeric_filtering(
self, index_public_endpoint_match_queries_mock
Expand Down

0 comments on commit 42c7e08

Please sign in to comment.