2424from client import LlamaStackClientHolder
2525from configuration import configuration
2626from app .endpoints .conversations import conversation_id_to_agent_id
27+ import metrics
2728from models .responses import QueryResponse , UnauthorizedResponse , ForbiddenResponse
2829from models .requests import QueryRequest , Attachment
2930import constants
@@ -122,14 +123,18 @@ def query_endpoint_handler(
122123 try :
123124 # try to get Llama Stack client
124125 client = LlamaStackClientHolder ().get_client ()
125- model_id = select_model_id (client .models .list (), query_request )
126+ model_id , provider_id = select_model_and_provider_id (
127+ client .models .list (), query_request
128+ )
126129 response , conversation_id = retrieve_response (
127130 client ,
128131 model_id ,
129132 query_request ,
130133 token ,
131134 mcp_headers = mcp_headers ,
132135 )
136+ # Update metrics for the LLM call
137+ metrics .llm_calls_total .labels (provider_id , model_id ).inc ()
133138
134139 if not is_transcripts_enabled ():
135140 logger .debug ("Transcript collection is disabled in the configuration" )
@@ -150,6 +155,8 @@ def query_endpoint_handler(
150155
151156 # connection to Llama Stack server
152157 except APIConnectionError as e :
158+ # Update metrics for the LLM call failure
159+ metrics .llm_calls_failures_total .inc ()
153160 logger .error ("Unable to connect to Llama Stack: %s" , e )
154161 raise HTTPException (
155162 status_code = status .HTTP_500_INTERNAL_SERVER_ERROR ,
@@ -160,8 +167,10 @@ def query_endpoint_handler(
160167 ) from e
161168
162169
163- def select_model_id (models : ModelListResponse , query_request : QueryRequest ) -> str :
164- """Select the model ID based on the request or available models."""
170+ def select_model_and_provider_id (
171+ models : ModelListResponse , query_request : QueryRequest
172+ ) -> tuple [str , str | None ]:
173+ """Select the model ID and provider ID based on the request or available models."""
165174 model_id = query_request .model
166175 provider_id = query_request .provider
167176
@@ -173,9 +182,11 @@ def select_model_id(models: ModelListResponse, query_request: QueryRequest) -> s
173182 m
174183 for m in models
175184 if m .model_type == "llm" # pyright: ignore[reportAttributeAccessIssue]
176- ).identifier
185+ )
186+ model_id = model .identifier
187+ provider_id = model .provider_id
177188 logger .info ("Selected model: %s" , model )
178- return model
189+ return model_id , provider_id
179190 except (StopIteration , AttributeError ) as e :
180191 message = "No LLM model found in available models"
181192 logger .error (message )
@@ -201,7 +212,7 @@ def select_model_id(models: ModelListResponse, query_request: QueryRequest) -> s
201212 },
202213 )
203214
204- return model_id
215+ return model_id , provider_id
205216
206217
207218def _is_inout_shield (shield : Shield ) -> bool :
@@ -218,7 +229,7 @@ def is_input_shield(shield: Shield) -> bool:
218229 return _is_inout_shield (shield ) or not is_output_shield (shield )
219230
220231
221- def retrieve_response (
232+ def retrieve_response ( # pylint: disable=too-many-locals
222233 client : LlamaStackClient ,
223234 model_id : str ,
224235 query_request : QueryRequest ,
@@ -288,6 +299,14 @@ def retrieve_response(
288299 toolgroups = toolgroups or None ,
289300 )
290301
302+ # Check for validation errors in the response
303+ steps = getattr (response , "steps" , [])
304+ for step in steps :
305+ if step .step_type == "shield_call" and step .violation :
306+ # Metric for LLM validation errors
307+ metrics .llm_calls_validation_errors_total .inc ()
308+ break
309+
291310 return str (response .output_message .content ), conversation_id # type: ignore[union-attr]
292311
293312
0 commit comments