diff --git a/litellm/llms/gemini.py b/litellm/llms/gemini.py index 126569ecc0c9..21d85a2fcfd2 100644 --- a/litellm/llms/gemini.py +++ b/litellm/llms/gemini.py @@ -6,7 +6,8 @@ from litellm.utils import ModelResponse, get_secret, Choices, Message, Usage import litellm import sys, httpx -from .prompt_templates.factory import prompt_factory, custom_prompt +from .prompt_templates.factory import prompt_factory, custom_prompt, get_system_prompt +from packaging.version import Version class GeminiError(Exception): @@ -103,6 +104,13 @@ async def __aiter__(self): break +def supports_system_instruction(): + import google.generativeai as genai + + gemini_pkg_version = Version(genai.__version__) + return gemini_pkg_version >= Version("0.5.0") + + def completion( model: str, messages: list, @@ -124,7 +132,7 @@ def completion( "Importing google.generativeai failed, please run 'pip install -q google-generativeai" ) genai.configure(api_key=api_key) - + system_prompt = "" if model in custom_prompt_dict: # check if the model has a registered custom prompt model_prompt_details = custom_prompt_dict[model] @@ -135,6 +143,7 @@ def completion( messages=messages, ) else: + system_prompt, messages = get_system_prompt(messages=messages) prompt = prompt_factory( model=model, messages=messages, custom_llm_provider="gemini" ) @@ -166,7 +175,11 @@ def completion( ) ## COMPLETION CALL try: - _model = genai.GenerativeModel(f"models/{model}") + _params = {"model_name": "models/{}".format(model)} + _system_instruction = supports_system_instruction() + if _system_instruction and len(system_prompt) > 0: + _params["system_instruction"] = system_prompt + _model = genai.GenerativeModel(**_params) if stream == True: if acompletion == True: @@ -213,11 +226,12 @@ async def async_streaming(): encoding=encoding, ) else: - response = _model.generate_content( - contents=prompt, - generation_config=genai.types.GenerationConfig(**inference_params), - safety_settings=safety_settings, - ) + params = { + "contents": prompt, + "generation_config": genai.types.GenerationConfig(**inference_params), + "safety_settings": safety_settings, + } + response = _model.generate_content(**params) except Exception as e: raise GeminiError( message=str(e), diff --git a/litellm/llms/prompt_templates/factory.py b/litellm/llms/prompt_templates/factory.py index 5d3d70a5972e..2abb544095d7 100644 --- a/litellm/llms/prompt_templates/factory.py +++ b/litellm/llms/prompt_templates/factory.py @@ -959,7 +959,20 @@ def parse_xml_params(xml_content, json_schema: Optional[dict] = None): return params -### +### GEMINI HELPER FUNCTIONS ### + + +def get_system_prompt(messages): + system_prompt_indices = [] + system_prompt = "" + for idx, message in enumerate(messages): + if message["role"] == "system": + system_prompt += message["content"] + system_prompt_indices.append(idx) + if len(system_prompt_indices) > 0: + for idx in reversed(system_prompt_indices): + messages.pop(idx) + return system_prompt, messages def convert_openai_message_to_cohere_tool_result(message): diff --git a/litellm/llms/vertex_ai.py b/litellm/llms/vertex_ai.py index a38530f151db..ddfc8e9fb7d4 100644 --- a/litellm/llms/vertex_ai.py +++ b/litellm/llms/vertex_ai.py @@ -417,7 +417,7 @@ def completion( return async_completion(**data) if mode == "vision": - print_verbose("\nMaking VertexAI Gemini Pro Vision Call") + print_verbose("\nMaking VertexAI Gemini Pro / Pro Vision Call") print_verbose(f"\nProcessing input messages = {messages}") tools = optional_params.pop("tools", None) prompt, images = _gemini_vision_convert_messages(messages=messages)