diff --git a/packages/opentelemetry-instrumentation-milvus/opentelemetry/instrumentation/milvus/wrapper.py b/packages/opentelemetry-instrumentation-milvus/opentelemetry/instrumentation/milvus/wrapper.py index c493a91d83..b14b9cd9ba 100644 --- a/packages/opentelemetry-instrumentation-milvus/opentelemetry/instrumentation/milvus/wrapper.py +++ b/packages/opentelemetry-instrumentation-milvus/opentelemetry/instrumentation/milvus/wrapper.py @@ -5,7 +5,7 @@ from opentelemetry.instrumentation.utils import ( _SUPPRESS_INSTRUMENTATION_KEY, ) -from opentelemetry.semconv_ai import Events +from opentelemetry.semconv_ai import Events, EventAttributes from opentelemetry.semconv_ai import SpanAttributes as AISpanAttributes @@ -55,9 +55,13 @@ def _wrap(tracer, to_wrap, wrapped, instance, args, kwargs): _set_create_collection_attributes(span, kwargs) return_value = wrapped(*args, **kwargs) + if to_wrap.get("method") == "query": _add_query_result_events(span, return_value) + if to_wrap.get("method") == "search": + _add_search_result_events(span, return_value) + return return_value @@ -198,7 +202,9 @@ def _set_search_attributes(span, kwargs): count_or_none(kwargs.get("output_fields")), ) _set_span_attribute( - span, AISpanAttributes.MILVUS_SEARCH_SEARCH_PARAMS, kwargs.get("search_params") + span, + AISpanAttributes.MILVUS_SEARCH_SEARCH_PARAMS, + _encode_include(kwargs.get("search_params")), ) _set_span_attribute( span, AISpanAttributes.MILVUS_SEARCH_TIMEOUT, kwargs.get("timeout") @@ -211,6 +217,19 @@ def _set_search_attributes(span, kwargs): _set_span_attribute( span, AISpanAttributes.MILVUS_SEARCH_ANNS_FIELD, kwargs.get("anns_field") ) + _set_span_attribute( + span, + AISpanAttributes.MILVUS_SEARCH_PARTITION_NAMES, + _encode_partition_name(kwargs.get("partition_names")), + ) + query_vectors = kwargs.get("data", []) + vector_dims = [len(vec) for vec in query_vectors] + + _set_span_attribute( + span, + AISpanAttributes.MILVUS_SEARCH_QUERY_VECTOR_DIMENSION, + _encode_include(vector_dims), + ) @dont_throw @@ -248,6 +267,60 @@ def _add_query_result_events(span, kwargs): span.add_event(name=Events.DB_QUERY_RESULT.value, attributes=element) +@dont_throw +def _add_search_result_events(span, kwargs): + + all_distances = [] + total_matches = 0 + + single_query = len(kwargs) == 1 + + def set_query_stats(query_idx, distances, match_ids): + """Helper function to set per-query stats in the span.""" + + _set_span_attribute( + span, + f"{AISpanAttributes.MILVUS_SEARCH_RESULT_COUNT}_{query_idx}", + len(distances), + ) + + def set_global_stats(): + """Helper function to set global stats for a single query.""" + _set_span_attribute( + span, AISpanAttributes.MILVUS_SEARCH_RESULT_COUNT, total_matches + ) + + for query_idx, query_results in enumerate(kwargs): + + query_distances = [] + query_match_ids = [] + + for match in query_results: + distance = float(match["distance"]) + query_distances.append(distance) + all_distances.append(distance) + total_matches += 1 + query_match_ids.append(match["id"]) + + span.add_event( + Events.DB_SEARCH_RESULT.value, + attributes={ + EventAttributes.DB_SEARCH_RESULT_QUERY_ID.value: query_idx, + EventAttributes.DB_SEARCH_RESULT_ID.value: match["id"], + EventAttributes.DB_SEARCH_RESULT_DISTANCE.value: str(distance), + EventAttributes.DB_SEARCH_RESULT_ENTITY.value: _encode_include( + match["entity"] + ), + }, + ) + + if not single_query: + set_query_stats(query_idx, query_distances, query_match_ids) + + if single_query: + set_global_stats() + + @dont_throw def _set_upsert_attributes(span, kwargs): _set_span_attribute( diff --git a/packages/opentelemetry-instrumentation-milvus/tests/test_search.py b/packages/opentelemetry-instrumentation-milvus/tests/test_search.py new file mode 100644 index 0000000000..d58b0337cc --- /dev/null +++ b/packages/opentelemetry-instrumentation-milvus/tests/test_search.py @@ -0,0 +1,200 @@ +import os +import random + +import pymilvus +import pytest +from opentelemetry.semconv_ai import Events, SpanAttributes, EventAttributes + +path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "milvus.db") +milvus = pymilvus.MilvusClient(uri=path) + + +@pytest.fixture +def collection(): + collection_name = "Colors" + milvus.create_collection(collection_name=collection_name, dimension=5) + yield collection_name + milvus.drop_collection(collection_name=collection_name) + + +def insert_data(collection): + colors = [ + "green", + "blue", + "yellow", + "red", + "black", + "white", + "purple", + "pink", + "orange", + "grey", + ] + data = [ + { + "id": i, + "vector": [random.uniform(-1, 1) for _ in range(5)], + "color": random.choice(colors), + "tag": random.randint(1000, 9999), + } + for i in range(1000) + ] + data += [ + { + "id": 1000, + "vector": [random.uniform(-1, 1) for _ in range(5)], + "color": "brown", + "tag": 1234, + }, + { + "id": 1001, + "vector": [random.uniform(-1, 1) for _ in range(5)], + "color": "brown", + "tag": 5678, + }, + { + "id": 1002, + "vector": [random.uniform(-1, 1) for _ in range(5)], + "color": "brown", + "tag": 9101, + }, + ] + for i in data: + i["color_tag"] = "{}_{}".format(i["color"], i["tag"]) + milvus.insert(collection_name=collection, data=data) + + +def test_milvus_single_vector_search(exporter, collection): + insert_data(collection) + + query_vectors = [ + [random.uniform(-1, 1) for _ in range(5)], # Random query vector for the search + ] + search_params = {"radius": 0.5, "metric_type": "COSINE", "index_type": "IVF_FLAT"} + milvus.search( + collection_name=collection, + data=query_vectors, + anns_field="vector", + search_params=search_params, + output_fields=["color_tag"], + limit=3, + timeout=10, + ) + + # Get finished spans + spans = exporter.get_finished_spans() + span = next(span for span in spans if span.name == "milvus.search") + + # Check the span attributes related to search + assert span.attributes.get(SpanAttributes.VECTOR_DB_VENDOR) == "milvus" + assert span.attributes.get(SpanAttributes.VECTOR_DB_OPERATION) == "search" + assert ( + span.attributes.get(SpanAttributes.MILVUS_SEARCH_COLLECTION_NAME) == collection + ) + assert span.attributes.get(SpanAttributes.MILVUS_SEARCH_OUTPUT_FIELDS_COUNT) == 1 + assert span.attributes.get(SpanAttributes.MILVUS_SEARCH_LIMIT) == 3 + assert span.attributes.get(SpanAttributes.MILVUS_SEARCH_TIMEOUT) == 10 + assert span.attributes.get(SpanAttributes.MILVUS_SEARCH_ANNS_FIELD) == "vector" + assert ( + span.attributes.get(SpanAttributes.MILVUS_SEARCH_QUERY_VECTOR_DIMENSION) + == "[5]" + ) + distances = [] + ids = [] + + events = span.events + + for event in events: + assert event.name == Events.DB_SEARCH_RESULT.value + _id = event.attributes.get(EventAttributes.DB_SEARCH_RESULT_ID.value) + distance = event.attributes.get(EventAttributes.DB_SEARCH_RESULT_DISTANCE.value) + + assert isinstance(_id, int) + assert isinstance(distance, str) + + # Collect the distances and IDs for further computation + distances.append( + float(distance) + ) # Convert the distance to a float for computation + ids.append(_id) + + # Now compute dynamic stats from the distances + total_matches = len(events) + + assert ( + span.attributes.get(SpanAttributes.MILVUS_SEARCH_RESULT_COUNT) == total_matches + ) + + +def test_milvus_multiple_vector_search(exporter, collection): + insert_data(collection) + + query_vectors = [ + [random.uniform(-1, 1) for _ in range(5)], # Random query vector for the search + [random.uniform(-1, 1) for _ in range(5)], # Another query vector + [ + random.uniform(-1, 1) for _ in range(5) + ], # Another query vector (you can add more as needed) + ] + search_params = {"radius": 0.5, "metric_type": "COSINE", "index_type": "IVF_FLAT"} + milvus.search( + collection_name=collection, + data=query_vectors, + anns_field="vector", + search_params=search_params, + output_fields=["color_tag"], + limit=3, + timeout=10, + ) + + # Get finished spans + spans = exporter.get_finished_spans() + span = next(span for span in spans if span.name == "milvus.search") + + # Check the span attributes related to search + assert span.attributes.get(SpanAttributes.VECTOR_DB_VENDOR) == "milvus" + assert span.attributes.get(SpanAttributes.VECTOR_DB_OPERATION) == "search" + assert ( + span.attributes.get(SpanAttributes.MILVUS_SEARCH_COLLECTION_NAME) == collection + ) + assert span.attributes.get(SpanAttributes.MILVUS_SEARCH_OUTPUT_FIELDS_COUNT) == 1 + assert span.attributes.get(SpanAttributes.MILVUS_SEARCH_LIMIT) == 3 + assert span.attributes.get(SpanAttributes.MILVUS_SEARCH_TIMEOUT) == 10 + assert span.attributes.get(SpanAttributes.MILVUS_SEARCH_ANNS_FIELD) == "vector" + assert ( + span.attributes.get(SpanAttributes.MILVUS_SEARCH_QUERY_VECTOR_DIMENSION) + == "[5, 5, 5]" + ) + + distances_dict = {} + ids_dict = {} + + events = span.events + for event in events: + assert event.name == Events.DB_SEARCH_RESULT.value + query_idx = event.attributes.get( + EventAttributes.DB_SEARCH_RESULT_QUERY_ID.value + ) + _id = event.attributes.get(EventAttributes.DB_SEARCH_RESULT_ID.value) + distance = event.attributes.get(EventAttributes.DB_SEARCH_RESULT_DISTANCE.value) + + assert isinstance(_id, int) + assert isinstance(distance, str) + + distance = float(distance) + + if query_idx not in distances_dict: + distances_dict[query_idx] = [] + ids_dict[query_idx] = [] + + distances_dict[query_idx].append(distance) + ids_dict[query_idx].append(_id) + + for query_idx in distances_dict: + distances = distances_dict[query_idx] + + total_matches = len(distances) + + count_key = f"{SpanAttributes.MILVUS_SEARCH_RESULT_COUNT}_{query_idx}" + + assert span.attributes.get(count_key) == total_matches