From d84047be415c5defbbaf9649c3b2de598516c966 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Mon, 10 Feb 2025 18:45:02 -0800 Subject: [PATCH 1/5] fix(litellm_logging.py): support saving applied guardrails in logging object allows list of applied guardrails to be logged for proxy admin's knowledge --- litellm/litellm_core_utils/litellm_logging.py | 3 + litellm/proxy/_new_secret_config.yaml | 9 +- litellm/utils.py | 34 +++++++ tests/litellm_utils_tests/test_utils.py | 96 ++++++++++++++++++- 4 files changed, 136 insertions(+), 6 deletions(-) diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index 28182b75acd1..e8be1b901f47 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -199,6 +199,7 @@ def __init__( dynamic_async_failure_callbacks: Optional[ List[Union[str, Callable, CustomLogger]] ] = None, + applied_guardrails: Optional[List[str]] = None, kwargs: Optional[Dict] = None, ): _input: Optional[str] = messages # save original value of messages @@ -271,6 +272,7 @@ def __init__( "litellm_call_id": litellm_call_id, "input": _input, "litellm_params": litellm_params, + "applied_guardrails": applied_guardrails, } def process_dynamic_callbacks(self): @@ -2866,6 +2868,7 @@ def get_standard_logging_metadata( - If the input metadata is None or not a dictionary, an empty StandardLoggingMetadata object is returned. - If 'user_api_key' is present in metadata and is a valid SHA256 hash, it's stored as 'user_api_key_hash'. """ + prompt_management_metadata: Optional[ StandardLoggingPromptManagementMetadata ] = None diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index 5347d4a791ba..aa6808adc0df 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -49,4 +49,11 @@ general_settings: router_settings: redis_host: os.environ/REDIS_HOST redis_password: os.environ/REDIS_PASSWORD - redis_port: os.environ/REDIS_PORT \ No newline at end of file + redis_port: os.environ/REDIS_PORT + +guardrails: + - guardrail_name: "custom-pre-guard" + litellm_params: + guardrail: custom_guardrail.myCustomGuardrail # 👈 Key change + mode: "pre_call" # runs async_pre_call_hook + default_on: true \ No newline at end of file diff --git a/litellm/utils.py b/litellm/utils.py index 7cdfc2ebbe46..dbefc90dc61d 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -60,6 +60,7 @@ from litellm.caching._internal_lru_cache import lru_cache_wrapper from litellm.caching.caching import DualCache from litellm.caching.caching_handler import CachingHandlerResponse, LLMCachingHandler +from litellm.integrations.custom_guardrail import CustomGuardrail from litellm.integrations.custom_logger import CustomLogger from litellm.litellm_core_utils.core_helpers import ( map_finish_reason, @@ -418,6 +419,35 @@ def _custom_logger_class_exists_in_failure_callbacks( ) +def get_request_guardrails(kwargs: Dict[str, Any]) -> List[str]: + """ + Get the request guardrails from the kwargs + """ + metadata = kwargs.get("metadata") or {} + requester_metadata = metadata.get("requester_metadata") or {} + applied_guardrails = requester_metadata.get("guardrails") or [] + return applied_guardrails + + +def get_applied_guardrails(kwargs: Dict[str, Any]) -> List[str]: + """ + - Add 'default_on' guardrails to the list + - Add request guardrails to the list + """ + + request_guardrails = get_request_guardrails(kwargs) + applied_guardrails = [] + for callback in litellm.callbacks: + if callback is not None and isinstance(callback, CustomGuardrail): + if callback.guardrail_name is not None: + if callback.default_on is True: + applied_guardrails.append(callback.guardrail_name) + elif callback.guardrail_name in request_guardrails: + applied_guardrails.append(callback.guardrail_name) + + return applied_guardrails + + def function_setup( # noqa: PLR0915 original_function: str, rules_obj, start_time, *args, **kwargs ): # just run once to check if user wants to send their data anywhere - PostHog/Sentry/Slack/etc. @@ -436,6 +466,9 @@ def function_setup( # noqa: PLR0915 ## CUSTOM LLM SETUP ## custom_llm_setup() + ## GET APPLIED GUARDRAILS + applied_guardrails = get_applied_guardrails(kwargs) + ## LOGGING SETUP function_id: Optional[str] = kwargs["id"] if "id" in kwargs else None @@ -677,6 +710,7 @@ def function_setup( # noqa: PLR0915 dynamic_async_success_callbacks=dynamic_async_success_callbacks, dynamic_async_failure_callbacks=dynamic_async_failure_callbacks, kwargs=kwargs, + applied_guardrails=applied_guardrails, ) ## check if metadata is passed in diff --git a/tests/litellm_utils_tests/test_utils.py b/tests/litellm_utils_tests/test_utils.py index 75630c81d89a..b19282563c4a 100644 --- a/tests/litellm_utils_tests/test_utils.py +++ b/tests/litellm_utils_tests/test_utils.py @@ -864,17 +864,24 @@ def test_convert_model_response_object(): == '{"type":"error","error":{"type":"invalid_request_error","message":"Output blocked by content filtering policy"}}' ) + @pytest.mark.parametrize( - "content, expected_reasoning, expected_content", + "content, expected_reasoning, expected_content", [ (None, None, None), - ("I am thinking hereThe sky is a canvas of blue", "I am thinking here", "The sky is a canvas of blue"), + ( + "I am thinking hereThe sky is a canvas of blue", + "I am thinking here", + "The sky is a canvas of blue", + ), ("I am a regular response", None, "I am a regular response"), - - ] + ], ) def test_parse_content_for_reasoning(content, expected_reasoning, expected_content): - assert(litellm.utils._parse_content_for_reasoning(content) == (expected_reasoning, expected_content)) + assert litellm.utils._parse_content_for_reasoning(content) == ( + expected_reasoning, + expected_content, + ) @pytest.mark.parametrize( @@ -1874,3 +1881,82 @@ def test_validate_user_messages_invalid_content_type(): assert "Invalid message" in str(e) print(e) + + +from litellm.integrations.custom_guardrail import CustomGuardrail +from litellm.utils import get_applied_guardrails +from unittest.mock import Mock + + +@pytest.mark.parametrize( + "test_case", + [ + { + "name": "default_on_guardrail", + "callbacks": [ + CustomGuardrail(guardrail_name="test_guardrail", default_on=True) + ], + "kwargs": {"metadata": {"requester_metadata": {"guardrails": []}}}, + "expected": ["test_guardrail"], + }, + { + "name": "request_specific_guardrail", + "callbacks": [ + CustomGuardrail(guardrail_name="test_guardrail", default_on=False) + ], + "kwargs": { + "metadata": {"requester_metadata": {"guardrails": ["test_guardrail"]}} + }, + "expected": ["test_guardrail"], + }, + { + "name": "multiple_guardrails", + "callbacks": [ + CustomGuardrail(guardrail_name="default_guardrail", default_on=True), + CustomGuardrail(guardrail_name="request_guardrail", default_on=False), + ], + "kwargs": { + "metadata": { + "requester_metadata": {"guardrails": ["request_guardrail"]} + } + }, + "expected": ["default_guardrail", "request_guardrail"], + }, + { + "name": "empty_metadata", + "callbacks": [ + CustomGuardrail(guardrail_name="test_guardrail", default_on=False) + ], + "kwargs": {}, + "expected": [], + }, + { + "name": "none_callback", + "callbacks": [ + None, + CustomGuardrail(guardrail_name="test_guardrail", default_on=True), + ], + "kwargs": {}, + "expected": ["test_guardrail"], + }, + { + "name": "non_guardrail_callback", + "callbacks": [ + Mock(), + CustomGuardrail(guardrail_name="test_guardrail", default_on=True), + ], + "kwargs": {}, + "expected": ["test_guardrail"], + }, + ], +) +def test_get_applied_guardrails(test_case): + + # Setup + litellm.callbacks = test_case["callbacks"] + + # Execute + result = get_applied_guardrails(test_case["kwargs"]) + + # Assert + assert sorted(result) == sorted(test_case["expected"]) From 3ef79c50c69c36badfe4438c509f4cc20f9901b6 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Mon, 10 Feb 2025 18:57:14 -0800 Subject: [PATCH 2/5] feat(spend_tracking_utils.py): log applied guardrails to spend logs makes it easy for admin to know what guardrails were applied on a request --- litellm/litellm_core_utils/litellm_logging.py | 4 ++++ litellm/proxy/_types.py | 1 + .../spend_tracking/spend_tracking_utils.py | 20 +++++++++++++++---- litellm/types/utils.py | 1 + 4 files changed, 22 insertions(+), 4 deletions(-) diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index e8be1b901f47..516d9dba349e 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -2854,6 +2854,7 @@ def get_standard_logging_metadata( metadata: Optional[Dict[str, Any]], litellm_params: Optional[dict] = None, prompt_integration: Optional[str] = None, + applied_guardrails: Optional[List[str]] = None, ) -> StandardLoggingMetadata: """ Clean and filter the metadata dictionary to include only the specified keys in StandardLoggingMetadata. @@ -2898,6 +2899,7 @@ def get_standard_logging_metadata( requester_metadata=None, user_api_key_end_user_id=None, prompt_management_metadata=prompt_management_metadata, + applied_guardrails=applied_guardrails, ) if isinstance(metadata, dict): # Filter the metadata dictionary to include only the specified keys @@ -3196,6 +3198,7 @@ def get_standard_logging_object_payload( metadata=metadata, litellm_params=litellm_params, prompt_integration=kwargs.get("prompt_integration", None), + applied_guardrails=kwargs.get("applied_guardrails", None), ) _request_body = proxy_server_request.get("body", {}) @@ -3331,6 +3334,7 @@ def get_standard_logging_metadata( requester_metadata=None, user_api_key_end_user_id=None, prompt_management_metadata=None, + applied_guardrails=None, ) if isinstance(metadata, dict): # Filter the metadata dictionary to include only the specified keys diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 7b2435e67c49..0a0be7e1fb47 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -1790,6 +1790,7 @@ class SpendLogsMetadata(TypedDict): dict ] # special param to log k,v pairs to spendlogs for a call requester_ip_address: Optional[str] + applied_guardrails: Optional[List[str]] class SpendLogsPayload(TypedDict): diff --git a/litellm/proxy/spend_tracking/spend_tracking_utils.py b/litellm/proxy/spend_tracking/spend_tracking_utils.py index ccf0836e05bb..f12220766baa 100644 --- a/litellm/proxy/spend_tracking/spend_tracking_utils.py +++ b/litellm/proxy/spend_tracking/spend_tracking_utils.py @@ -3,7 +3,7 @@ from datetime import datetime from datetime import datetime as dt from datetime import timezone -from typing import Optional, cast +from typing import List, Optional, cast from pydantic import BaseModel @@ -32,7 +32,9 @@ def _is_master_key(api_key: str, _master_key: Optional[str]) -> bool: return False -def _get_spend_logs_metadata(metadata: Optional[dict]) -> SpendLogsMetadata: +def _get_spend_logs_metadata( + metadata: Optional[dict], applied_guardrails: Optional[List[str]] = None +) -> SpendLogsMetadata: if metadata is None: return SpendLogsMetadata( user_api_key=None, @@ -44,8 +46,9 @@ def _get_spend_logs_metadata(metadata: Optional[dict]) -> SpendLogsMetadata: spend_logs_metadata=None, requester_ip_address=None, additional_usage_values=None, + applied_guardrails=None, ) - verbose_proxy_logger.debug( + verbose_proxy_logger.info( "getting payload for SpendLogs, available keys in metadata: " + str(list(metadata.keys())) ) @@ -58,6 +61,8 @@ def _get_spend_logs_metadata(metadata: Optional[dict]) -> SpendLogsMetadata: if key in metadata } ) + clean_metadata["applied_guardrails"] = applied_guardrails + return clean_metadata @@ -130,7 +135,14 @@ def get_logging_payload( # noqa: PLR0915 _model_group = metadata.get("model_group", "") # clean up litellm metadata - clean_metadata = _get_spend_logs_metadata(metadata) + clean_metadata = _get_spend_logs_metadata( + metadata, + applied_guardrails=( + standard_logging_payload["metadata"].get("applied_guardrails", None) + if standard_logging_payload is not None + else None + ), + ) special_usage_fields = ["completion_tokens", "prompt_tokens", "total_tokens"] additional_usage_values = {} diff --git a/litellm/types/utils.py b/litellm/types/utils.py index 556bae94e759..9139f6be4c62 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -1525,6 +1525,7 @@ class StandardLoggingMetadata(StandardLoggingUserAPIKeyMetadata): requester_ip_address: Optional[str] requester_metadata: Optional[dict] prompt_management_metadata: Optional[StandardLoggingPromptManagementMetadata] + applied_guardrails: Optional[List[str]] class StandardLoggingAdditionalHeaders(TypedDict, total=False): From ef3c9408d2b1f935458f81103f3d34d7146c2cb9 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Mon, 10 Feb 2025 20:36:14 -0800 Subject: [PATCH 3/5] ci(config.yml): uninstall posthog from ci/cd --- .circleci/config.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.circleci/config.yml b/.circleci/config.yml index 1af15b03a5bc..e5e015f5199b 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -72,6 +72,7 @@ jobs: pip install "jsonschema==4.22.0" pip install "pytest-xdist==3.6.1" pip install "websockets==10.4" + pip uninstall posthog -y - save_cache: paths: - ./venv From 47d28c8afad20bc622e583c0ded1290dae7cca01 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Mon, 10 Feb 2025 21:56:50 -0800 Subject: [PATCH 4/5] test: fix tests --- tests/local_testing/test_traceloop.py | 1 + .../gcs_pub_sub_body/spend_logs_payload.json | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/local_testing/test_traceloop.py b/tests/local_testing/test_traceloop.py index 5cab8dd59cca..0deb319e1e80 100644 --- a/tests/local_testing/test_traceloop.py +++ b/tests/local_testing/test_traceloop.py @@ -26,6 +26,7 @@ def exporter(): return exporter +@pytest.mark.skip(reason="moved to using 'otel' for logging") @pytest.mark.parametrize("model", ["claude-3-5-haiku-20241022", "gpt-3.5-turbo"]) def test_traceloop_logging(exporter, model): litellm.completion( diff --git a/tests/logging_callback_tests/gcs_pub_sub_body/spend_logs_payload.json b/tests/logging_callback_tests/gcs_pub_sub_body/spend_logs_payload.json index 0e78e60b7693..08c6b4518328 100644 --- a/tests/logging_callback_tests/gcs_pub_sub_body/spend_logs_payload.json +++ b/tests/logging_callback_tests/gcs_pub_sub_body/spend_logs_payload.json @@ -9,7 +9,7 @@ "model": "gpt-4o", "user": "", "team_id": "", - "metadata": "{\"additional_usage_values\": {\"completion_tokens_details\": null, \"prompt_tokens_details\": null}}", + "metadata": "{\"applied_guardrails\": [], \"additional_usage_values\": {\"completion_tokens_details\": null, \"prompt_tokens_details\": null}}", "cache_key": "Cache OFF", "spend": 0.00022500000000000002, "total_tokens": 30, From 9d0ff7ec3167faa574642dcd0a1cc04e34270677 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Mon, 10 Feb 2025 22:12:52 -0800 Subject: [PATCH 5/5] test: update test --- tests/logging_callback_tests/test_otel_logging.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/logging_callback_tests/test_otel_logging.py b/tests/logging_callback_tests/test_otel_logging.py index 9c19c9d261ab..d37e46bf19f3 100644 --- a/tests/logging_callback_tests/test_otel_logging.py +++ b/tests/logging_callback_tests/test_otel_logging.py @@ -272,6 +272,7 @@ def validate_redacted_message_span_attributes(span): "metadata.user_api_key_user_id", "metadata.user_api_key_org_id", "metadata.user_api_key_end_user_id", + "metadata.applied_guardrails", ] _all_attributes = set(