|
14 | 14 | # limitations under the License. |
15 | 15 |
|
16 | 16 | import logging |
| 17 | +import os |
17 | 18 | import re |
18 | | -from typing import List, Optional, Tuple, Union |
| 19 | +from typing import List, Literal, Optional, Tuple, Union |
19 | 20 |
|
| 21 | +import aiohttp |
20 | 22 | from langchain_core.language_models.llms import BaseLLM |
21 | 23 |
|
22 | 24 | from nemoguardrails.actions import action |
@@ -106,5 +108,140 @@ async def patronus_lynx_check_output_hallucination( |
106 | 108 | ) |
107 | 109 |
|
108 | 110 | hallucination, reasoning = parse_patronus_lynx_response(result) |
109 | | - print(f"Hallucination: {hallucination}, Reasoning: {reasoning}") |
110 | 111 | return {"hallucination": hallucination, "reasoning": reasoning} |
| 112 | + |
| 113 | + |
| 114 | +def check_guardrail_pass( |
| 115 | + response: Optional[dict], success_strategy: Literal["all_pass", "any_pass"] |
| 116 | +) -> bool: |
| 117 | + """ |
| 118 | + Check if evaluations in the Patronus API response pass based on the success strategy. |
| 119 | + "all_pass" requires all evaluators to pass for success. |
| 120 | + "any_pass" requires only one evaluator to pass for success. |
| 121 | + """ |
| 122 | + if not response or "results" not in response: |
| 123 | + return False |
| 124 | + |
| 125 | + evaluations = response["results"] |
| 126 | + |
| 127 | + if success_strategy == "all_pass": |
| 128 | + return all( |
| 129 | + "evaluation_result" in result |
| 130 | + and isinstance(result["evaluation_result"], dict) |
| 131 | + and result["evaluation_result"].get("pass", False) |
| 132 | + for result in evaluations |
| 133 | + ) |
| 134 | + return any( |
| 135 | + "evaluation_result" in result |
| 136 | + and isinstance(result["evaluation_result"], dict) |
| 137 | + and result["evaluation_result"].get("pass", False) |
| 138 | + for result in evaluations |
| 139 | + ) |
| 140 | + |
| 141 | + |
| 142 | +async def patronus_evaluate_request( |
| 143 | + api_params: dict, |
| 144 | + user_input: Optional[str] = None, |
| 145 | + bot_response: Optional[str] = None, |
| 146 | + provided_context: Optional[Union[str, List[str]]] = None, |
| 147 | +) -> Optional[dict]: |
| 148 | + """ |
| 149 | + Make a call to the Patronus Evaluate API. |
| 150 | +
|
| 151 | + Returns a dictionary of the API response JSON if successful, or None if a server error occurs. |
| 152 | + * Server errors will cause the guardrail to block the bot response |
| 153 | +
|
| 154 | + Raises a ValueError for client errors (400-499), as these indicate invalid requests. |
| 155 | + """ |
| 156 | + api_key = os.environ.get("PATRONUS_API_KEY") |
| 157 | + |
| 158 | + if api_key is None: |
| 159 | + raise ValueError("PATRONUS_API_KEY environment variable not set.") |
| 160 | + |
| 161 | + if "evaluators" not in api_params: |
| 162 | + raise ValueError( |
| 163 | + "The Patronus Evaluate API parameters must contain an 'evaluators' field" |
| 164 | + ) |
| 165 | + evaluators = api_params["evaluators"] |
| 166 | + if not isinstance(evaluators, list): |
| 167 | + raise ValueError( |
| 168 | + "The Patronus Evaluate API parameter 'evaluators' must be a list" |
| 169 | + ) |
| 170 | + |
| 171 | + for evaluator in evaluators: |
| 172 | + if not isinstance(evaluator, dict): |
| 173 | + raise ValueError( |
| 174 | + "Each object in the 'evaluators' list must be a dictionary" |
| 175 | + ) |
| 176 | + if "evaluator" not in evaluator: |
| 177 | + raise ValueError( |
| 178 | + "Each dictionary in the 'evaluators' list must contain the 'evaluator' field" |
| 179 | + ) |
| 180 | + |
| 181 | + data = { |
| 182 | + **api_params, |
| 183 | + "evaluated_model_input": user_input, |
| 184 | + "evaluated_model_output": bot_response, |
| 185 | + "evaluated_model_retrieved_context": provided_context, |
| 186 | + } |
| 187 | + |
| 188 | + url = "https://api.patronus.ai/v1/evaluate" |
| 189 | + headers = { |
| 190 | + "X-API-KEY": api_key, |
| 191 | + "Content-Type": "application/json", |
| 192 | + } |
| 193 | + |
| 194 | + async with aiohttp.ClientSession() as session: |
| 195 | + async with session.post( |
| 196 | + url=url, |
| 197 | + headers=headers, |
| 198 | + json=data, |
| 199 | + ) as response: |
| 200 | + if 400 <= response.status < 500: |
| 201 | + raise ValueError( |
| 202 | + f"The Patronus Evaluate API call failed with status code {response.status}. " |
| 203 | + f"Details: {await response.text()}" |
| 204 | + ) |
| 205 | + |
| 206 | + if response.status != 200: |
| 207 | + log.error( |
| 208 | + "The Patronus Evaluate API call failed with status code %s. Details: %s", |
| 209 | + response.status, |
| 210 | + await response.text(), |
| 211 | + ) |
| 212 | + return None |
| 213 | + |
| 214 | + response_json = await response.json() |
| 215 | + return response_json |
| 216 | + |
| 217 | + |
| 218 | +@action(name="patronus_api_check_output") |
| 219 | +async def patronus_api_check_output( |
| 220 | + llm_task_manager: LLMTaskManager, |
| 221 | + context: Optional[dict] = None, |
| 222 | +) -> dict: |
| 223 | + """ |
| 224 | + Check the user message, bot response, and/or provided context |
| 225 | + for issues based on the Patronus Evaluate API |
| 226 | + """ |
| 227 | + user_input = context.get("user_message") |
| 228 | + bot_response = context.get("bot_message") |
| 229 | + provided_context = context.get("relevant_chunks") |
| 230 | + |
| 231 | + patronus_config = llm_task_manager.config.rails.config.patronus.output |
| 232 | + evaluate_config = getattr(patronus_config, "evaluate_config", {}) |
| 233 | + success_strategy: Literal["all_pass", "any_pass"] = getattr( |
| 234 | + evaluate_config, "success_strategy", "all_pass" |
| 235 | + ) |
| 236 | + api_params = getattr(evaluate_config, "params", {}) |
| 237 | + response = await patronus_evaluate_request( |
| 238 | + api_params=api_params, |
| 239 | + user_input=user_input, |
| 240 | + bot_response=bot_response, |
| 241 | + provided_context=provided_context, |
| 242 | + ) |
| 243 | + return { |
| 244 | + "pass": check_guardrail_pass( |
| 245 | + response=response, success_strategy=success_strategy |
| 246 | + ) |
| 247 | + } |
0 commit comments