Skip to content

Commit

Permalink
feat(custom_logger.py): expose new async_dataset_hook for modifying… (
Browse files Browse the repository at this point in the history
#6331)

* feat(custom_logger.py): expose new `async_dataset_hook` for modifying/rejecting argilla items before logging

Allows user more control on what gets logged to argilla for annotations

* feat(google_ai_studio_endpoints.py): add new `/azure/*` pass through route

enables pass-through for azure provider

* feat(utils.py): support checking ollama `/api/show` endpoint for retrieving ollama model info

Fixes #6322

* fix(user_api_key_auth.py): add `/key/delete` to an allowed_ui_routes

Fixes #6236

* fix(user_api_key_auth.py): remove type ignore

* fix(user_api_key_auth.py): route ui vs. api token checks differently

Fixes #6238

* feat(internal_user_endpoints.py): support setting models as a default internal user param

Closes #6239

* fix(user_api_key_auth.py): fix exception string

* fix(user_api_key_auth.py): fix error string

* fix: fix test
  • Loading branch information
krrishdholakia authored Oct 20, 2024
1 parent 7cc12bd commit 905ebeb
Show file tree
Hide file tree
Showing 16 changed files with 422 additions and 153 deletions.
1 change: 1 addition & 0 deletions docs/my-website/docs/proxy/self_serve.md
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ litellm_settings:
user_role: "internal_user" # one of "internal_user", "internal_user_viewer", "proxy_admin", "proxy_admin_viewer". New SSO users not in litellm will be created as this user
max_budget: 100 # Optional[float], optional): $100 budget for a new SSO sign in user
budget_duration: 30d # Optional[str], optional): 30 days budget_duration for a new SSO sign in user
models: ["gpt-3.5-turbo"] # Optional[List[str]], optional): models to be used by a new SSO sign in user
upperbound_key_generate_params: # Upperbound for /key/generate requests when self-serve flow is on
Expand Down
69 changes: 30 additions & 39 deletions litellm/integrations/argilla.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,53 +21,22 @@
import litellm
from litellm._logging import verbose_logger
from litellm.integrations.custom_batch_logger import CustomBatchLogger
from litellm.integrations.custom_logger import CustomLogger
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
get_async_httpx_client,
httpxSpecialProvider,
)
from litellm.llms.prompt_templates.common_utils import get_content_from_model_response
from litellm.types.integrations.argilla import (
SUPPORTED_PAYLOAD_FIELDS,
ArgillaCredentialsObject,
ArgillaItem,
ArgillaPayload,
)
from litellm.types.utils import StandardLoggingPayload


class LangsmithInputs(BaseModel):
model: Optional[str] = None
messages: Optional[List[Any]] = None
stream: Optional[bool] = None
call_type: Optional[str] = None
litellm_call_id: Optional[str] = None
completion_start_time: Optional[datetime] = None
temperature: Optional[float] = None
max_tokens: Optional[int] = None
custom_llm_provider: Optional[str] = None
input: Optional[List[Any]] = None
log_event_type: Optional[str] = None
original_response: Optional[Any] = None
response_cost: Optional[float] = None

# LiteLLM Virtual Key specific fields
user_api_key: Optional[str] = None
user_api_key_user_id: Optional[str] = None
user_api_key_team_alias: Optional[str] = None


class ArgillaItem(TypedDict):
fields: Dict[str, Any]


class ArgillaPayload(TypedDict):
items: List[ArgillaItem]


class ArgillaCredentialsObject(TypedDict):
ARGILLA_API_KEY: str
ARGILLA_DATASET_NAME: str
ARGILLA_BASE_URL: str


SUPPORTED_PAYLOAD_FIELDS = ["messages", "response"]


