Skip to content

Commit

Permalink
Update gemini.py for multiple tool calls + pre-commit formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
marklysze authored Jun 25, 2024
1 parent f0880f2 commit 3504ee7
Showing 1 changed file with 74 additions and 47 deletions.
121 changes: 74 additions & 47 deletions autogen/oai/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,25 @@

import google.generativeai as genai
import requests
from google.ai.generativelanguage import Content, FunctionCall, FunctionDeclaration, FunctionResponse, Part, Tool
import vertexai
from google.ai.generativelanguage import Content, FunctionCall, FunctionDeclaration, FunctionResponse, Part, Tool
from google.api_core.exceptions import InternalServerError
from openai.types.chat import ChatCompletion, ChatCompletionMessageToolCall
from openai.types.chat.chat_completion import ChatCompletionMessage, Choice
from openai.types.chat.chat_completion_message_tool_call import Function
from openai.types.completion_usage import CompletionUsage
from PIL import Image
from vertexai.generative_models import Content as VertexAIContent, FunctionDeclaration as VertexAIFunctionDeclaration, Tool as VertexAITool
from vertexai.generative_models import (
Content as VertexAIContent,
)
from vertexai.generative_models import (
FunctionDeclaration as VertexAIFunctionDeclaration,
)
from vertexai.generative_models import GenerativeModel
from vertexai.generative_models import Part as VertexAIPart
from vertexai.generative_models import (
Tool as VertexAITool,
)


class GeminiClient:
Expand Down Expand Up @@ -272,14 +280,14 @@ def create(self, params: Dict) -> ChatCompletion:
)

return response_oai

# If str is not a json string return str as is
def _to_json(self, str) -> dict:
try:
return json.loads(str)
except ValueError:
return str

def _oai_content_to_gemini_content(self, message: Dict[str, Any]) -> List:
"""Convert content from OAI format to Gemini format"""
rst = []
Expand All @@ -292,32 +300,43 @@ def _oai_content_to_gemini_content(self, message: Dict[str, Any]) -> List:

