diff --git a/docs/my-website/docs/proxy/self_serve.md b/docs/my-website/docs/proxy/self_serve.md index 0e4198f0cac7..e04aa4b44398 100644 --- a/docs/my-website/docs/proxy/self_serve.md +++ b/docs/my-website/docs/proxy/self_serve.md @@ -207,6 +207,7 @@ litellm_settings: user_role: "internal_user" # one of "internal_user", "internal_user_viewer", "proxy_admin", "proxy_admin_viewer". New SSO users not in litellm will be created as this user max_budget: 100 # Optional[float], optional): $100 budget for a new SSO sign in user budget_duration: 30d # Optional[str], optional): 30 days budget_duration for a new SSO sign in user + models: ["gpt-3.5-turbo"] # Optional[List[str]], optional): models to be used by a new SSO sign in user upperbound_key_generate_params: # Upperbound for /key/generate requests when self-serve flow is on diff --git a/litellm/integrations/argilla.py b/litellm/integrations/argilla.py index 5c0bd4b1ea83..352543d82cf7 100644 --- a/litellm/integrations/argilla.py +++ b/litellm/integrations/argilla.py @@ -21,53 +21,22 @@ import litellm from litellm._logging import verbose_logger from litellm.integrations.custom_batch_logger import CustomBatchLogger +from litellm.integrations.custom_logger import CustomLogger from litellm.llms.custom_httpx.http_handler import ( AsyncHTTPHandler, get_async_httpx_client, httpxSpecialProvider, ) from litellm.llms.prompt_templates.common_utils import get_content_from_model_response +from litellm.types.integrations.argilla import ( + SUPPORTED_PAYLOAD_FIELDS, + ArgillaCredentialsObject, + ArgillaItem, + ArgillaPayload, +) from litellm.types.utils import StandardLoggingPayload -class LangsmithInputs(BaseModel): - model: Optional[str] = None - messages: Optional[List[Any]] = None - stream: Optional[bool] = None - call_type: Optional[str] = None - litellm_call_id: Optional[str] = None - completion_start_time: Optional[datetime] = None - temperature: Optional[float] = None - max_tokens: Optional[int] = None - custom_llm_provider: Optional[str] = None - input: Optional[List[Any]] = None - log_event_type: Optional[str] = None - original_response: Optional[Any] = None - response_cost: Optional[float] = None - - # LiteLLM Virtual Key specific fields - user_api_key: Optional[str] = None - user_api_key_user_id: Optional[str] = None - user_api_key_team_alias: Optional[str] = None - - -class ArgillaItem(TypedDict): - fields: Dict[str, Any] - - -class ArgillaPayload(TypedDict): - items: List[ArgillaItem] - - -class ArgillaCredentialsObject(TypedDict): - ARGILLA_API_KEY: str - ARGILLA_DATASET_NAME: str - ARGILLA_BASE_URL: str - - -SUPPORTED_PAYLOAD_FIELDS = ["messages", "response"] - - def is_serializable(value): non_serializable_types = ( types.CoroutineType, @@ -215,7 +184,7 @@ def get_str_response(self, payload: StandardLoggingPayload) -> str: def _prepare_log_data( self, kwargs, response_obj, start_time, end_time - ) -> ArgillaItem: + ) -> Optional[ArgillaItem]: try: # Ensure everything in the payload is converted to str payload: Optional[StandardLoggingPayload] = kwargs.get( @@ -235,6 +204,7 @@ def _prepare_log_data( argilla_item["fields"][k] = argilla_response else: argilla_item["fields"][k] = payload.get(v, None) + return argilla_item except Exception: raise @@ -294,6 +264,9 @@ def log_success_event(self, kwargs, response_obj, start_time, end_time): response_obj, ) data = self._prepare_log_data(kwargs, response_obj, start_time, end_time) + if data is None: + return + self.log_queue.append(data) verbose_logger.debug( f"Langsmith, event added to queue. Will flush in {self.flush_interval} seconds..." @@ -321,7 +294,25 @@ async def async_log_success_event(self, kwargs, response_obj, start_time, end_ti kwargs, response_obj, ) + payload: Optional[StandardLoggingPayload] = kwargs.get( + "standard_logging_object", None + ) + data = self._prepare_log_data(kwargs, response_obj, start_time, end_time) + + ## ALLOW CUSTOM LOGGERS TO MODIFY / FILTER DATA BEFORE LOGGING + for callback in litellm.callbacks: + if isinstance(callback, CustomLogger): + try: + if data is None: + break + data = await callback.async_dataset_hook(data, payload) + except NotImplementedError: + pass + + if data is None: + return + self.log_queue.append(data) verbose_logger.debug( "Langsmith logging: queue length %s, batch size %s", diff --git a/litellm/integrations/custom_logger.py b/litellm/integrations/custom_logger.py index 1d23d29047d3..bdcb6a52f7a1 100644 --- a/litellm/integrations/custom_logger.py +++ b/litellm/integrations/custom_logger.py @@ -10,6 +10,7 @@ from litellm.caching.caching import DualCache from litellm.proxy._types import UserAPIKeyAuth +from litellm.types.integrations.argilla import ArgillaItem from litellm.types.llms.openai import ChatCompletionRequest from litellm.types.services import ServiceLoggerPayload from litellm.types.utils import ( @@ -17,6 +18,7 @@ EmbeddingResponse, ImageResponse, ModelResponse, + StandardLoggingPayload, ) @@ -108,6 +110,20 @@ def translate_completion_output_params_streaming( """ pass + ### DATASET HOOKS #### - currently only used for Argilla + + async def async_dataset_hook( + self, + logged_item: ArgillaItem, + standard_logging_payload: Optional[StandardLoggingPayload], + ) -> Optional[ArgillaItem]: + """ + - Decide if the result should be logged to Argilla. + - Modify the result before logging to Argilla. + - Return None if the result should not be logged to Argilla. + """ + raise NotImplementedError("async_dataset_hook not implemented") + #### CALL HOOKS - proxy only #### """ Control the modify incoming / outgoung data before calling the model diff --git a/litellm/llms/ollama.py b/litellm/llms/ollama.py index 3fc482b4af21..e08e6f6938c1 100644 --- a/litellm/llms/ollama.py +++ b/litellm/llms/ollama.py @@ -14,7 +14,8 @@ import litellm from litellm import verbose_logger -from litellm.types.utils import ProviderField, StreamingChoices +from litellm.secret_managers.main import get_secret_str +from litellm.types.utils import ModelInfo, ProviderField, StreamingChoices from .prompt_templates.factory import custom_prompt, prompt_factory @@ -163,6 +164,56 @@ def get_supported_openai_params( "response_format", ] + def _supports_function_calling(self, ollama_model_info: dict) -> bool: + """ + Check if the 'template' field in the ollama_model_info contains a 'tools' or 'function' key. + """ + _template: str = str(ollama_model_info.get("template", "") or "") + return "tools" in _template.lower() + + def _get_max_tokens(self, ollama_model_info: dict) -> Optional[int]: + _model_info: dict = ollama_model_info.get("model_info", {}) + + for k, v in _model_info.items(): + if "context_length" in k: + return v + return None + + def get_model_info(self, model: str) -> ModelInfo: + """ + curl http://localhost:11434/api/show -d '{ + "name": "mistral" + }' + """ + api_base = get_secret_str("OLLAMA_API_BASE") or "http://localhost:11434" + + try: + response = litellm.module_level_client.post( + url=f"{api_base}/api/show", + json={"name": model}, + ) + except Exception as e: + raise Exception( + f"OllamaError: Error getting model info for {model}. Set Ollama API Base via `OLLAMA_API_BASE` environment variable. Error: {e}" + ) + + model_info = response.json() + + _max_tokens: Optional[int] = self._get_max_tokens(model_info) + + return ModelInfo( + key=model, + litellm_provider="ollama", + mode="chat", + supported_openai_params=self.get_supported_openai_params(), + supports_function_calling=self._supports_function_calling(model_info), + input_cost_per_token=0.0, + output_cost_per_token=0.0, + max_tokens=_max_tokens, + max_input_tokens=_max_tokens, + max_output_tokens=_max_tokens, + ) + # ollama wants plain base64 jpeg/png files as images. strip any leading dataURI # and convert to jpeg if necessary. diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index c38c90c945a7..506a56cc6c86 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -2,4 +2,4 @@ model_list: - model_name: "gpt-4o-audio-preview" litellm_params: model: gpt-4o-audio-preview - api_key: os.environ/OPENAI_API_KEY + api_key: os.environ/OPENAI_API_KEY \ No newline at end of file diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index e104b46ae1e6..629e002b56f5 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -335,6 +335,31 @@ class LiteLLMRoutes(enum.Enum): "/metrics", ] + ui_routes = [ + "/sso", + "/sso/get/ui_settings", + "/login", + "/key/generate", + "/key/update", + "/key/info", + "/key/delete", + "/config", + "/spend", + "/user", + "/model/info", + "/v2/model/info", + "/v2/key/info", + "/models", + "/v1/models", + "/global/spend", + "/global/spend/logs", + "/global/spend/keys", + "/global/spend/models", + "/global/predict/spend/logs", + "/global/activity", + "/health/services", + ] + info_routes + internal_user_routes = ( [ "/key/generate", diff --git a/litellm/proxy/auth/user_api_key_auth.py b/litellm/proxy/auth/user_api_key_auth.py index 8322a78dfa13..20b262d3176f 100644 --- a/litellm/proxy/auth/user_api_key_auth.py +++ b/litellm/proxy/auth/user_api_key_auth.py @@ -105,6 +105,88 @@ def _get_bearer_token( return api_key +def _is_ui_route_allowed( + route: str, + user_obj: Optional[LiteLLM_UserTable] = None, +) -> bool: + """ + - Route b/w ui token check and normal token check + """ + # this token is only used for managing the ui + allowed_routes = LiteLLMRoutes.ui_routes.value + # check if the current route startswith any of the allowed routes + if ( + route is not None + and isinstance(route, str) + and any(route.startswith(allowed_route) for allowed_route in allowed_routes) + ): + # Do something if the current route starts with any of the allowed routes + return True + else: + if user_obj is not None and _is_user_proxy_admin(user_obj=user_obj): + return True + elif _has_user_setup_sso() and route in LiteLLMRoutes.sso_only_routes.value: + return True + else: + raise Exception( + f"This key is made for LiteLLM UI, Tried to access route: {route}. Not allowed" + ) + + +def _is_api_route_allowed( + route: str, + request: Request, + request_data: dict, + api_key: str, + valid_token: Optional[UserAPIKeyAuth], + user_obj: Optional[LiteLLM_UserTable] = None, +) -> bool: + """ + - Route b/w api token check and normal token check + """ + _user_role = _get_user_role(user_obj=user_obj) + + if valid_token is None: + raise Exception("Invalid proxy server token passed") + + if not _is_user_proxy_admin(user_obj=user_obj): # if non-admin + non_proxy_admin_allowed_routes_check( + user_obj=user_obj, + _user_role=_user_role, + route=route, + request=request, + request_data=request_data, + api_key=api_key, + valid_token=valid_token, + ) + return True + + +def _is_allowed_route( + route: str, + token_type: Literal["ui", "api"], + request: Request, + request_data: dict, + api_key: str, + valid_token: Optional[UserAPIKeyAuth], + user_obj: Optional[LiteLLM_UserTable] = None, +) -> bool: + """ + - Route b/w ui token check and normal token check + """ + if token_type == "ui": + return _is_ui_route_allowed(route=route, user_obj=user_obj) + else: + return _is_api_route_allowed( + route=route, + request=request, + request_data=request_data, + api_key=api_key, + valid_token=valid_token, + user_obj=user_obj, + ) + + async def user_api_key_auth( # noqa: PLR0915 request: Request, api_key: str = fastapi.Security(api_key_header), @@ -1041,81 +1123,27 @@ async def user_api_key_auth( # noqa: PLR0915 if _end_user_object is not None: valid_token_dict.update(end_user_params) - _user_role = _get_user_role(user_obj=user_obj) - - if not _is_user_proxy_admin(user_obj=user_obj): # if non-admin - non_proxy_admin_allowed_routes_check( - user_obj=user_obj, - _user_role=_user_role, - route=route, - request=request, - request_data=request_data, - api_key=api_key, - valid_token=valid_token, - ) - # check if token is from litellm-ui, litellm ui makes keys to allow users to login with sso. These keys can only be used for LiteLLM UI functions # sso/login, ui/login, /key functions and /user functions # this will never be allowed to call /chat/completions token_team = getattr(valid_token, "team_id", None) + token_type: Literal["ui", "api"] = ( + "ui" + if token_team is not None and token_team == "litellm-dashboard" + else "api" + ) + _is_route_allowed = _is_allowed_route( + route=route, + token_type=token_type, + user_obj=user_obj, + request=request, + request_data=request_data, + api_key=api_key, + valid_token=valid_token, + ) + if not _is_route_allowed: + raise HTTPException(401, detail="Invalid route for UI token") - if token_team is not None and token_team == "litellm-dashboard": - # this token is only used for managing the ui - allowed_routes = [ - "/sso", - "/sso/get/ui_settings", - "/login", - "/key/generate", - "/key/update", - "/key/info", - "/config", - "/spend", - "/user", - "/model/info", - "/v2/model/info", - "/v2/key/info", - "/models", - "/v1/models", - "/global/spend", - "/global/spend/logs", - "/global/spend/keys", - "/global/spend/models", - "/global/predict/spend/logs", - "/global/activity", - "/health/services", - ] + LiteLLMRoutes.info_routes.value # type: ignore - # check if the current route startswith any of the allowed routes - if ( - route is not None - and isinstance(route, str) - and any( - route.startswith(allowed_route) for allowed_route in allowed_routes - ) - ): - # Do something if the current route starts with any of the allowed routes - pass - else: - if user_obj is not None and _is_user_proxy_admin(user_obj=user_obj): - return UserAPIKeyAuth( - api_key=api_key, - user_role=LitellmUserRoles.PROXY_ADMIN, - parent_otel_span=parent_otel_span, - **valid_token_dict, - ) - elif ( - _has_user_setup_sso() - and route in LiteLLMRoutes.sso_only_routes.value - ): - return UserAPIKeyAuth( - api_key=api_key, - user_role=_user_role, # type: ignore - parent_otel_span=parent_otel_span, - **valid_token_dict, - ) - else: - raise Exception( - f"This key is made for LiteLLM UI, Tried to access route: {route}. Not allowed" - ) if valid_token is None: # No token was found when looking up in the DB raise Exception("Invalid proxy server token passed") diff --git a/litellm/proxy/management_endpoints/internal_user_endpoints.py b/litellm/proxy/management_endpoints/internal_user_endpoints.py index fa96fe08caf6..743cc8d67950 100644 --- a/litellm/proxy/management_endpoints/internal_user_endpoints.py +++ b/litellm/proxy/management_endpoints/internal_user_endpoints.py @@ -41,6 +41,40 @@ router = APIRouter() +def _update_internal_user_params(data_json: dict, data: NewUserRequest) -> dict: + if "user_id" in data_json and data_json["user_id"] is None: + data_json["user_id"] = str(uuid.uuid4()) + auto_create_key = data_json.pop("auto_create_key", True) + if auto_create_key is False: + data_json["table_name"] = ( + "user" # only create a user, don't create key if 'auto_create_key' set to False + ) + + is_internal_user = False + if data.user_role == LitellmUserRoles.INTERNAL_USER: + is_internal_user = True + if litellm.default_internal_user_params: + for key, value in litellm.default_internal_user_params.items(): + if key not in data_json or data_json[key] is None: + data_json[key] = value + elif ( + key == "models" + and isinstance(data_json[key], list) + and len(data_json[key]) == 0 + ): + data_json[key] = value + + if "max_budget" in data_json and data_json["max_budget"] is None: + if is_internal_user and litellm.max_internal_user_budget is not None: + data_json["max_budget"] = litellm.max_internal_user_budget + + if "budget_duration" in data_json and data_json["budget_duration"] is None: + if is_internal_user and litellm.internal_user_budget_duration is not None: + data_json["budget_duration"] = litellm.internal_user_budget_duration + + return data_json + + @router.post( "/user/new", tags=["Internal User management"], @@ -94,26 +128,7 @@ async def new_user( from litellm.proxy.proxy_server import general_settings, proxy_logging_obj data_json = data.json() # type: ignore - if "user_id" in data_json and data_json["user_id"] is None: - data_json["user_id"] = str(uuid.uuid4()) - auto_create_key = data_json.pop("auto_create_key", True) - if auto_create_key is False: - data_json["table_name"] = ( - "user" # only create a user, don't create key if 'auto_create_key' set to False - ) - - is_internal_user = False - if data.user_role == LitellmUserRoles.INTERNAL_USER: - is_internal_user = True - - if "max_budget" in data_json and data_json["max_budget"] is None: - if is_internal_user and litellm.max_internal_user_budget is not None: - data_json["max_budget"] = litellm.max_internal_user_budget - - if "budget_duration" in data_json and data_json["budget_duration"] is None: - if is_internal_user and litellm.internal_user_budget_duration is not None: - data_json["budget_duration"] = litellm.internal_user_budget_duration - + data_json = _update_internal_user_params(data_json, data) response = await generate_key_helper_fn(request_type="user", **data_json) # Admin UI Logic diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 1a1516e923d2..6f5ecbc6545b 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -1585,10 +1585,6 @@ async def get_config(self, config_file_path: Optional[str] = None) -> dict: printed_yaml = copy.deepcopy(config) printed_yaml.pop("environment_variables", None) - verbose_proxy_logger.debug( - f"Loaded config YAML (api_key and environment_variables are not shown):\n{json.dumps(printed_yaml, indent=2)}" - ) - config = self._check_for_os_environ_vars(config=config) return config diff --git a/litellm/proxy/vertex_ai_endpoints/google_ai_studio_endpoints.py b/litellm/proxy/vertex_ai_endpoints/google_ai_studio_endpoints.py index 6e232d3a6725..667a21a3ca50 100644 --- a/litellm/proxy/vertex_ai_endpoints/google_ai_studio_endpoints.py +++ b/litellm/proxy/vertex_ai_endpoints/google_ai_studio_endpoints.py @@ -40,6 +40,7 @@ from litellm.proxy.pass_through_endpoints.pass_through_endpoints import ( create_pass_through_route, ) +from litellm.secret_managers.main import get_secret_str router = APIRouter() default_vertex_config = None @@ -226,3 +227,53 @@ async def bedrock_proxy_route( ) return received_value + + +@router.api_route("/azure/{endpoint:path}", methods=["GET", "POST", "PUT", "DELETE"]) +async def azure_proxy_route( + endpoint: str, + request: Request, + fastapi_response: Response, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + base_target_url = get_secret_str(secret_name="AZURE_API_BASE") + if base_target_url is None: + raise Exception( + "Required 'AZURE_API_BASE' in environment to make pass-through calls to Azure." + ) + encoded_endpoint = httpx.URL(endpoint).path + + # Ensure endpoint starts with '/' for proper URL construction + if not encoded_endpoint.startswith("/"): + encoded_endpoint = "/" + encoded_endpoint + + # Construct the full target URL using httpx + base_url = httpx.URL(base_target_url) + updated_url = base_url.copy_with(path=encoded_endpoint) + + # Add or update query parameters + azure_api_key = get_secret_str(secret_name="AZURE_API_KEY") + + ## check for streaming + is_streaming_request = False + if "stream" in str(updated_url): + is_streaming_request = True + + ## CREATE PASS-THROUGH + endpoint_func = create_pass_through_route( + endpoint=endpoint, + target=str(updated_url), + custom_headers={ + "authorization": "Bearer {}".format(azure_api_key), + "api-key": "{}".format(azure_api_key), + }, + ) # dynamically construct pass-through endpoint based on incoming path + received_value = await endpoint_func( + request, + fastapi_response, + user_api_key_dict, + stream=is_streaming_request, # type: ignore + query_params=dict(request.query_params), # type: ignore + ) + + return received_value diff --git a/litellm/types/integrations/argilla.py b/litellm/types/integrations/argilla.py new file mode 100644 index 000000000000..6c0de762a700 --- /dev/null +++ b/litellm/types/integrations/argilla.py @@ -0,0 +1,21 @@ +import os +from datetime import datetime as dt +from enum import Enum +from typing import Any, Dict, List, Literal, Optional, Set, TypedDict + + +class ArgillaItem(TypedDict): + fields: Dict[str, Any] + + +class ArgillaPayload(TypedDict): + items: List[ArgillaItem] + + +class ArgillaCredentialsObject(TypedDict): + ARGILLA_API_KEY: str + ARGILLA_DATASET_NAME: str + ARGILLA_BASE_URL: str + + +SUPPORTED_PAYLOAD_FIELDS = ["messages", "response"] diff --git a/litellm/utils.py b/litellm/utils.py index 51aea33a4756..45e24847cde3 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -1821,6 +1821,7 @@ def supports_function_calling( model=model, custom_llm_provider=custom_llm_provider ) + ## CHECK IF MODEL SUPPORTS FUNCTION CALLING ## model_info = litellm.get_model_info( model=model, custom_llm_provider=custom_llm_provider ) @@ -4768,6 +4769,8 @@ def _get_max_position_embeddings(model_name): supports_assistant_prefill=None, supports_prompt_caching=None, ) + elif custom_llm_provider == "ollama" or custom_llm_provider == "ollama_chat": + return litellm.OllamaConfig().get_model_info(model) else: """ Check if: (in order of specificity) @@ -4964,7 +4967,9 @@ def _get_max_position_embeddings(model_name): supports_audio_input=_model_info.get("supports_audio_input", False), supports_audio_output=_model_info.get("supports_audio_output", False), ) - except Exception: + except Exception as e: + if "OllamaError" in str(e): + raise e raise Exception( "This model isn't mapped yet. model={}, custom_llm_provider={}. Add it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json.".format( model, custom_llm_provider diff --git a/tests/local_testing/test_get_model_info.py b/tests/local_testing/test_get_model_info.py index 20f9aa16e0cd..82ce9c4651f9 100644 --- a/tests/local_testing/test_get_model_info.py +++ b/tests/local_testing/test_get_model_info.py @@ -11,6 +11,7 @@ import litellm from litellm import get_model_info +from unittest.mock import AsyncMock, MagicMock, patch def test_get_model_info_simple_model_name(): @@ -74,3 +75,25 @@ def test_get_model_info_gemini_pro(): info = litellm.get_model_info("gemini-1.5-pro-002") print("info", info) assert info["key"] == "gemini-1.5-pro-002" + + +def test_get_model_info_ollama_chat(): + from litellm.llms.ollama import OllamaConfig + + with patch.object( + litellm.module_level_client, + "post", + return_value=MagicMock( + json=lambda: { + "model_info": {"llama.context_length": 32768}, + "template": "tools", + } + ), + ): + info = OllamaConfig().get_model_info("mistral") + print("info", info) + assert info["supports_function_calling"] is True + + info = get_model_info("ollama/mistral") + print("info", info) + assert info["supports_function_calling"] is True diff --git a/tests/local_testing/test_proxy_utils.py b/tests/local_testing/test_proxy_utils.py index 5bb9bdc165d8..a74e9e78b33b 100644 --- a/tests/local_testing/test_proxy_utils.py +++ b/tests/local_testing/test_proxy_utils.py @@ -406,3 +406,29 @@ def test_add_litellm_data_for_backend_llm_call(headers, expected_data): data = add_litellm_data_for_backend_llm_call(headers) assert json.dumps(data, sort_keys=True) == json.dumps(expected_data, sort_keys=True) + + +def test_update_internal_user_params(): + from litellm.proxy.management_endpoints.internal_user_endpoints import ( + _update_internal_user_params, + ) + from litellm.proxy._types import NewUserRequest + + litellm.default_internal_user_params = { + "max_budget": 100, + "budget_duration": "30d", + "models": ["gpt-3.5-turbo"], + } + + data = NewUserRequest(user_role="internal_user", user_email="krrish3@berri.ai") + data_json = data.model_dump() + updated_data_json = _update_internal_user_params(data_json, data) + assert updated_data_json["models"] == litellm.default_internal_user_params["models"] + assert ( + updated_data_json["max_budget"] + == litellm.default_internal_user_params["max_budget"] + ) + assert ( + updated_data_json["budget_duration"] + == litellm.default_internal_user_params["budget_duration"] + ) diff --git a/tests/local_testing/test_user_api_key_auth.py b/tests/local_testing/test_user_api_key_auth.py index 6f0132312eea..47f96ccf22f0 100644 --- a/tests/local_testing/test_user_api_key_auth.py +++ b/tests/local_testing/test_user_api_key_auth.py @@ -291,3 +291,28 @@ async def test_auth_with_allowed_routes(route, should_raise_error): await user_api_key_auth(request=request, api_key="Bearer " + user_key) setattr(proxy_server, "general_settings", initial_general_settings) + + +@pytest.mark.parametrize("route", ["/global/spend/logs", "/key/delete"]) +def test_is_ui_route_allowed(route): + from litellm.proxy.auth.user_api_key_auth import _is_ui_route_allowed + from litellm.proxy._types import LiteLLM_UserTable + + received_args: dict = { + "route": route, + "user_obj": LiteLLM_UserTable( + user_id="3b803c0e-666e-4e99-bd5c-6e534c07e297", + max_budget=None, + spend=0.0, + model_max_budget={}, + model_spend={}, + user_email="my-test-email@1234.com", + models=[], + tpm_limit=None, + rpm_limit=None, + user_role="internal_user", + organization_memberships=[], + ), + } + + assert _is_ui_route_allowed(**received_args) diff --git a/tests/local_testing/test_utils.py b/tests/local_testing/test_utils.py index 61b424129e25..684e41da8c0e 100644 --- a/tests/local_testing/test_utils.py +++ b/tests/local_testing/test_utils.py @@ -448,24 +448,19 @@ def test_token_counter(): # test_token_counter() -def test_supports_function_calling(): +@pytest.mark.parametrize( + "model, expected_bool", + [ + ("gpt-3.5-turbo", True), + ("azure/gpt-4-1106-preview", True), + ("groq/gemma-7b-it", True), + ("anthropic.claude-instant-v1", False), + ("palm/chat-bison", False), + ], +) +def test_supports_function_calling(model, expected_bool): try: - assert litellm.supports_function_calling(model="gpt-3.5-turbo") == True - assert ( - litellm.supports_function_calling(model="azure/gpt-4-1106-preview") == True - ) - assert litellm.supports_function_calling(model="groq/gemma-7b-it") == True - assert ( - litellm.supports_function_calling(model="anthropic.claude-instant-v1") - == False - ) - assert litellm.supports_function_calling(model="palm/chat-bison") == False - assert litellm.supports_function_calling(model="ollama/llama2") == False - assert ( - litellm.supports_function_calling(model="anthropic.claude-instant-v1") - == False - ) - assert litellm.supports_function_calling(model="claude-2") == False + assert litellm.supports_function_calling(model=model) == expected_bool except Exception as e: pytest.fail(f"Error occurred: {e}")