def is_serializable(value):
non_serializable_types = (
types.CoroutineType,
Expand Down Expand Up @@ -215,7 +184,7 @@ def get_str_response(self, payload: StandardLoggingPayload) -> str:

def _prepare_log_data(
self, kwargs, response_obj, start_time, end_time
) -> ArgillaItem:
) -> Optional[ArgillaItem]:
try:
# Ensure everything in the payload is converted to str
payload: Optional[StandardLoggingPayload] = kwargs.get(
Expand All @@ -235,6 +204,7 @@ def _prepare_log_data(
argilla_item["fields"][k] = argilla_response
else:
argilla_item["fields"][k] = payload.get(v, None)

return argilla_item
except Exception:
raise
Expand Down Expand Up @@ -294,6 +264,9 @@ def log_success_event(self, kwargs, response_obj, start_time, end_time):
response_obj,
)
data = self._prepare_log_data(kwargs, response_obj, start_time, end_time)
if data is None:
return

self.log_queue.append(data)
verbose_logger.debug(
f"Langsmith, event added to queue. Will flush in {self.flush_interval} seconds..."
Expand Down Expand Up @@ -321,7 +294,25 @@ async def async_log_success_event(self, kwargs, response_obj, start_time, end_ti
kwargs,
response_obj,
)
payload: Optional[StandardLoggingPayload] = kwargs.get(
"standard_logging_object", None
)

data = self._prepare_log_data(kwargs, response_obj, start_time, end_time)

## ALLOW CUSTOM LOGGERS TO MODIFY / FILTER DATA BEFORE LOGGING
for callback in litellm.callbacks:
if isinstance(callback, CustomLogger):
try:
if data is None:
break
data = await callback.async_dataset_hook(data, payload)
except NotImplementedError:
pass

if data is None:
return

self.log_queue.append(data)
verbose_logger.debug(
"Langsmith logging: queue length %s, batch size %s",
Expand Down
16 changes: 16 additions & 0 deletions litellm/integrations/custom_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@

from litellm.caching.caching import DualCache
from litellm.proxy._types import UserAPIKeyAuth
from litellm.types.integrations.argilla import ArgillaItem
from litellm.types.llms.openai import ChatCompletionRequest
from litellm.types.services import ServiceLoggerPayload
from litellm.types.utils import (
AdapterCompletionStreamWrapper,
EmbeddingResponse,
ImageResponse,
ModelResponse,
StandardLoggingPayload,
)


Expand Down Expand Up @@ -108,6 +110,20 @@ def translate_completion_output_params_streaming(
"""
pass

### DATASET HOOKS #### - currently only used for Argilla

async def async_dataset_hook(
self,
logged_item: ArgillaItem,
standard_logging_payload: Optional[StandardLoggingPayload],
) -> Optional[ArgillaItem]:
"""
- Decide if the result should be logged to Argilla.
- Modify the result before logging to Argilla.
- Return None if the result should not be logged to Argilla.
"""
raise NotImplementedError("async_dataset_hook not implemented")

#### CALL HOOKS - proxy only ####
"""
Control the modify incoming / outgoung data before calling the model
Expand Down
53 changes: 52 additions & 1 deletion litellm/llms/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@

import litellm
from litellm import verbose_logger
from litellm.types.utils import ProviderField, StreamingChoices
from litellm.secret_managers.main import get_secret_str
from litellm.types.utils import ModelInfo, ProviderField, StreamingChoices

from .prompt_templates.factory import custom_prompt, prompt_factory

Expand Down Expand Up @@ -163,6 +164,56 @@ def get_supported_openai_params(
"response_format",
]

def _supports_function_calling(self, ollama_model_info: dict) -> bool:
"""
Check if the 'template' field in the ollama_model_info contains a 'tools' or 'function' key.
"""
_template: str = str(ollama_model_info.get("template", "") or "")
return "tools" in _template.lower()

def _get_max_tokens(self, ollama_model_info: dict) -> Optional[int]:
_model_info: dict = ollama_model_info.get("model_info", {})

for k, v in _model_info.items():
if "context_length" in k:
return v
return None

def get_model_info(self, model: str) -> ModelInfo:
"""
curl http://localhost:11434/api/show -d '{
"name": "mistral"
}'
"""
api_base = get_secret_str("OLLAMA_API_BASE") or "http://localhost:11434"

try:
response = litellm.module_level_client.post(
url=f"{api_base}/api/show",
json={"name": model},
)
except Exception as e:
raise Exception(
f"OllamaError: Error getting model info for {model}. Set Ollama API Base via `OLLAMA_API_BASE` environment variable. Error: {e}"
)

model_info = response.json()

_max_tokens: Optional[int] = self._get_max_tokens(model_info)

return ModelInfo(
key=model,
litellm_provider="ollama",
mode="chat",
supported_openai_params=self.get_supported_openai_params(),
supports_function_calling=self._supports_function_calling(model_info),
input_cost_per_token=0.0,
output_cost_per_token=0.0,
max_tokens=_max_tokens,
max_input_tokens=_max_tokens,
max_output_tokens=_max_tokens,
)


# ollama wants plain base64 jpeg/png files as images. strip any leading dataURI
# and convert to jpeg if necessary.
Expand Down
2 changes: 1 addition & 1 deletion litellm/proxy/_new_secret_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ model_list:
- model_name: "gpt-4o-audio-preview"
litellm_params:
model: gpt-4o-audio-preview
api_key: os.environ/OPENAI_API_KEY
api_key: os.environ/OPENAI_API_KEY
25 changes: 25 additions & 0 deletions litellm/proxy/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,31 @@ class LiteLLMRoutes(enum.Enum):
"/metrics",
]

ui_routes = [
"/sso",
"/sso/get/ui_settings",
"/login",
"/key/generate",
"/key/update",
"/key/info",
"/key/delete",
"/config",
"/spend",
"/user",
"/model/info",
"/v2/model/info",
"/v2/key/info",
"/models",
"/v1/models",
"/global/spend",
"/global/spend/logs",
"/global/spend/keys",
"/global/spend/models",
"/global/predict/spend/logs",
"/global/activity",
"/health/services",
] + info_routes

internal_user_routes = (
[
"/key/generate",
Expand Down
Loading

0 comments on commit 905ebeb

Please sign in to comment.