Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Log applied guardrails on LLM API call #8452

Merged
merged 6 commits into from
Feb 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions litellm/litellm_core_utils/litellm_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ def __init__(
dynamic_async_failure_callbacks: Optional[
List[Union[str, Callable, CustomLogger]]
] = None,
applied_guardrails: Optional[List[str]] = None,
kwargs: Optional[Dict] = None,
):
_input: Optional[str] = messages # save original value of messages
Expand Down Expand Up @@ -271,6 +272,7 @@ def __init__(
"litellm_call_id": litellm_call_id,
"input": _input,
"litellm_params": litellm_params,
"applied_guardrails": applied_guardrails,
}

def process_dynamic_callbacks(self):
Expand Down Expand Up @@ -2852,6 +2854,7 @@ def get_standard_logging_metadata(
metadata: Optional[Dict[str, Any]],
litellm_params: Optional[dict] = None,
prompt_integration: Optional[str] = None,
applied_guardrails: Optional[List[str]] = None,
) -> StandardLoggingMetadata:
"""
Clean and filter the metadata dictionary to include only the specified keys in StandardLoggingMetadata.
Expand All @@ -2866,6 +2869,7 @@ def get_standard_logging_metadata(
- If the input metadata is None or not a dictionary, an empty StandardLoggingMetadata object is returned.
- If 'user_api_key' is present in metadata and is a valid SHA256 hash, it's stored as 'user_api_key_hash'.
"""

prompt_management_metadata: Optional[
StandardLoggingPromptManagementMetadata
] = None
Expand Down Expand Up @@ -2895,6 +2899,7 @@ def get_standard_logging_metadata(
requester_metadata=None,
user_api_key_end_user_id=None,
prompt_management_metadata=prompt_management_metadata,
applied_guardrails=applied_guardrails,
)
if isinstance(metadata, dict):
# Filter the metadata dictionary to include only the specified keys
Expand Down Expand Up @@ -3193,6 +3198,7 @@ def get_standard_logging_object_payload(
metadata=metadata,
litellm_params=litellm_params,
prompt_integration=kwargs.get("prompt_integration", None),
applied_guardrails=kwargs.get("applied_guardrails", None),
)

_request_body = proxy_server_request.get("body", {})
Expand Down Expand Up @@ -3328,6 +3334,7 @@ def get_standard_logging_metadata(
requester_metadata=None,
user_api_key_end_user_id=None,
prompt_management_metadata=None,
applied_guardrails=None,
)
if isinstance(metadata, dict):
# Filter the metadata dictionary to include only the specified keys
Expand Down
1 change: 1 addition & 0 deletions litellm/proxy/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1794,6 +1794,7 @@ class SpendLogsMetadata(TypedDict):
dict
] # special param to log k,v pairs to spendlogs for a call
requester_ip_address: Optional[str]
applied_guardrails: Optional[List[str]]


class SpendLogsPayload(TypedDict):
Expand Down
20 changes: 16 additions & 4 deletions litellm/proxy/spend_tracking/spend_tracking_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from datetime import datetime
from datetime import datetime as dt
from datetime import timezone
from typing import Optional, cast
from typing import List, Optional, cast

from pydantic import BaseModel

Expand Down Expand Up @@ -32,7 +32,9 @@ def _is_master_key(api_key: str, _master_key: Optional[str]) -> bool:
return False


def _get_spend_logs_metadata(metadata: Optional[dict]) -> SpendLogsMetadata:
def _get_spend_logs_metadata(
metadata: Optional[dict], applied_guardrails: Optional[List[str]] = None
) -> SpendLogsMetadata:
if metadata is None:
return SpendLogsMetadata(
user_api_key=None,
Expand All @@ -44,8 +46,9 @@ def _get_spend_logs_metadata(metadata: Optional[dict]) -> SpendLogsMetadata:
spend_logs_metadata=None,
requester_ip_address=None,
additional_usage_values=None,
applied_guardrails=None,
)
verbose_proxy_logger.debug(
verbose_proxy_logger.info(
"getting payload for SpendLogs, available keys in metadata: "
+ str(list(metadata.keys()))
)
Expand All @@ -58,6 +61,8 @@ def _get_spend_logs_metadata(metadata: Optional[dict]) -> SpendLogsMetadata:
if key in metadata
}
)
clean_metadata["applied_guardrails"] = applied_guardrails