if "tool_calls" in message:
if self.use_vertexai:
rst.append(VertexAIPart.from_dict({
"functionCall": {
"name": message["tool_calls"][0]["function"]["name"],
"args": json.loads(message["tool_calls"][0]["function"]["arguments"])
}
}))
for tool_call in message["tool_calls"]:
rst.append(
VertexAIPart.from_dict(
{
"functionCall": {
"name": tool_call["function"]["name"],
"args": json.loads(tool_call["function"]["arguments"]),
}
}
)
)
else:
rst.append(
Part(
function_call=FunctionCall(
name=message["tool_calls"][0]["function"]["name"],
args=json.loads(message["tool_calls"][0]["function"]["arguments"]),
for tool_call in message["tool_calls"]:
rst.append(
Part(
function_call=FunctionCall(
name=tool_call["function"]["name"],
args=json.loads(tool_call["function"]["arguments"]),
)
)
)
)
return rst

if message["role"] == "tool":
if self.use_vertexai:
rst.append(VertexAIPart.from_function_response(
name=message["name"],
response={"result": self._to_json(message["content"])}
))
rst.append(
VertexAIPart.from_function_response(
name=message["name"], response={"result": self._to_json(message["content"])}
)
)
else:
rst.append(
Part(function_response=FunctionResponse(name=message["name"], response={"result": self._to_json(message["content"])}))
Part(
function_response=FunctionResponse(
name=message["name"], response={"result": self._to_json(message["content"])}
)
)
)
return rst

Expand Down Expand Up @@ -376,50 +395,59 @@ def _calculate_gemini_cost(self, input_tokens: int, output_tokens: int, model_na
return 0.70 * input_tokens / 1e6 + 2.10 * output_tokens / 1e6

if "gemini-pro" not in model_name and "gemini-1.0-pro" not in model_name:
warnings.warn(f"Cost calculation is not implemented for model {model_name}. Using Gemini-1.0-Pro.", UserWarning)
warnings.warn(
f"Cost calculation is not implemented for model {model_name}. Using Gemini-1.0-Pro.", UserWarning
)

# Cost is $0.5 per million input tokens and $1.5 per million output tokens
return 0.5 * input_tokens / 1e6 + 1.5 * output_tokens / 1e6


def _oai_messages_to_gemini_messages(self, messages: list[Dict[str, Any]]) -> list[dict[str, Any]]:
"""Convert messages from OAI format to Gemini format.
Make sure the "user" role and "model" role are interleaved.
Also, make sure the last item is from the "user" role.
"""
prev_role = None
rst = []

def append_parts(parts, role):
if self.use_vertexai:
rst.append(VertexAIContent(parts=parts, role=role))
else:
rst.append(Content(parts=parts, role=role))

def append_text_to_last(text):
if self.use_vertexai:
rst[-1] = VertexAIContent(parts=[*rst[-1].parts, VertexAIPart.from_text(text)], role=rst[-1].role)
else:
rst[-1] = Content(parts=[*rst[-1].parts, Part(text=text)], role=rst[-1].role)

def is_function_call(parts):
return self.use_vertexai and parts[0].function_call or not self.use_vertexai and "function_call" in parts[0]

for i, message in enumerate(messages):

# Since the tool call message does not have the "name" field, we need to find the corresponding tool message.
if message["role"] == "tool":
message["name"] = [
m for m in messages if "tool_calls" in m and m["tool_calls"][0]["id"] == message["tool_call_id"]
][0]["tool_calls"][0]["function"]["name"]
m["tool_calls"][i]["function"]["name"]
for m in messages
if "tool_calls" in m
for i, tc in enumerate(m["tool_calls"])
if tc["id"] == message["tool_call_id"]
][0]

parts = self._oai_content_to_gemini_content(message)
role = "user" if message["role"] in ["user", "system"] else "function" if message["role"] == "tool" else "model"

role = (
"user"
if message["role"] in ["user", "system"]
else "function" if message["role"] == "tool" else "model"
)

# In Gemini if the current message is a function call then previous message should not be a model message.
if is_function_call(parts):
# If the previous message is a model message then add a dummy "continue" user message before the function call
if(prev_role == "model"):
if prev_role == "model":
append_parts(self._oai_content_to_gemini_content("continue"), "user")
append_parts(parts, role)
# In Gemini if the current message is a function response then next message should be a model message.
Expand Down Expand Up @@ -450,7 +478,7 @@ def is_function_call(parts):
else:
rst.append(Content(parts=self._oai_content_to_gemini_content("continue"), role="user"))
return rst

def _oai_tools_to_gemini_tools(self, tools: List[Dict[str, Any]]) -> List[Tool]:
"""Convert tools from OAI format to Gemini format."""
if len(tools) == 0:
Expand All @@ -461,7 +489,7 @@ def _oai_tools_to_gemini_tools(self, tools: List[Dict[str, Any]]) -> List[Tool]:
function_declaration = VertexAIFunctionDeclaration(
name=tool["function"]["name"],
description=tool["function"]["description"],
parameters=tool["function"]["parameters"]
parameters=tool["function"]["parameters"],
)
else:
function_declaration = FunctionDeclaration(
Expand All @@ -477,8 +505,9 @@ def _oai_tools_to_gemini_tools(self, tools: List[Dict[str, Any]]) -> List[Tool]:
else:
return [Tool(function_declarations=function_declarations)]


def _oai_function_parameters_to_gemini_function_parameters(self, function_definition: dict[str, any]) -> dict[str, any]:
def _oai_function_parameters_to_gemini_function_parameters(
self, function_definition: dict[str, any]
) -> dict[str, any]:
"""
Convert OpenAPI function definition parameters to Gemini function parameters definition.
The type key is renamed to type_ and the value is capitalized.
Expand All @@ -487,7 +516,7 @@ def _oai_function_parameters_to_gemini_function_parameters(self, function_defini
# Delete the default key as it is not supported in Gemini
if "default" in function_definition:
del function_definition["default"]

function_definition["type_"] = function_definition["type"].upper()
del function_definition["type"]
if "properties" in function_definition:
Expand All @@ -501,30 +530,28 @@ def _oai_function_parameters_to_gemini_function_parameters(self, function_defini
)
return function_definition


def _gemini_content_to_oai_choices(self, response: Content) -> List[Choice]:
"""Convert response from Gemini format to OAI format."""
text = None
tool_calls = None
tool_calls = []
for part in response.parts:
if part.function_call:
if self.use_vertexai:
arguments = VertexAIPart.to_dict(part)["function_call"]["args"]
else:
arguments = Part.to_dict(part)["function_call"]["args"]
tool_calls = [
tool_calls.append(
ChatCompletionMessageToolCall(
id=str(random.randint(0, 1000)),
type="function",
function=Function(
name=part.function_call.name,
arguments=json.dumps(arguments)
),
function=Function(name=part.function_call.name, arguments=json.dumps(arguments)),
)
]
)
elif part.text:
text = part.text
message = ChatCompletionMessage(role="assistant", content=text, function_call=None, tool_calls=tool_calls)
message = ChatCompletionMessage(
role="assistant", content=text, function_call=None, tool_calls=tool_calls if len(tool_calls) > 0 else None
)
return [Choice(finish_reason="tool_calls" if tool_calls else "stop", index=0, message=message)]


Expand Down

0 comments on commit 3504ee7

Please sign in to comment.