From 4f9fa2ccbccd48c863a88d8370f608c278460165 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Tue, 10 Dec 2024 17:01:56 -0800 Subject: [PATCH 1/2] refactor(sagemaker/): separate chat + completion routes + make them both use base llm config Addresses https://github.com/andrewyng/aisuite/issues/113#issuecomment-2512369132 --- litellm/__init__.py | 3 +- .../get_supported_openai_params.py | 2 +- .../llms/OpenAI/chat/gpt_transformation.py | 6 +- litellm/llms/base_llm/transformation.py | 3 +- litellm/llms/sagemaker/chat/handler.py | 179 ++++++ litellm/llms/sagemaker/chat/transformation.py | 26 + litellm/llms/sagemaker/common_utils.py | 198 +++++++ .../{sagemaker.py => completion/handler.py} | 544 +++--------------- .../sagemaker/completion/transformation.py | 272 +++++++++ litellm/main.py | 35 +- litellm/utils.py | 40 +- tests/local_testing/test_config.py | 61 +- tests/local_testing/test_sagemaker.py | 2 +- tests/local_testing/test_streaming.py | 2 +- 14 files changed, 825 insertions(+), 548 deletions(-) create mode 100644 litellm/llms/sagemaker/chat/handler.py create mode 100644 litellm/llms/sagemaker/chat/transformation.py create mode 100644 litellm/llms/sagemaker/common_utils.py rename litellm/llms/sagemaker/{sagemaker.py => completion/handler.py} (55%) create mode 100644 litellm/llms/sagemaker/completion/transformation.py diff --git a/litellm/__init__.py b/litellm/__init__.py index 87be1d002f32..1b04c4d195d1 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -1099,7 +1099,8 @@ class LlmProviders(str, Enum): VertexAIAi21Config, ) -from .llms.sagemaker.sagemaker import SagemakerConfig +from .llms.sagemaker.completion.transformation import SagemakerConfig +from .llms.sagemaker.chat.transformation import SagemakerChatConfig from .llms.ollama import OllamaConfig from .llms.ollama_chat import OllamaChatConfig from .llms.maritalk import MaritTalkConfig diff --git a/litellm/litellm_core_utils/get_supported_openai_params.py b/litellm/litellm_core_utils/get_supported_openai_params.py index 383c2490c040..87a107e7b1ea 100644 --- a/litellm/litellm_core_utils/get_supported_openai_params.py +++ b/litellm/litellm_core_utils/get_supported_openai_params.py @@ -182,7 +182,7 @@ def get_supported_openai_params( # noqa: PLR0915 elif request_type == "embeddings": return litellm.VertexAITextEmbeddingConfig().get_supported_openai_params() elif custom_llm_provider == "sagemaker": - return ["stream", "temperature", "max_tokens", "top_p", "stop", "n"] + return litellm.SagemakerConfig().get_supported_openai_params(model=model) elif custom_llm_provider == "aleph_alpha": return [ "max_tokens", diff --git a/litellm/llms/OpenAI/chat/gpt_transformation.py b/litellm/llms/OpenAI/chat/gpt_transformation.py index 00df1c2b3ee8..d1496d8133cb 100644 --- a/litellm/llms/OpenAI/chat/gpt_transformation.py +++ b/litellm/llms/OpenAI/chat/gpt_transformation.py @@ -182,7 +182,11 @@ def transform_request( Returns: dict: The transformed request. Sent as the body of the API call. """ - raise NotImplementedError + return { + "model": model, + "messages": messages, + **optional_params, + } def transform_response( self, diff --git a/litellm/llms/base_llm/transformation.py b/litellm/llms/base_llm/transformation.py index 06d392e0b05d..c87c7a81d026 100644 --- a/litellm/llms/base_llm/transformation.py +++ b/litellm/llms/base_llm/transformation.py @@ -9,6 +9,7 @@ Any, AsyncIterator, Callable, + Dict, Iterator, List, Optional, @@ -33,7 +34,7 @@ def __init__( self, status_code: int, message: str, - headers: Optional[httpx.Headers] = None, + headers: Optional[Union[httpx.Headers, Dict]] = None, request: Optional[httpx.Request] = None, response: Optional[httpx.Response] = None, ): diff --git a/litellm/llms/sagemaker/chat/handler.py b/litellm/llms/sagemaker/chat/handler.py new file mode 100644 index 000000000000..f7024b94e324 --- /dev/null +++ b/litellm/llms/sagemaker/chat/handler.py @@ -0,0 +1,179 @@ +import json +from copy import deepcopy +from typing import Any, Callable, Dict, Optional, Union + +import httpx + +from litellm.utils import ModelResponse, get_secret + +from ...base_aws_llm import BaseAWSLLM +from ...prompt_templates.factory import custom_prompt, prompt_factory +from ..common_utils import AWSEventStreamDecoder +from .transformation import SagemakerChatConfig + + +class SagemakerChatHandler(BaseAWSLLM): + + def _load_credentials( + self, + optional_params: dict, + ): + try: + from botocore.credentials import Credentials + except ImportError: + raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.") + ## CREDENTIALS ## + # pop aws_secret_access_key, aws_access_key_id, aws_session_token, aws_region_name from kwargs, since completion calls fail with them + aws_secret_access_key = optional_params.pop("aws_secret_access_key", None) + aws_access_key_id = optional_params.pop("aws_access_key_id", None) + aws_session_token = optional_params.pop("aws_session_token", None) + aws_region_name = optional_params.pop("aws_region_name", None) + aws_role_name = optional_params.pop("aws_role_name", None) + aws_session_name = optional_params.pop("aws_session_name", None) + aws_profile_name = optional_params.pop("aws_profile_name", None) + optional_params.pop( + "aws_bedrock_runtime_endpoint", None + ) # https://bedrock-runtime.{region_name}.amazonaws.com + aws_web_identity_token = optional_params.pop("aws_web_identity_token", None) + aws_sts_endpoint = optional_params.pop("aws_sts_endpoint", None) + + ### SET REGION NAME ### + if aws_region_name is None: + # check env # + litellm_aws_region_name = get_secret("AWS_REGION_NAME", None) + + if litellm_aws_region_name is not None and isinstance( + litellm_aws_region_name, str + ): + aws_region_name = litellm_aws_region_name + + standard_aws_region_name = get_secret("AWS_REGION", None) + if standard_aws_region_name is not None and isinstance( + standard_aws_region_name, str + ): + aws_region_name = standard_aws_region_name + + if aws_region_name is None: + aws_region_name = "us-west-2" + + credentials: Credentials = self.get_credentials( + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + aws_session_token=aws_session_token, + aws_region_name=aws_region_name, + aws_session_name=aws_session_name, + aws_profile_name=aws_profile_name, + aws_role_name=aws_role_name, + aws_web_identity_token=aws_web_identity_token, + aws_sts_endpoint=aws_sts_endpoint, + ) + return credentials, aws_region_name + + def _prepare_request( + self, + credentials, + model: str, + data: dict, + optional_params: dict, + aws_region_name: str, + extra_headers: Optional[dict] = None, + ): + try: + import boto3 + from botocore.auth import SigV4Auth + from botocore.awsrequest import AWSRequest + from botocore.credentials import Credentials + except ImportError: + raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.") + + sigv4 = SigV4Auth(credentials, "sagemaker", aws_region_name) + if optional_params.get("stream") is True: + api_base = f"https://runtime.sagemaker.{aws_region_name}.amazonaws.com/endpoints/{model}/invocations-response-stream" + else: + api_base = f"https://runtime.sagemaker.{aws_region_name}.amazonaws.com/endpoints/{model}/invocations" + + sagemaker_base_url = optional_params.get("sagemaker_base_url", None) + if sagemaker_base_url is not None: + api_base = sagemaker_base_url + + encoded_data = json.dumps(data).encode("utf-8") + headers = {"Content-Type": "application/json"} + if extra_headers is not None: + headers = {"Content-Type": "application/json", **extra_headers} + request = AWSRequest( + method="POST", url=api_base, data=encoded_data, headers=headers + ) + sigv4.add_auth(request) + if ( + extra_headers is not None and "Authorization" in extra_headers + ): # prevent sigv4 from overwriting the auth header + request.headers["Authorization"] = extra_headers["Authorization"] + + prepped_request = request.prepare() + + return prepped_request + + def completion( + self, + model: str, + messages: list, + model_response: ModelResponse, + print_verbose: Callable, + encoding, + logging_obj, + optional_params: dict, + litellm_params: dict, + timeout: Optional[Union[float, httpx.Timeout]] = None, + custom_prompt_dict={}, + logger_fn=None, + acompletion: bool = False, + headers: dict = {}, + ): + + # pop streaming if it's in the optional params as 'stream' raises an error with sagemaker + credentials, aws_region_name = self._load_credentials(optional_params) + inference_params = deepcopy(optional_params) + stream = inference_params.pop("stream", None) + + from litellm.llms.openai_like.chat.handler import OpenAILikeChatHandler + + openai_like_chat_completions = OpenAILikeChatHandler() + inference_params["stream"] = True if stream is True else False + _data = SagemakerChatConfig().transform_request( + model=model, + messages=messages, + optional_params=inference_params, + litellm_params=litellm_params, + headers=headers, + ) + + prepared_request = self._prepare_request( + model=model, + data=_data, + optional_params=optional_params, + credentials=credentials, + aws_region_name=aws_region_name, + ) + + custom_stream_decoder = AWSEventStreamDecoder(model="", is_messages_api=True) + + return openai_like_chat_completions.completion( + model=model, + messages=messages, + api_base=prepared_request.url, + api_key=None, + custom_prompt_dict=custom_prompt_dict, + model_response=model_response, + print_verbose=print_verbose, + logging_obj=logging_obj, + optional_params=inference_params, + acompletion=acompletion, + litellm_params=litellm_params, + logger_fn=logger_fn, + timeout=timeout, + encoding=encoding, + headers=prepared_request.headers, # type: ignore + custom_endpoint=True, + custom_llm_provider="sagemaker_chat", + streaming_decoder=custom_stream_decoder, # type: ignore + ) diff --git a/litellm/llms/sagemaker/chat/transformation.py b/litellm/llms/sagemaker/chat/transformation.py new file mode 100644 index 000000000000..fa68f971af74 --- /dev/null +++ b/litellm/llms/sagemaker/chat/transformation.py @@ -0,0 +1,26 @@ +""" +Translate from OpenAI's `/v1/chat/completions` to Sagemaker's `/invocations` API + +Called if Sagemaker endpoint supports HF Messages API. + +LiteLLM Docs: https://docs.litellm.ai/docs/providers/aws_sagemaker#sagemaker-messages-api +Huggingface Docs: https://huggingface.co/docs/text-generation-inference/en/messages_api +""" + +from typing import Union + +from httpx._models import Headers + +from litellm.llms.base_llm.transformation import BaseLLMException + +from ...OpenAI.chat.gpt_transformation import OpenAIGPTConfig +from ..common_utils import SagemakerError + + +class SagemakerChatConfig(OpenAIGPTConfig): + def get_error_class( + self, error_message: str, status_code: int, headers: Union[dict, Headers] + ) -> BaseLLMException: + return SagemakerError( + status_code=status_code, message=error_message, headers=headers + ) diff --git a/litellm/llms/sagemaker/common_utils.py b/litellm/llms/sagemaker/common_utils.py new file mode 100644 index 000000000000..8fa450a8d57e --- /dev/null +++ b/litellm/llms/sagemaker/common_utils.py @@ -0,0 +1,198 @@ +import json +from typing import AsyncIterator, Iterator, List, Optional, Union + +import httpx + +from litellm import verbose_logger +from litellm.llms.base_llm.transformation import BaseLLMException +from litellm.types.utils import GenericStreamingChunk as GChunk +from litellm.types.utils import StreamingChatCompletionChunk + +_response_stream_shape_cache = None + + +class SagemakerError(BaseLLMException): + def __init__( + self, + status_code: int, + message: str, + headers: Optional[Union[dict, httpx.Headers]] = None, + ): + super().__init__(status_code=status_code, message=message, headers=headers) + + +class AWSEventStreamDecoder: + def __init__(self, model: str, is_messages_api: Optional[bool] = None) -> None: + from botocore.parsers import EventStreamJSONParser + + self.model = model + self.parser = EventStreamJSONParser() + self.content_blocks: List = [] + self.is_messages_api = is_messages_api + + def _chunk_parser_messages_api( + self, chunk_data: dict + ) -> StreamingChatCompletionChunk: + + openai_chunk = StreamingChatCompletionChunk(**chunk_data) + + return openai_chunk + + def _chunk_parser(self, chunk_data: dict) -> GChunk: + verbose_logger.debug("in sagemaker chunk parser, chunk_data %s", chunk_data) + _token = chunk_data.get("token", {}) or {} + _index = chunk_data.get("index", None) or 0 + is_finished = False + finish_reason = "" + + _text = _token.get("text", "") + if _text == "<|endoftext|>": + return GChunk( + text="", + index=_index, + is_finished=True, + finish_reason="stop", + usage=None, + ) + + return GChunk( + text=_text, + index=_index, + is_finished=is_finished, + finish_reason=finish_reason, + usage=None, + ) + + def iter_bytes( + self, iterator: Iterator[bytes] + ) -> Iterator[Optional[Union[GChunk, StreamingChatCompletionChunk]]]: + """Given an iterator that yields lines, iterate over it & yield every event encountered""" + from botocore.eventstream import EventStreamBuffer + + event_stream_buffer = EventStreamBuffer() + accumulated_json = "" + + for chunk in iterator: + event_stream_buffer.add_data(chunk) + for event in event_stream_buffer: + message = self._parse_message_from_event(event) + if message: + # remove data: prefix and "\n\n" at the end + message = message.replace("data:", "").replace("\n\n", "") + + # Accumulate JSON data + accumulated_json += message + + # Try to parse the accumulated JSON + try: + _data = json.loads(accumulated_json) + if self.is_messages_api: + yield self._chunk_parser_messages_api(chunk_data=_data) + else: + yield self._chunk_parser(chunk_data=_data) + # Reset accumulated_json after successful parsing + accumulated_json = "" + except json.JSONDecodeError: + # If it's not valid JSON yet, continue to the next event + continue + + # Handle any remaining data after the iterator is exhausted + if accumulated_json: + try: + _data = json.loads(accumulated_json) + if self.is_messages_api: + yield self._chunk_parser_messages_api(chunk_data=_data) + else: + yield self._chunk_parser(chunk_data=_data) + except json.JSONDecodeError: + # Handle or log any unparseable data at the end + verbose_logger.error( + f"Warning: Unparseable JSON data remained: {accumulated_json}" + ) + yield None + + async def aiter_bytes( + self, iterator: AsyncIterator[bytes] + ) -> AsyncIterator[Optional[Union[GChunk, StreamingChatCompletionChunk]]]: + """Given an async iterator that yields lines, iterate over it & yield every event encountered""" + from botocore.eventstream import EventStreamBuffer + + event_stream_buffer = EventStreamBuffer() + accumulated_json = "" + + async for chunk in iterator: + event_stream_buffer.add_data(chunk) + for event in event_stream_buffer: + message = self._parse_message_from_event(event) + if message: + verbose_logger.debug("sagemaker parsed chunk bytes %s", message) + # remove data: prefix and "\n\n" at the end + message = message.replace("data:", "").replace("\n\n", "") + + # Accumulate JSON data + accumulated_json += message + + # Try to parse the accumulated JSON + try: + _data = json.loads(accumulated_json) + if self.is_messages_api: + yield self._chunk_parser_messages_api(chunk_data=_data) + else: + yield self._chunk_parser(chunk_data=_data) + # Reset accumulated_json after successful parsing + accumulated_json = "" + except json.JSONDecodeError: + # If it's not valid JSON yet, continue to the next event + continue + + # Handle any remaining data after the iterator is exhausted + if accumulated_json: + try: + _data = json.loads(accumulated_json) + if self.is_messages_api: + yield self._chunk_parser_messages_api(chunk_data=_data) + else: + yield self._chunk_parser(chunk_data=_data) + except json.JSONDecodeError: + # Handle or log any unparseable data at the end + verbose_logger.error( + f"Warning: Unparseable JSON data remained: {accumulated_json}" + ) + yield None + + def _parse_message_from_event(self, event) -> Optional[str]: + response_dict = event.to_response_dict() + parsed_response = self.parser.parse(response_dict, get_response_stream_shape()) + + if response_dict["status_code"] != 200: + raise ValueError(f"Bad response code, expected 200: {response_dict}") + + if "chunk" in parsed_response: + chunk = parsed_response.get("chunk") + if not chunk: + return None + return chunk.get("bytes").decode() # type: ignore[no-any-return] + else: + chunk = response_dict.get("body") + if not chunk: + return None + + return chunk.decode() # type: ignore[no-any-return] + + +def get_response_stream_shape(): + global _response_stream_shape_cache + if _response_stream_shape_cache is None: + + from botocore.loaders import Loader + from botocore.model import ServiceModel + + loader = Loader() + sagemaker_service_dict = loader.load_service_model( + "sagemaker-runtime", "service-2" + ) + sagemaker_service_model = ServiceModel(sagemaker_service_dict) + _response_stream_shape_cache = sagemaker_service_model.shape_for( + "InvokeEndpointWithResponseStreamOutput" + ) + return _response_stream_shape_cache diff --git a/litellm/llms/sagemaker/sagemaker.py b/litellm/llms/sagemaker/completion/handler.py similarity index 55% rename from litellm/llms/sagemaker/sagemaker.py rename to litellm/llms/sagemaker/completion/handler.py index 2e6c72ac80c0..648f184e89c5 100644 --- a/litellm/llms/sagemaker/sagemaker.py +++ b/litellm/llms/sagemaker/completion/handler.py @@ -22,12 +22,7 @@ _get_httpx_client, get_async_httpx_client, ) -from litellm.types.llms.openai import ( - ChatCompletionToolCallChunk, - ChatCompletionUsageBlock, -) -from litellm.types.utils import GenericStreamingChunk as GChunk -from litellm.types.utils import StreamingChatCompletionChunk +from litellm.types.llms.openai import AllMessageValues from litellm.utils import ( CustomStreamWrapper, EmbeddingResponse, @@ -36,65 +31,12 @@ get_secret, ) -from ..base_aws_llm import BaseAWSLLM -from ..prompt_templates.factory import custom_prompt, prompt_factory - -_response_stream_shape_cache = None - - -class SagemakerError(Exception): - def __init__(self, status_code, message): - self.status_code = status_code - self.message = message - self.request = httpx.Request( - method="POST", url="https://us-west-2.console.aws.amazon.com/sagemaker" - ) - self.response = httpx.Response(status_code=status_code, request=self.request) - super().__init__( - self.message - ) # Call the base class constructor with the parameters it needs - - -class SagemakerConfig: - """ - Reference: https://d-uuwbxj1u4cnu.studio.us-west-2.sagemaker.aws/jupyter/default/lab/workspaces/auto-q/tree/DemoNotebooks/meta-textgeneration-llama-2-7b-SDK_1.ipynb - """ - - max_new_tokens: Optional[int] = None - top_p: Optional[float] = None - temperature: Optional[float] = None - return_full_text: Optional[bool] = None - - def __init__( - self, - max_new_tokens: Optional[int] = None, - top_p: Optional[float] = None, - temperature: Optional[float] = None, - return_full_text: Optional[bool] = None, - ) -> None: - locals_ = locals() - for key, value in locals_.items(): - if key != "self" and value is not None: - setattr(self.__class__, key, value) - - @classmethod - def get_config(cls): - return { - k: v - for k, v in cls.__dict__.items() - if not k.startswith("__") - and not isinstance( - v, - ( - types.FunctionType, - types.BuiltinFunctionType, - classmethod, - staticmethod, - ), - ) - and v is not None - } +from ...base_aws_llm import BaseAWSLLM +from ...prompt_templates.factory import custom_prompt, prompt_factory +from ..common_utils import AWSEventStreamDecoder, SagemakerError +from .transformation import SagemakerConfig +sagemaker_config = SagemakerConfig() """ SAGEMAKER AUTH Keys/Vars @@ -166,6 +108,7 @@ def _prepare_request( credentials, model: str, data: dict, + messages: List[AllMessageValues], optional_params: dict, aws_region_name: str, extra_headers: Optional[dict] = None, @@ -189,9 +132,12 @@ def _prepare_request( api_base = sagemaker_base_url encoded_data = json.dumps(data).encode("utf-8") - headers = {"Content-Type": "application/json"} - if extra_headers is not None: - headers = {"Content-Type": "application/json", **extra_headers} + headers = sagemaker_config.validate_environment( + headers=extra_headers, + model=model, + messages=messages, + optional_params=optional_params, + ) request = AWSRequest( method="POST", url=api_base, data=encoded_data, headers=headers ) @@ -205,49 +151,6 @@ def _prepare_request( return prepped_request - def _transform_prompt( - self, - model: str, - messages: List, - custom_prompt_dict: dict, - hf_model_name: Optional[str], - ) -> str: - if model in custom_prompt_dict: - # check if the model has a registered custom prompt - model_prompt_details = custom_prompt_dict[model] - prompt = custom_prompt( - role_dict=model_prompt_details.get("roles", None), - initial_prompt_value=model_prompt_details.get( - "initial_prompt_value", "" - ), - final_prompt_value=model_prompt_details.get("final_prompt_value", ""), - messages=messages, - ) - elif hf_model_name in custom_prompt_dict: - # check if the base huggingface model has a registered custom prompt - model_prompt_details = custom_prompt_dict[hf_model_name] - prompt = custom_prompt( - role_dict=model_prompt_details.get("roles", None), - initial_prompt_value=model_prompt_details.get( - "initial_prompt_value", "" - ), - final_prompt_value=model_prompt_details.get("final_prompt_value", ""), - messages=messages, - ) - else: - if hf_model_name is None: - if "llama-2" in model.lower(): # llama-2 model - if "chat" in model.lower(): # apply llama2 chat template - hf_model_name = "meta-llama/Llama-2-7b-chat-hf" - else: # apply regular llama2 template - hf_model_name = "meta-llama/Llama-2-7b" - hf_model_name = ( - hf_model_name or model - ) # pass in hf model name for pulling it's prompt template - (e.g. `hf_model_name="meta-llama/Llama-2-7b-chat-hf` applies the llama2 chat template to the prompt) - prompt: str = prompt_factory(model=hf_model_name, messages=messages) # type: ignore - - return prompt - def completion( # noqa: PLR0915 self, model: str, @@ -257,13 +160,13 @@ def completion( # noqa: PLR0915 encoding, logging_obj, optional_params: dict, + litellm_params: dict, timeout: Optional[Union[float, httpx.Timeout]] = None, custom_prompt_dict={}, hf_model_name=None, - litellm_params=None, logger_fn=None, acompletion: bool = False, - use_messages_api: Optional[bool] = None, + headers: dict = {}, ): # pop streaming if it's in the optional params as 'stream' raises an error with sagemaker @@ -272,50 +175,6 @@ def completion( # noqa: PLR0915 stream = inference_params.pop("stream", None) model_id = optional_params.get("model_id", None) - if use_messages_api is True: - from litellm.llms.openai_like.chat.handler import OpenAILikeChatHandler - - openai_like_chat_completions = OpenAILikeChatHandler() - inference_params["stream"] = True if stream is True else False - _data: Dict[str, Any] = { - "model": model, - "messages": messages, - **inference_params, - } - - prepared_request = self._prepare_request( - model=model, - data=_data, - optional_params=optional_params, - credentials=credentials, - aws_region_name=aws_region_name, - ) - - custom_stream_decoder = AWSEventStreamDecoder( - model="", is_messages_api=True - ) - - return openai_like_chat_completions.completion( - model=model, - messages=messages, - api_base=prepared_request.url, - api_key=None, - custom_prompt_dict=custom_prompt_dict, - model_response=model_response, - print_verbose=print_verbose, - logging_obj=logging_obj, - optional_params=inference_params, - acompletion=acompletion, - litellm_params=litellm_params, - logger_fn=logger_fn, - timeout=timeout, - encoding=encoding, - headers=prepared_request.headers, # type: ignore - custom_endpoint=True, - custom_llm_provider="sagemaker_chat", - streaming_decoder=custom_stream_decoder, # type: ignore - ) - ## Load Config config = litellm.SagemakerConfig.get_config() for k, v in config.items(): @@ -325,21 +184,6 @@ def completion( # noqa: PLR0915 inference_params[k] = v if stream is True: - data = {"parameters": inference_params, "stream": True} - prepared_request = self._prepare_request( - model=model, - data=data, - optional_params=optional_params, - credentials=credentials, - aws_region_name=aws_region_name, - ) - if model_id is not None: - # Add model_id as InferenceComponentName header - # boto3 doc: https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_runtime_InvokeEndpoint.html - prepared_request.headers.update( - {"X-Amzn-SageMaker-Inference-Component": model_id} - ) - if acompletion is True: response = self.async_streaming( messages=messages, @@ -350,23 +194,25 @@ def completion( # noqa: PLR0915 encoding=encoding, model_response=model_response, logging_obj=logging_obj, - data=data, model_id=model_id, aws_region_name=aws_region_name, credentials=credentials, + headers=headers, + litellm_params=litellm_params, ) return response else: - prompt = self._transform_prompt( + data = sagemaker_config.transform_request( model=model, messages=messages, - custom_prompt_dict=custom_prompt_dict, - hf_model_name=hf_model_name, + optional_params=optional_params, + litellm_params=litellm_params, + headers=headers, ) - data["inputs"] = prompt prepared_request = self._prepare_request( model=model, data=data, + messages=messages, optional_params=optional_params, credentials=credentials, aws_region_name=aws_region_name, @@ -388,7 +234,7 @@ def completion( # noqa: PLR0915 if sync_response.status_code != 200: raise SagemakerError( status_code=sync_response.status_code, - message=sync_response.read(), + message=str(sync_response.read()), ) decoder = AWSEventStreamDecoder(model="") @@ -413,14 +259,6 @@ def completion( # noqa: PLR0915 return streaming_response # Non-Streaming Requests - _data = {"parameters": inference_params} - prepared_request_args = { - "model": model, - "data": _data, - "optional_params": optional_params, - "credentials": credentials, - "aws_region_name": aws_region_name, - } # Async completion if acompletion is True: @@ -432,21 +270,30 @@ def completion( # noqa: PLR0915 model_response=model_response, encoding=encoding, logging_obj=logging_obj, - data=_data, model_id=model_id, optional_params=optional_params, credentials=credentials, aws_region_name=aws_region_name, + headers=headers, + litellm_params=litellm_params, ) - prompt = self._transform_prompt( + ## Non-Streaming completion CALL + _data = sagemaker_config.transform_request( model=model, messages=messages, - custom_prompt_dict=custom_prompt_dict, - hf_model_name=hf_model_name, + optional_params=optional_params, + litellm_params=litellm_params, + headers=headers, ) - _data["inputs"] = prompt - ## Non-Streaming completion CALL + prepared_request_args = { + "model": model, + "data": _data, + "optional_params": optional_params, + "credentials": credentials, + "aws_region_name": aws_region_name, + "messages": messages, + } prepared_request = self._prepare_request(**prepared_request_args) try: if model_id is not None: @@ -507,53 +354,16 @@ def completion( # noqa: PLR0915 error_message += "\n pass in via `litellm.completion(..., model_id={InferenceComponentName})`" raise SagemakerError(status_code=status_code, message=error_message) - completion_response = sync_response.json() - ## LOGGING - logging_obj.post_call( - input=prompt, - api_key="", - original_response=completion_response, - additional_args={"complete_input_dict": _data}, - ) - print_verbose(f"raw model_response: {completion_response}") - ## RESPONSE OBJECT - try: - if isinstance(completion_response, list): - completion_response_choices = completion_response[0] - else: - completion_response_choices = completion_response - completion_output = "" - if "generation" in completion_response_choices: - completion_output += completion_response_choices["generation"] - elif "generated_text" in completion_response_choices: - completion_output += completion_response_choices["generated_text"] - - # check if the prompt template is part of output, if so - filter it out - if completion_output.startswith(prompt) and "" in prompt: - completion_output = completion_output.replace(prompt, "", 1) - - model_response.choices[0].message.content = completion_output # type: ignore - except Exception: - raise SagemakerError( - message=f"LiteLLM Error: Unable to parse sagemaker RAW RESPONSE {json.dumps(completion_response)}", - status_code=500, - ) - - ## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here. - prompt_tokens = len(encoding.encode(prompt)) - completion_tokens = len( - encoding.encode(model_response["choices"][0]["message"].get("content", "")) - ) - - model_response.created = int(time.time()) - model_response.model = model - usage = Usage( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=prompt_tokens + completion_tokens, + return sagemaker_config.transform_response( + model=model, + raw_response=sync_response, + model_response=model_response, + logging_obj=logging_obj, + request_data=_data, + messages=messages, + optional_params=optional_params, + encoding=encoding, ) - setattr(model_response, "usage", usage) - return model_response async def make_async_call( self, @@ -605,7 +415,7 @@ async def make_async_call( async def async_streaming( self, - messages: list, + messages: List[AllMessageValues], model: str, custom_prompt_dict: dict, hf_model_name: Optional[str], @@ -616,13 +426,15 @@ async def async_streaming( model_response: ModelResponse, model_id: Optional[str], logging_obj: Any, - data, + litellm_params: dict, + headers: dict, ): - data["inputs"] = self._transform_prompt( + data = await sagemaker_config.async_transform_request( model=model, messages=messages, - custom_prompt_dict=custom_prompt_dict, - hf_model_name=hf_model_name, + optional_params={**optional_params, "stream": True}, + litellm_params=litellm_params, + headers=headers, ) asyncified_prepare_request = asyncify(self._prepare_request) prepared_request_args = { @@ -631,6 +443,7 @@ async def async_streaming( "optional_params": optional_params, "credentials": credentials, "aws_region_name": aws_region_name, + "messages": messages, } prepared_request = await asyncified_prepare_request(**prepared_request_args) completion_stream = await self.make_async_call( @@ -658,7 +471,7 @@ async def async_streaming( async def async_completion( self, - messages: list, + messages: List[AllMessageValues], model: str, custom_prompt_dict: dict, hf_model_name: Optional[str], @@ -668,22 +481,23 @@ async def async_completion( model_response: ModelResponse, optional_params: dict, logging_obj: Any, - data: dict, model_id: Optional[str], + headers: dict, + litellm_params: dict, ): timeout = 300.0 async_handler = get_async_httpx_client( llm_provider=litellm.LlmProviders.SAGEMAKER ) - async_transform_prompt = asyncify(self._transform_prompt) - - data["inputs"] = await async_transform_prompt( + data = await sagemaker_config.async_transform_request( model=model, messages=messages, - custom_prompt_dict=custom_prompt_dict, - hf_model_name=hf_model_name, + optional_params=optional_params, + litellm_params=litellm_params, + headers=headers, ) + asyncified_prepare_request = asyncify(self._prepare_request) prepared_request_args = { "model": model, @@ -691,6 +505,7 @@ async def async_completion( "optional_params": optional_params, "credentials": credentials, "aws_region_name": aws_region_name, + "messages": messages, } prepared_request = await asyncified_prepare_request(**prepared_request_args) @@ -738,52 +553,16 @@ async def async_completion( if "Inference Component Name header is required" in error_message: error_message += "\n pass in via `litellm.completion(..., model_id={InferenceComponentName})`" raise SagemakerError(status_code=500, message=error_message) - completion_response = response.json() - ## LOGGING - logging_obj.post_call( - input=data["inputs"], - api_key="", - original_response=response, - additional_args={"complete_input_dict": data}, - ) - ## RESPONSE OBJECT - try: - if isinstance(completion_response, list): - completion_response_choices = completion_response[0] - else: - completion_response_choices = completion_response - completion_output = "" - if "generation" in completion_response_choices: - completion_output += completion_response_choices["generation"] - elif "generated_text" in completion_response_choices: - completion_output += completion_response_choices["generated_text"] - - # check if the prompt template is part of output, if so - filter it out - if completion_output.startswith(data["inputs"]) and "" in data["inputs"]: - completion_output = completion_output.replace(data["inputs"], "", 1) - - model_response.choices[0].message.content = completion_output # type: ignore - except Exception: - raise SagemakerError( - message=f"LiteLLM Error: Unable to parse sagemaker RAW RESPONSE {json.dumps(completion_response)}", - status_code=500, - ) - - ## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here. - prompt_tokens = len(encoding.encode(data["inputs"])) - completion_tokens = len( - encoding.encode(model_response["choices"][0]["message"].get("content", "")) - ) - - model_response.created = int(time.time()) - model_response.model = model - usage = Usage( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=prompt_tokens + completion_tokens, + return sagemaker_config.transform_response( + model=model, + raw_response=response, + model_response=model_response, + logging_obj=logging_obj, + request_data=data, + messages=messages, + optional_params=optional_params, + encoding=encoding, ) - setattr(model_response, "usage", usage) - return model_response def embedding( self, @@ -928,180 +707,3 @@ def embedding( ) return model_response - - -def get_response_stream_shape(): - global _response_stream_shape_cache - if _response_stream_shape_cache is None: - - from botocore.loaders import Loader - from botocore.model import ServiceModel - - loader = Loader() - sagemaker_service_dict = loader.load_service_model( - "sagemaker-runtime", "service-2" - ) - sagemaker_service_model = ServiceModel(sagemaker_service_dict) - _response_stream_shape_cache = sagemaker_service_model.shape_for( - "InvokeEndpointWithResponseStreamOutput" - ) - return _response_stream_shape_cache - - -class AWSEventStreamDecoder: - def __init__(self, model: str, is_messages_api: Optional[bool] = None) -> None: - from botocore.parsers import EventStreamJSONParser - - self.model = model - self.parser = EventStreamJSONParser() - self.content_blocks: List = [] - self.is_messages_api = is_messages_api - - def _chunk_parser_messages_api( - self, chunk_data: dict - ) -> StreamingChatCompletionChunk: - - openai_chunk = StreamingChatCompletionChunk(**chunk_data) - - return openai_chunk - - def _chunk_parser(self, chunk_data: dict) -> GChunk: - verbose_logger.debug("in sagemaker chunk parser, chunk_data %s", chunk_data) - _token = chunk_data.get("token", {}) or {} - _index = chunk_data.get("index", None) or 0 - is_finished = False - finish_reason = "" - - _text = _token.get("text", "") - if _text == "<|endoftext|>": - return GChunk( - text="", - index=_index, - is_finished=True, - finish_reason="stop", - usage=None, - ) - - return GChunk( - text=_text, - index=_index, - is_finished=is_finished, - finish_reason=finish_reason, - usage=None, - ) - - def iter_bytes( - self, iterator: Iterator[bytes] - ) -> Iterator[Optional[Union[GChunk, StreamingChatCompletionChunk]]]: - """Given an iterator that yields lines, iterate over it & yield every event encountered""" - from botocore.eventstream import EventStreamBuffer - - event_stream_buffer = EventStreamBuffer() - accumulated_json = "" - - for chunk in iterator: - event_stream_buffer.add_data(chunk) - for event in event_stream_buffer: - message = self._parse_message_from_event(event) - if message: - # remove data: prefix and "\n\n" at the end - message = message.replace("data:", "").replace("\n\n", "") - - # Accumulate JSON data - accumulated_json += message - - # Try to parse the accumulated JSON - try: - _data = json.loads(accumulated_json) - if self.is_messages_api: - yield self._chunk_parser_messages_api(chunk_data=_data) - else: - yield self._chunk_parser(chunk_data=_data) - # Reset accumulated_json after successful parsing - accumulated_json = "" - except json.JSONDecodeError: - # If it's not valid JSON yet, continue to the next event - continue - - # Handle any remaining data after the iterator is exhausted - if accumulated_json: - try: - _data = json.loads(accumulated_json) - if self.is_messages_api: - yield self._chunk_parser_messages_api(chunk_data=_data) - else: - yield self._chunk_parser(chunk_data=_data) - except json.JSONDecodeError: - # Handle or log any unparseable data at the end - verbose_logger.error( - f"Warning: Unparseable JSON data remained: {accumulated_json}" - ) - yield None - - async def aiter_bytes( - self, iterator: AsyncIterator[bytes] - ) -> AsyncIterator[Optional[Union[GChunk, StreamingChatCompletionChunk]]]: - """Given an async iterator that yields lines, iterate over it & yield every event encountered""" - from botocore.eventstream import EventStreamBuffer - - event_stream_buffer = EventStreamBuffer() - accumulated_json = "" - - async for chunk in iterator: - event_stream_buffer.add_data(chunk) - for event in event_stream_buffer: - message = self._parse_message_from_event(event) - if message: - verbose_logger.debug("sagemaker parsed chunk bytes %s", message) - # remove data: prefix and "\n\n" at the end - message = message.replace("data:", "").replace("\n\n", "") - - # Accumulate JSON data - accumulated_json += message - - # Try to parse the accumulated JSON - try: - _data = json.loads(accumulated_json) - if self.is_messages_api: - yield self._chunk_parser_messages_api(chunk_data=_data) - else: - yield self._chunk_parser(chunk_data=_data) - # Reset accumulated_json after successful parsing - accumulated_json = "" - except json.JSONDecodeError: - # If it's not valid JSON yet, continue to the next event - continue - - # Handle any remaining data after the iterator is exhausted - if accumulated_json: - try: - _data = json.loads(accumulated_json) - if self.is_messages_api: - yield self._chunk_parser_messages_api(chunk_data=_data) - else: - yield self._chunk_parser(chunk_data=_data) - except json.JSONDecodeError: - # Handle or log any unparseable data at the end - verbose_logger.error( - f"Warning: Unparseable JSON data remained: {accumulated_json}" - ) - yield None - - def _parse_message_from_event(self, event) -> Optional[str]: - response_dict = event.to_response_dict() - parsed_response = self.parser.parse(response_dict, get_response_stream_shape()) - - if response_dict["status_code"] != 200: - raise ValueError(f"Bad response code, expected 200: {response_dict}") - - if "chunk" in parsed_response: - chunk = parsed_response.get("chunk") - if not chunk: - return None - return chunk.get("bytes").decode() # type: ignore[no-any-return] - else: - chunk = response_dict.get("body") - if not chunk: - return None - - return chunk.decode() # type: ignore[no-any-return] diff --git a/litellm/llms/sagemaker/completion/transformation.py b/litellm/llms/sagemaker/completion/transformation.py new file mode 100644 index 000000000000..e6bfbb33f625 --- /dev/null +++ b/litellm/llms/sagemaker/completion/transformation.py @@ -0,0 +1,272 @@ +""" +Translate from OpenAI's `/v1/chat/completions` to Sagemaker's `/invoke` + +In the Huggingface TGI format. +""" + +import json +import time +import types +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union + +from httpx._models import Headers, Response + +import litellm +from litellm.litellm_core_utils.asyncify import asyncify +from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException +from litellm.llms.prompt_templates.factory import custom_prompt, prompt_factory +from litellm.types.llms.openai import AllMessageValues +from litellm.types.utils import Usage + +from ..common_utils import SagemakerError + +if TYPE_CHECKING: + from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj + + LiteLLMLoggingObj = _LiteLLMLoggingObj +else: + LiteLLMLoggingObj = Any + + +class SagemakerConfig(BaseConfig): + """ + Reference: https://d-uuwbxj1u4cnu.studio.us-west-2.sagemaker.aws/jupyter/default/lab/workspaces/auto-q/tree/DemoNotebooks/meta-textgeneration-llama-2-7b-SDK_1.ipynb + """ + + max_new_tokens: Optional[int] = None + top_p: Optional[float] = None + temperature: Optional[float] = None + return_full_text: Optional[bool] = None + + def __init__( + self, + max_new_tokens: Optional[int] = None, + top_p: Optional[float] = None, + temperature: Optional[float] = None, + return_full_text: Optional[bool] = None, + ) -> None: + locals_ = locals() + for key, value in locals_.items(): + if key != "self" and value is not None: + setattr(self.__class__, key, value) + + @classmethod + def get_config(cls): + return super().get_config() + + def _transform_messages( + self, + messages: List[AllMessageValues], + ) -> List[AllMessageValues]: + return messages + + def get_error_class( + self, error_message: str, status_code: int, headers: Union[dict, Headers] + ) -> BaseLLMException: + return SagemakerError( + message=error_message, status_code=status_code, headers=headers + ) + + def get_supported_openai_params(self, model: str) -> List: + return ["stream", "temperature", "max_tokens", "top_p", "stop", "n"] + + def map_openai_params( + self, + non_default_params: dict, + optional_params: dict, + model: str, + drop_params: bool, + ) -> dict: + for param, value in non_default_params.items(): + if param == "temperature": + if value == 0.0 or value == 0: + # hugging face exception raised when temp==0 + # Failed: Error occurred: HuggingfaceException - Input validation error: `temperature` must be strictly positive + if not non_default_params.get( + "aws_sagemaker_allow_zero_temp", False + ): + value = 0.01 + + optional_params["temperature"] = value + if param == "top_p": + optional_params["top_p"] = value + if param == "n": + optional_params["best_of"] = value + optional_params["do_sample"] = ( + True # Need to sample if you want best of for hf inference endpoints + ) + if param == "stream": + optional_params["stream"] = value + if param == "stop": + optional_params["stop"] = value + if param == "max_tokens": + # HF TGI raises the following exception when max_new_tokens==0 + # Failed: Error occurred: HuggingfaceException - Input validation error: `max_new_tokens` must be strictly positive + if value == 0: + value = 1 + optional_params["max_new_tokens"] = value + non_default_params.pop("aws_sagemaker_allow_zero_temp", None) + return optional_params + + def _transform_prompt( + self, + model: str, + messages: List, + custom_prompt_dict: dict, + hf_model_name: Optional[str], + ) -> str: + if model in custom_prompt_dict: + # check if the model has a registered custom prompt + model_prompt_details = custom_prompt_dict[model] + prompt = custom_prompt( + role_dict=model_prompt_details.get("roles", None), + initial_prompt_value=model_prompt_details.get( + "initial_prompt_value", "" + ), + final_prompt_value=model_prompt_details.get("final_prompt_value", ""), + messages=messages, + ) + elif hf_model_name in custom_prompt_dict: + # check if the base huggingface model has a registered custom prompt + model_prompt_details = custom_prompt_dict[hf_model_name] + prompt = custom_prompt( + role_dict=model_prompt_details.get("roles", None), + initial_prompt_value=model_prompt_details.get( + "initial_prompt_value", "" + ), + final_prompt_value=model_prompt_details.get("final_prompt_value", ""), + messages=messages, + ) + else: + if hf_model_name is None: + if "llama-2" in model.lower(): # llama-2 model + if "chat" in model.lower(): # apply llama2 chat template + hf_model_name = "meta-llama/Llama-2-7b-chat-hf" + else: # apply regular llama2 template + hf_model_name = "meta-llama/Llama-2-7b" + hf_model_name = ( + hf_model_name or model + ) # pass in hf model name for pulling it's prompt template - (e.g. `hf_model_name="meta-llama/Llama-2-7b-chat-hf` applies the llama2 chat template to the prompt) + prompt: str = prompt_factory(model=hf_model_name, messages=messages) # type: ignore + + return prompt + + def transform_request( + self, + model: str, + messages: List[AllMessageValues], + optional_params: dict, + litellm_params: dict, + headers: dict, + ) -> dict: + inference_params = optional_params.copy() + stream = inference_params.pop("stream", False) + data: Dict = {"parameters": inference_params} + if stream is True: + data["stream"] = True + + custom_prompt_dict = ( + litellm_params.get("custom_prompt_dict", None) or litellm.custom_prompt_dict + ) + + hf_model_name = litellm_params.get("hf_model_name", None) + + prompt = self._transform_prompt( + model=model, + messages=messages, + custom_prompt_dict=custom_prompt_dict, + hf_model_name=hf_model_name, + ) + data["inputs"] = prompt + + return data + + async def async_transform_request( + self, + model: str, + messages: List[AllMessageValues], + optional_params: dict, + litellm_params: dict, + headers: dict, + ) -> dict: + return await asyncify(self.transform_request)( + model, messages, optional_params, litellm_params, headers + ) + + def transform_response( + self, + model: str, + raw_response: Response, + model_response: litellm.ModelResponse, + logging_obj: LiteLLMLoggingObj, + request_data: dict, + messages: List[AllMessageValues], + optional_params: dict, + encoding: str, + api_key: Optional[str] = None, + json_mode: Optional[bool] = None, + ) -> litellm.ModelResponse: + completion_response = raw_response.json() + ## LOGGING + logging_obj.post_call( + input=messages, + api_key="", + original_response=completion_response, + additional_args={"complete_input_dict": request_data}, + ) + + prompt = request_data["inputs"] + + ## RESPONSE OBJECT + try: + if isinstance(completion_response, list): + completion_response_choices = completion_response[0] + else: + completion_response_choices = completion_response + completion_output = "" + if "generation" in completion_response_choices: + completion_output += completion_response_choices["generation"] + elif "generated_text" in completion_response_choices: + completion_output += completion_response_choices["generated_text"] + + # check if the prompt template is part of output, if so - filter it out + if completion_output.startswith(prompt) and "" in prompt: + completion_output = completion_output.replace(prompt, "", 1) + + model_response.choices[0].message.content = completion_output # type: ignore + except Exception: + raise SagemakerError( + message=f"LiteLLM Error: Unable to parse sagemaker RAW RESPONSE {json.dumps(completion_response)}", + status_code=500, + ) + + ## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here. + prompt_tokens = len(encoding.encode(prompt)) + completion_tokens = len( + encoding.encode(model_response["choices"][0]["message"].get("content", "")) + ) + + model_response.created = int(time.time()) + model_response.model = model + usage = Usage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ) + setattr(model_response, "usage", usage) + return model_response + + def validate_environment( + self, + headers: Optional[dict], + model: str, + messages: List[AllMessageValues], + optional_params: dict, + api_key: Optional[str] = None, + ) -> dict: + headers = {"Content-Type": "application/json"} + + if headers is not None: + headers = {"Content-Type": "application/json", **headers} + + return headers diff --git a/litellm/main.py b/litellm/main.py index 10142bdb25ea..84b73fe7f6e3 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -130,7 +130,8 @@ prompt_factory, stringify_json_tool_call_content, ) -from .llms.sagemaker.sagemaker import SagemakerLLM +from .llms.sagemaker.chat.handler import SagemakerChatHandler +from .llms.sagemaker.completion.handler import SagemakerLLM from .llms.text_completion_codestral import CodestralTextCompletion from .llms.together_ai.completion.handler import TogetherAITextCompletion from .llms.triton import TritonChatCompletion @@ -229,6 +230,7 @@ openai_like_embedding = OpenAILikeEmbeddingHandler() databricks_embedding = DatabricksEmbeddingHandler() base_llm_http_handler = BaseLLMHTTPHandler() +sagemaker_chat_completion = SagemakerChatHandler() ####### COMPLETION ENDPOINTS ################ @@ -2513,10 +2515,23 @@ def completion( # type: ignore # noqa: PLR0915 ## RESPONSE OBJECT response = model_response - elif ( - custom_llm_provider == "sagemaker" - or custom_llm_provider == "sagemaker_chat" - ): + elif custom_llm_provider == "sagemaker_chat": + # boto3 reads keys from .env + response = sagemaker_chat_completion.completion( + model=model, + messages=messages, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + custom_prompt_dict=custom_prompt_dict, + logger_fn=logger_fn, + encoding=encoding, + logging_obj=logging, + acompletion=acompletion, + headers=headers or {}, + ) + elif custom_llm_provider == "sagemaker": # boto3 reads keys from .env model_response = sagemaker_llm.completion( model=model, @@ -2531,17 +2546,7 @@ def completion( # type: ignore # noqa: PLR0915 encoding=encoding, logging_obj=logging, acompletion=acompletion, - use_messages_api=( - True if custom_llm_provider == "sagemaker_chat" else False - ), ) - if optional_params.get("stream", False): - ## LOGGING - logging.post_call( - input=messages, - api_key=None, - original_response=model_response, - ) ## RESPONSE OBJECT response = model_response diff --git a/litellm/utils.py b/litellm/utils.py index 5321357a8700..31b96245f3a9 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -3145,31 +3145,16 @@ def _map_and_modify_arg(supported_params: dict, provider: str, model: str): ) _check_valid_arg(supported_params=supported_params) # temperature, top_p, n, stream, stop, max_tokens, n, presence_penalty default to None - if temperature is not None: - if temperature == 0.0 or temperature == 0: - # hugging face exception raised when temp==0 - # Failed: Error occurred: HuggingfaceException - Input validation error: `temperature` must be strictly positive - if not passed_params.get("aws_sagemaker_allow_zero_temp", False): - temperature = 0.01 - optional_params["temperature"] = temperature - if top_p is not None: - optional_params["top_p"] = top_p - if n is not None: - optional_params["best_of"] = n - optional_params["do_sample"] = ( - True # Need to sample if you want best of for hf inference endpoints - ) - if stream is not None: - optional_params["stream"] = stream - if stop is not None: - optional_params["stop"] = stop - if max_tokens is not None: - # HF TGI raises the following exception when max_new_tokens==0 - # Failed: Error occurred: HuggingfaceException - Input validation error: `max_new_tokens` must be strictly positive - if max_tokens == 0: - max_tokens = 1 - optional_params["max_new_tokens"] = max_tokens - passed_params.pop("aws_sagemaker_allow_zero_temp", None) + optional_params = litellm.SagemakerConfig().map_openai_params( + non_default_params=non_default_params, + optional_params=optional_params, + model=model, + drop_params=( + drop_params + if drop_params is not None and isinstance(drop_params, bool) + else False + ), + ) elif custom_llm_provider == "bedrock": supported_params = get_supported_openai_params( model=model, custom_llm_provider=custom_llm_provider @@ -6284,7 +6269,10 @@ def get_provider_chat_config( return litellm.VertexAIAnthropicConfig() elif litellm.LlmProviders.CLOUDFLARE == provider: return litellm.CloudflareChatConfig() - + elif litellm.LlmProviders.SAGEMAKER_CHAT == provider: + return litellm.SagemakerChatConfig() + elif litellm.LlmProviders.SAGEMAKER == provider: + return litellm.SagemakerConfig() return litellm.OpenAIGPTConfig() diff --git a/tests/local_testing/test_config.py b/tests/local_testing/test_config.py index c5896793a7f5..0c08aceec7d0 100644 --- a/tests/local_testing/test_config.py +++ b/tests/local_testing/test_config.py @@ -290,33 +290,34 @@ async def _monkey_patch_get_config(*args, **kwargs): assert len(llm_router.model_list) == len(model_list) + prev_llm_router_val -# def test_provider_config_manager(): -# from litellm import LITELLM_CHAT_PROVIDERS, LlmProviders -# from litellm.utils import ProviderConfigManager -# from litellm.llms.base_llm.transformation import BaseConfig -# from litellm.llms.OpenAI.chat.gpt_transformation import OpenAIGPTConfig - -# for provider in LITELLM_CHAT_PROVIDERS: -# assert isinstance( -# ProviderConfigManager.get_provider_chat_config( -# model="gpt-3.5-turbo", provider=LlmProviders(provider) -# ), -# BaseConfig, -# ), f"Provider {provider} is not a subclass of BaseConfig" - -# config = ProviderConfigManager.get_provider_chat_config( -# model="gpt-3.5-turbo", provider=LlmProviders(provider) -# ) - -# if ( -# provider != litellm.LlmProviders.OPENAI -# and provider != litellm.LlmProviders.OPENAI_LIKE -# and provider != litellm.LlmProviders.CUSTOM_OPENAI -# ): -# assert ( -# config.__class__.__name__ != "OpenAIGPTConfig" -# ), f"Provider {provider} is an instance of OpenAIGPTConfig" - -# assert ( -# "_abc_impl" not in config.get_config() -# ), f"Provider {provider} has _abc_impl" +def test_provider_config_manager(): + from litellm import LITELLM_CHAT_PROVIDERS, LlmProviders + from litellm.utils import ProviderConfigManager + from litellm.llms.base_llm.transformation import BaseConfig + from litellm.llms.OpenAI.chat.gpt_transformation import OpenAIGPTConfig + + LITELLM_CHAT_PROVIDERS = ["sagemaker_chat", "sagemaker"] + for provider in LITELLM_CHAT_PROVIDERS: + assert isinstance( + ProviderConfigManager.get_provider_chat_config( + model="gpt-3.5-turbo", provider=LlmProviders(provider) + ), + BaseConfig, + ), f"Provider {provider} is not a subclass of BaseConfig" + + config = ProviderConfigManager.get_provider_chat_config( + model="gpt-3.5-turbo", provider=LlmProviders(provider) + ) + + if ( + provider != litellm.LlmProviders.OPENAI + and provider != litellm.LlmProviders.OPENAI_LIKE + and provider != litellm.LlmProviders.CUSTOM_OPENAI + ): + assert ( + config.__class__.__name__ != "OpenAIGPTConfig" + ), f"Provider {provider} is an instance of OpenAIGPTConfig" + + assert ( + "_abc_impl" not in config.get_config() + ), f"Provider {provider} has _abc_impl" diff --git a/tests/local_testing/test_sagemaker.py b/tests/local_testing/test_sagemaker.py index 0185c71460c1..f602067338a3 100644 --- a/tests/local_testing/test_sagemaker.py +++ b/tests/local_testing/test_sagemaker.py @@ -129,7 +129,7 @@ async def test_completion_sagemaker_messages_api(sync_mode): "sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614", ], ) -@pytest.mark.flaky(retries=3, delay=1) +# @pytest.mark.flaky(retries=3, delay=1) async def test_completion_sagemaker_stream(sync_mode, model): try: litellm.set_verbose = False diff --git a/tests/local_testing/test_streaming.py b/tests/local_testing/test_streaming.py index 7227099792a1..30d9d3e0f683 100644 --- a/tests/local_testing/test_streaming.py +++ b/tests/local_testing/test_streaming.py @@ -1750,7 +1750,7 @@ def test_sagemaker_weird_response(): try: import json - from litellm.llms.sagemaker.sagemaker import TokenIterator + from litellm.llms.sagemaker.completion.handler import TokenIterator chunk = """[INST] Hey, how's it going? [/INST], I'm doing well, thanks for asking! How about you? Is there anything you'd like to chat about or ask? I'm here to help with any questions you might have.""" From f9faae13a8315a30636a6d45087d5f9cb43f510f Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Tue, 10 Dec 2024 17:28:43 -0800 Subject: [PATCH 2/2] fix(main.py): pass hf model name + custom prompt dict to litellm params --- litellm/main.py | 2 ++ litellm/utils.py | 4 ++++ tests/local_testing/test_async_fn.py | 17 ----------------- 3 files changed, 6 insertions(+), 17 deletions(-) diff --git a/litellm/main.py b/litellm/main.py index 84b73fe7f6e3..e61fb2385e8b 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -1075,6 +1075,8 @@ def completion( # type: ignore # noqa: PLR0915 user_continue_message=kwargs.get("user_continue_message"), base_model=base_model, litellm_trace_id=kwargs.get("litellm_trace_id"), + hf_model_name=hf_model_name, + custom_prompt_dict=custom_prompt_dict, ) logging.update_environment_variables( model=model, diff --git a/litellm/utils.py b/litellm/utils.py index 31b96245f3a9..a129bf710203 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -2076,6 +2076,8 @@ def get_litellm_params( user_continue_message=None, base_model=None, litellm_trace_id=None, + hf_model_name: Optional[str] = None, + custom_prompt_dict: Optional[dict] = None, ): litellm_params = { "acompletion": acompletion, @@ -2105,6 +2107,8 @@ def get_litellm_params( "base_model": base_model or _get_base_model_from_litellm_call_metadata(metadata=metadata), "litellm_trace_id": litellm_trace_id, + "hf_model_name": hf_model_name, + "custom_prompt_dict": custom_prompt_dict, } return litellm_params diff --git a/tests/local_testing/test_async_fn.py b/tests/local_testing/test_async_fn.py index ec322f6d5e06..1fc8d5320552 100644 --- a/tests/local_testing/test_async_fn.py +++ b/tests/local_testing/test_async_fn.py @@ -246,23 +246,6 @@ async def test_hf_completion_tgi(): # test_get_cloudflare_response_streaming() -@pytest.mark.skip(reason="AWS Suspended Account") -@pytest.mark.asyncio -async def test_completion_sagemaker(): - # litellm.set_verbose=True - try: - response = await acompletion( - model="sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4", - messages=[{"content": "Hello, how are you?", "role": "user"}], - ) - # Add any assertions here to check the response - print(response) - except litellm.Timeout as e: - pass - except Exception as e: - pytest.fail(f"Error occurred: {e}") - - def test_get_response_streaming(): import asyncio