return clean_metadata


Expand Down Expand Up @@ -130,7 +135,14 @@ def get_logging_payload( # noqa: PLR0915
_model_group = metadata.get("model_group", "")

# clean up litellm metadata
clean_metadata = _get_spend_logs_metadata(metadata)
clean_metadata = _get_spend_logs_metadata(
metadata,
applied_guardrails=(
standard_logging_payload["metadata"].get("applied_guardrails", None)
if standard_logging_payload is not None
else None
),
)

special_usage_fields = ["completion_tokens", "prompt_tokens", "total_tokens"]
additional_usage_values = {}
Expand Down
1 change: 1 addition & 0 deletions litellm/types/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1525,6 +1525,7 @@ class StandardLoggingMetadata(StandardLoggingUserAPIKeyMetadata):
requester_ip_address: Optional[str]
requester_metadata: Optional[dict]
prompt_management_metadata: Optional[StandardLoggingPromptManagementMetadata]
applied_guardrails: Optional[List[str]]


class StandardLoggingAdditionalHeaders(TypedDict, total=False):
Expand Down
34 changes: 34 additions & 0 deletions litellm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
from litellm.caching._internal_lru_cache import lru_cache_wrapper
from litellm.caching.caching import DualCache
from litellm.caching.caching_handler import CachingHandlerResponse, LLMCachingHandler
from litellm.integrations.custom_guardrail import CustomGuardrail
from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.core_helpers import (
map_finish_reason,
Expand Down Expand Up @@ -418,6 +419,35 @@ def _custom_logger_class_exists_in_failure_callbacks(
)


def get_request_guardrails(kwargs: Dict[str, Any]) -> List[str]:
"""
Get the request guardrails from the kwargs
"""
metadata = kwargs.get("metadata") or {}
requester_metadata = metadata.get("requester_metadata") or {}
applied_guardrails = requester_metadata.get("guardrails") or []
return applied_guardrails


def get_applied_guardrails(kwargs: Dict[str, Any]) -> List[str]:
"""
- Add 'default_on' guardrails to the list
- Add request guardrails to the list
"""

request_guardrails = get_request_guardrails(kwargs)
applied_guardrails = []
for callback in litellm.callbacks:
if callback is not None and isinstance(callback, CustomGuardrail):
if callback.guardrail_name is not None:
if callback.default_on is True:
applied_guardrails.append(callback.guardrail_name)
elif callback.guardrail_name in request_guardrails:
applied_guardrails.append(callback.guardrail_name)

return applied_guardrails


def function_setup( # noqa: PLR0915
original_function: str, rules_obj, start_time, *args, **kwargs
): # just run once to check if user wants to send their data anywhere - PostHog/Sentry/Slack/etc.
Expand All @@ -436,6 +466,9 @@ def function_setup( # noqa: PLR0915
## CUSTOM LLM SETUP ##
custom_llm_setup()

## GET APPLIED GUARDRAILS
applied_guardrails = get_applied_guardrails(kwargs)

## LOGGING SETUP
function_id: Optional[str] = kwargs["id"] if "id" in kwargs else None

Expand Down Expand Up @@ -677,6 +710,7 @@ def function_setup( # noqa: PLR0915
dynamic_async_success_callbacks=dynamic_async_success_callbacks,
dynamic_async_failure_callbacks=dynamic_async_failure_callbacks,
kwargs=kwargs,
applied_guardrails=applied_guardrails,
)

## check if metadata is passed in
Expand Down
96 changes: 91 additions & 5 deletions tests/litellm_utils_tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -864,17 +864,24 @@ def test_convert_model_response_object():
== '{"type":"error","error":{"type":"invalid_request_error","message":"Output blocked by content filtering policy"}}'
)


@pytest.mark.parametrize(
"content, expected_reasoning, expected_content",
"content, expected_reasoning, expected_content",
[
(None, None, None),
("<think>I am thinking here</think>The sky is a canvas of blue", "I am thinking here", "The sky is a canvas of blue"),
(
"<think>I am thinking here</think>The sky is a canvas of blue",
"I am thinking here",
"The sky is a canvas of blue",
),
("I am a regular response", None, "I am a regular response"),

]
],
)
def test_parse_content_for_reasoning(content, expected_reasoning, expected_content):
assert(litellm.utils._parse_content_for_reasoning(content) == (expected_reasoning, expected_content))
assert litellm.utils._parse_content_for_reasoning(content) == (
expected_reasoning,
expected_content,
)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -1874,3 +1881,82 @@ def test_validate_user_messages_invalid_content_type():

assert "Invalid message" in str(e)
print(e)


from litellm.integrations.custom_guardrail import CustomGuardrail
from litellm.utils import get_applied_guardrails
from unittest.mock import Mock


@pytest.mark.parametrize(
"test_case",
[
{
"name": "default_on_guardrail",
"callbacks": [
CustomGuardrail(guardrail_name="test_guardrail", default_on=True)
],
"kwargs": {"metadata": {"requester_metadata": {"guardrails": []}}},
"expected": ["test_guardrail"],
},
{
"name": "request_specific_guardrail",
"callbacks": [
CustomGuardrail(guardrail_name="test_guardrail", default_on=False)
],
"kwargs": {
"metadata": {"requester_metadata": {"guardrails": ["test_guardrail"]}}
},
"expected": ["test_guardrail"],
},
{
"name": "multiple_guardrails",
"callbacks": [
CustomGuardrail(guardrail_name="default_guardrail", default_on=True),
CustomGuardrail(guardrail_name="request_guardrail", default_on=False),
],
"kwargs": {
"metadata": {
"requester_metadata": {"guardrails": ["request_guardrail"]}
}
},
"expected": ["default_guardrail", "request_guardrail"],
},
{
"name": "empty_metadata",
"callbacks": [
CustomGuardrail(guardrail_name="test_guardrail", default_on=False)
],
"kwargs": {},
"expected": [],
},
{
"name": "none_callback",
"callbacks": [
None,
CustomGuardrail(guardrail_name="test_guardrail", default_on=True),
],
"kwargs": {},
"expected": ["test_guardrail"],
},
{
"name": "non_guardrail_callback",
"callbacks": [
Mock(),
CustomGuardrail(guardrail_name="test_guardrail", default_on=True),
],
"kwargs": {},
"expected": ["test_guardrail"],
},
],
)
def test_get_applied_guardrails(test_case):

# Setup
litellm.callbacks = test_case["callbacks"]

# Execute
result = get_applied_guardrails(test_case["kwargs"])

# Assert
assert sorted(result) == sorted(test_case["expected"])
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
"model": "gpt-4o",
"user": "",
"team_id": "",
"metadata": "{\"additional_usage_values\": {\"completion_tokens_details\": null, \"prompt_tokens_details\": null}}",
"metadata": "{\"applied_guardrails\": [], \"additional_usage_values\": {\"completion_tokens_details\": null, \"prompt_tokens_details\": null}}",
"cache_key": "Cache OFF",
"spend": 0.00022500000000000002,
"total_tokens": 30,
Expand Down
1 change: 1 addition & 0 deletions tests/logging_callback_tests/test_otel_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@ def validate_redacted_message_span_attributes(span):
"metadata.user_api_key_user_id",
"metadata.user_api_key_org_id",
"metadata.user_api_key_end_user_id",
"metadata.applied_guardrails",
]

_all_attributes = set(
Expand Down