2626)
2727class TestLLMObsBedrock :
2828 @staticmethod
29- def expected_llmobs_span_event (span , n_output , message = False , metadata = None , token_metrics = None ):
29+ def expected_llmobs_span_event (
30+ span , n_output , input_message = False , output_message = False , metadata = None , token_metrics = None
31+ ):
3032 expected_input = [{"content" : mock .ANY }]
31- if message :
33+ if input_message :
3234 expected_input = [{"content" : mock .ANY , "role" : "user" }]
35+ expected_output = []
36+ if output_message :
37+ expected_output = [{"content" : mock .ANY } for _ in range (n_output )]
3338
3439 # Use empty dicts as defaults for _expected_llmobs_llm_span_event to avoid None issues
3540 expected_parameters = metadata if metadata is not None else {}
@@ -40,7 +45,7 @@ def expected_llmobs_span_event(span, n_output, message=False, metadata=None, tok
4045 model_name = span .get_tag ("bedrock.request.model" ),
4146 model_provider = span .get_tag ("bedrock.request.model_provider" ),
4247 input_messages = expected_input ,
43- output_messages = [{ "content" : mock . ANY } for _ in range ( n_output )] ,
48+ output_messages = expected_output ,
4449 metadata = expected_parameters ,
4550 token_metrics = expected_token_metrics ,
4651 tags = {"service" : "aws.bedrock-runtime" , "ml_app" : "<ml-app-name>" },
@@ -86,7 +91,7 @@ def _test_llmobs_invoke(cls, provider, bedrock_client, mock_tracer, llmobs_event
8691
8792 assert len (llmobs_events ) == 1
8893 assert llmobs_events [0 ] == cls .expected_llmobs_span_event (
89- span , n_output , message = "message" in provider , metadata = expected_metadata
94+ span , n_output , input_message = "message" in provider , output_message = True , metadata = expected_metadata
9095 )
9196 LLMObs .disable ()
9297
@@ -121,7 +126,7 @@ def _test_llmobs_invoke_stream(
121126
122127 assert len (llmobs_events ) == 1
123128 assert llmobs_events [0 ] == cls .expected_llmobs_span_event (
124- span , n_output , message = "message" in provider , metadata = expected_metadata
129+ span , n_output , input_message = "message" in provider , output_message = True , metadata = expected_metadata
125130 )
126131
127132 def test_llmobs_ai21_invoke (self , ddtrace_global_config , bedrock_client , mock_tracer , llmobs_events ):
@@ -156,6 +161,24 @@ def test_llmobs_cohere_multi_output_invoke(self, ddtrace_global_config, bedrock_
156161 def test_llmobs_meta_invoke (self , ddtrace_global_config , bedrock_client , mock_tracer , llmobs_events ):
157162 self ._test_llmobs_invoke ("meta" , bedrock_client , mock_tracer , llmobs_events )
158163
164+ def test_llmobs_cohere_rerank_invoke (self , ddtrace_global_config , bedrock_client , mock_tracer , llmobs_events ):
165+ cassette_name = "cohere_rerank_invoke.yaml"
166+ model = "cohere.rerank-v3-5:0"
167+ prompt_data = "What is the capital of the United States?"
168+ documents = [
169+ "Carson City is the capital city of the American state of Nevada." ,
170+ "The Commonwealth of the Northern Mariana Islands's capital is Saipan." ,
171+ ]
172+ body = json .dumps ({"query" : prompt_data , "documents" : documents , "api_version" : 2 , "top_n" : 3 })
173+ with get_request_vcr ().use_cassette (cassette_name ):
174+ response = bedrock_client .invoke_model (body = body , modelId = model )
175+ json .loads (response .get ("body" ).read ())
176+ span = mock_tracer .pop_traces ()[0 ][0 ]
177+
178+ assert len (llmobs_events ) == 1
179+ assert llmobs_events [0 ] == self .expected_llmobs_span_event (span , 1 )
180+ LLMObs .disable ()
181+
159182 def test_llmobs_amazon_invoke_stream (self , ddtrace_global_config , bedrock_client , mock_tracer , llmobs_events ):
160183 self ._test_llmobs_invoke_stream ("amazon" , bedrock_client , mock_tracer , llmobs_events )
161184
0 commit comments