Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from opentelemetry.instrumentation.utils import (
_SUPPRESS_INSTRUMENTATION_KEY,
)
from opentelemetry.semconv_ai import Events
from opentelemetry.semconv_ai import Events, EventAttributes
from opentelemetry.semconv_ai import SpanAttributes as AISpanAttributes


Expand Down Expand Up @@ -55,9 +55,13 @@ def _wrap(tracer, to_wrap, wrapped, instance, args, kwargs):
_set_create_collection_attributes(span, kwargs)

return_value = wrapped(*args, **kwargs)

if to_wrap.get("method") == "query":
_add_query_result_events(span, return_value)

if to_wrap.get("method") == "search":
_add_search_result_events(span, return_value)

return return_value


Expand Down Expand Up @@ -198,7 +202,9 @@ def _set_search_attributes(span, kwargs):
count_or_none(kwargs.get("output_fields")),
)
_set_span_attribute(
span, AISpanAttributes.MILVUS_SEARCH_SEARCH_PARAMS, kwargs.get("search_params")
span,
AISpanAttributes.MILVUS_SEARCH_SEARCH_PARAMS,
_encode_include(kwargs.get("search_params")),
)
_set_span_attribute(
span, AISpanAttributes.MILVUS_SEARCH_TIMEOUT, kwargs.get("timeout")
Expand All @@ -211,6 +217,19 @@ def _set_search_attributes(span, kwargs):
_set_span_attribute(
span, AISpanAttributes.MILVUS_SEARCH_ANNS_FIELD, kwargs.get("anns_field")
)
_set_span_attribute(
span,
AISpanAttributes.MILVUS_SEARCH_PARTITION_NAMES,
_encode_partition_name(kwargs.get("partition_names")),
)
query_vectors = kwargs.get("data", [])
vector_dims = [len(vec) for vec in query_vectors]

_set_span_attribute(
span,
AISpanAttributes.MILVUS_SEARCH_QUERY_VECTOR_DIMENSION,
_encode_include(vector_dims),
)


@dont_throw
Expand Down Expand Up @@ -248,6 +267,60 @@ def _add_query_result_events(span, kwargs):
span.add_event(name=Events.DB_QUERY_RESULT.value, attributes=element)


@dont_throw
def _add_search_result_events(span, kwargs):

all_distances = []
total_matches = 0

single_query = len(kwargs) == 1

def set_query_stats(query_idx, distances, match_ids):
"""Helper function to set per-query stats in the span."""

_set_span_attribute(
span,
f"{AISpanAttributes.MILVUS_SEARCH_RESULT_COUNT}_{query_idx}",
len(distances),
)

def set_global_stats():
"""Helper function to set global stats for a single query."""
_set_span_attribute(
span, AISpanAttributes.MILVUS_SEARCH_RESULT_COUNT, total_matches
)

for query_idx, query_results in enumerate(kwargs):

query_distances = []
query_match_ids = []

for match in query_results:
distance = float(match["distance"])
query_distances.append(distance)
all_distances.append(distance)
total_matches += 1
query_match_ids.append(match["id"])

span.add_event(
Events.DB_SEARCH_RESULT.value,
attributes={
EventAttributes.DB_SEARCH_RESULT_QUERY_ID.value: query_idx,
EventAttributes.DB_SEARCH_RESULT_ID.value: match["id"],
EventAttributes.DB_SEARCH_RESULT_DISTANCE.value: str(distance),
EventAttributes.DB_SEARCH_RESULT_ENTITY.value: _encode_include(
match["entity"]
),
},
)

if not single_query:
set_query_stats(query_idx, query_distances, query_match_ids)

if single_query:
set_global_stats()


@dont_throw
def _set_upsert_attributes(span, kwargs):
_set_span_attribute(
Expand Down
200 changes: 200 additions & 0 deletions packages/opentelemetry-instrumentation-milvus/tests/test_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
import os
import random

import pymilvus
import pytest
from opentelemetry.semconv_ai import Events, SpanAttributes, EventAttributes

path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "milvus.db")
milvus = pymilvus.MilvusClient(uri=path)


@pytest.fixture
def collection():
collection_name = "Colors"
milvus.create_collection(collection_name=collection_name, dimension=5)
yield collection_name
milvus.drop_collection(collection_name=collection_name)


def insert_data(collection):
colors = [
"green",
"blue",
"yellow",
"red",
"black",
"white",
"purple",
"pink",
"orange",
"grey",
]
data = [
{
"id": i,
"vector": [random.uniform(-1, 1) for _ in range(5)],
"color": random.choice(colors),
"tag": random.randint(1000, 9999),
}
for i in range(1000)
]
data += [
{
"id": 1000,
"vector": [random.uniform(-1, 1) for _ in range(5)],
"color": "brown",
"tag": 1234,
},
{
"id": 1001,
"vector": [random.uniform(-1, 1) for _ in range(5)],
"color": "brown",
"tag": 5678,
},
{
"id": 1002,
"vector": [random.uniform(-1, 1) for _ in range(5)],
"color": "brown",
"tag": 9101,
},
]
for i in data:
i["color_tag"] = "{}_{}".format(i["color"], i["tag"])
milvus.insert(collection_name=collection, data=data)


def test_milvus_single_vector_search(exporter, collection):
insert_data(collection)

query_vectors = [
[random.uniform(-1, 1) for _ in range(5)], # Random query vector for the search
]
search_params = {"radius": 0.5, "metric_type": "COSINE", "index_type": "IVF_FLAT"}
milvus.search(
collection_name=collection,
data=query_vectors,
anns_field="vector",
search_params=search_params,
output_fields=["color_tag"],
limit=3,
timeout=10,
)

# Get finished spans
spans = exporter.get_finished_spans()
span = next(span for span in spans if span.name == "milvus.search")

# Check the span attributes related to search
assert span.attributes.get(SpanAttributes.VECTOR_DB_VENDOR) == "milvus"
assert span.attributes.get(SpanAttributes.VECTOR_DB_OPERATION) == "search"
assert (
span.attributes.get(SpanAttributes.MILVUS_SEARCH_COLLECTION_NAME) == collection
)
assert span.attributes.get(SpanAttributes.MILVUS_SEARCH_OUTPUT_FIELDS_COUNT) == 1
assert span.attributes.get(SpanAttributes.MILVUS_SEARCH_LIMIT) == 3
assert span.attributes.get(SpanAttributes.MILVUS_SEARCH_TIMEOUT) == 10
assert span.attributes.get(SpanAttributes.MILVUS_SEARCH_ANNS_FIELD) == "vector"
assert (
span.attributes.get(SpanAttributes.MILVUS_SEARCH_QUERY_VECTOR_DIMENSION)
== "[5]"
)
distances = []
ids = []

events = span.events

for event in events:
assert event.name == Events.DB_SEARCH_RESULT.value
_id = event.attributes.get(EventAttributes.DB_SEARCH_RESULT_ID.value)
distance = event.attributes.get(EventAttributes.DB_SEARCH_RESULT_DISTANCE.value)

assert isinstance(_id, int)
assert isinstance(distance, str)

# Collect the distances and IDs for further computation
distances.append(
float(distance)
) # Convert the distance to a float for computation
ids.append(_id)

# Now compute dynamic stats from the distances
total_matches = len(events)

assert (
span.attributes.get(SpanAttributes.MILVUS_SEARCH_RESULT_COUNT) == total_matches
)


def test_milvus_multiple_vector_search(exporter, collection):
insert_data(collection)

query_vectors = [
[random.uniform(-1, 1) for _ in range(5)], # Random query vector for the search
[random.uniform(-1, 1) for _ in range(5)], # Another query vector
[
random.uniform(-1, 1) for _ in range(5)
], # Another query vector (you can add more as needed)
]
search_params = {"radius": 0.5, "metric_type": "COSINE", "index_type": "IVF_FLAT"}
milvus.search(
collection_name=collection,
data=query_vectors,
anns_field="vector",
search_params=search_params,
output_fields=["color_tag"],
limit=3,
timeout=10,
)

# Get finished spans
spans = exporter.get_finished_spans()
span = next(span for span in spans if span.name == "milvus.search")

# Check the span attributes related to search
assert span.attributes.get(SpanAttributes.VECTOR_DB_VENDOR) == "milvus"
assert span.attributes.get(SpanAttributes.VECTOR_DB_OPERATION) == "search"
assert (
span.attributes.get(SpanAttributes.MILVUS_SEARCH_COLLECTION_NAME) == collection
)
assert span.attributes.get(SpanAttributes.MILVUS_SEARCH_OUTPUT_FIELDS_COUNT) == 1
assert span.attributes.get(SpanAttributes.MILVUS_SEARCH_LIMIT) == 3
assert span.attributes.get(SpanAttributes.MILVUS_SEARCH_TIMEOUT) == 10
assert span.attributes.get(SpanAttributes.MILVUS_SEARCH_ANNS_FIELD) == "vector"
assert (
span.attributes.get(SpanAttributes.MILVUS_SEARCH_QUERY_VECTOR_DIMENSION)
== "[5, 5, 5]"
)

distances_dict = {}
ids_dict = {}

events = span.events
for event in events:
assert event.name == Events.DB_SEARCH_RESULT.value
query_idx = event.attributes.get(
EventAttributes.DB_SEARCH_RESULT_QUERY_ID.value
)
_id = event.attributes.get(EventAttributes.DB_SEARCH_RESULT_ID.value)
distance = event.attributes.get(EventAttributes.DB_SEARCH_RESULT_DISTANCE.value)

assert isinstance(_id, int)
assert isinstance(distance, str)

distance = float(distance)

if query_idx not in distances_dict:
distances_dict[query_idx] = []
ids_dict[query_idx] = []

distances_dict[query_idx].append(distance)
ids_dict[query_idx].append(_id)

for query_idx in distances_dict:
distances = distances_dict[query_idx]

total_matches = len(distances)

count_key = f"{SpanAttributes.MILVUS_SEARCH_RESULT_COUNT}_{query_idx}"

assert span.attributes.get(count_key) == total_matches