Skip to content

Commit

Permalink
feat: Add hybrid query example to vector search sample.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 643779254
  • Loading branch information
lingyinw authored and copybara-github committed Jun 16, 2024
1 parent 32e3b22 commit 510da5e
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 4 deletions.
20 changes: 20 additions & 0 deletions samples/model-builder/test_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,26 @@
VECTOR_SEARCH_INDEX_ENDPOINT = "456"
VECTOR_SEARCH_DEPLOYED_INDEX_ID = "789"
VECTOR_SERACH_INDEX_QUERIES = [[0.1]]
VECTOR_SERACH_INDEX_HYBRID_QUERIES = [
aiplatform.matching_engine.matching_engine_index_endpoint.HybridQuery(
dense_embedding=[1, 2, 3],
sparse_embedding_dimensions=[10, 20, 30],
sparse_embedding_values=[1.0, 1.0, 1.0],
rrf_ranking_alpha=0.5,
),
aiplatform.matching_engine.matching_engine_index_endpoint.HybridQuery(
dense_embedding=[1, 2, 3],
sparse_embedding_dimensions=[10, 20, 30],
sparse_embedding_values=[0.1, 0.2, 0.3],
),
aiplatform.matching_engine.matching_engine_index_endpoint.HybridQuery(
sparse_embedding_dimensions=[10, 20, 30],
sparse_embedding_values=[0.1, 0.2, 0.3],
),
aiplatform.matching_engine.matching_engine_index_endpoint.HybridQuery(
dense_embedding=[1, 2, 3]
),
]
VECTOR_SEARCH_INDEX_DISPLAY_NAME = "my-vector-search-index"
VECTOR_SEARCH_GCS_URI = "gs://fake-dir"
VECTOR_SEARCH_INDEX_ENDPOINT_DISPLAY_NAME = "my-vector-search-index-endpoint"
Original file line number Diff line number Diff line change
Expand Up @@ -55,5 +55,32 @@ def vector_search_find_neighbors(
)
print(resp)

# Query hybrid datapoints, sparse-only datapoints, and dense-only datapoints.
hybrid_queries = [
aiplatform.matching_engine.matching_engine_index_endpoint.HybridQuery(
dense_embedding=[1, 2, 3],
sparse_embedding_dimensions=[10, 20, 30],
sparse_embedding_values=[1.0, 1.0, 1.0],
rrf_ranking_alpha=0.5,
),
aiplatform.matching_engine.matching_engine_index_endpoint.HybridQuery(
dense_embedding=[1, 2, 3],
sparse_embedding_dimensions=[10, 20, 30],
sparse_embedding_values=[0.1, 0.2, 0.3],
),
aiplatform.matching_engine.matching_engine_index_endpoint.HybridQuery(
sparse_embedding_dimensions=[10, 20, 30],
sparse_embedding_values=[0.1, 0.2, 0.3],
),
aiplatform.matching_engine.matching_engine_index_endpoint.HybridQuery(
dense_embedding=[1, 2, 3]
),
]

hybrid_resp = my_index_endpoint.find_neighbors(
deployed_index_id=deployed_index_id,
queries=hybrid_queries,
num_neighbors=num_neighbors,)
print(hybrid_resp)

# [END aiplatform_sdk_vector_search_find_neighbors_sample]
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from unittest.mock import call

import test_constants as constants
from vector_search import vector_search_find_neighbors_sample

Expand All @@ -38,8 +40,18 @@ def test_vector_search_find_neighbors_sample(
index_endpoint_name=constants.VECTOR_SEARCH_INDEX_ENDPOINT)

# Check index_endpoint.find_neighbors is called with right params.
mock_index_endpoint_find_neighbors.assert_called_with(
deployed_index_id=constants.VECTOR_SEARCH_DEPLOYED_INDEX_ID,
queries=constants.VECTOR_SERACH_INDEX_QUERIES,
num_neighbors=10
mock_index_endpoint_find_neighbors.assert_has_calls(
[
call(
deployed_index_id=constants.VECTOR_SEARCH_DEPLOYED_INDEX_ID,
queries=constants.VECTOR_SERACH_INDEX_QUERIES,
num_neighbors=10,
),
call(
deployed_index_id=constants.VECTOR_SEARCH_DEPLOYED_INDEX_ID,
queries=constants.VECTOR_SERACH_INDEX_HYBRID_QUERIES,
num_neighbors=10,
),
],
any_order=False,
)

0 comments on commit 510da5e

Please sign in to comment.