Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add parallel function call examples #115

Merged
merged 3 commits into from
May 7, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
258 changes: 202 additions & 56 deletions quickstarts/Function_calling.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
},
"outputs": [],
"source": [
"#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# @title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
Expand Down Expand Up @@ -70,18 +70,9 @@
"metadata": {
"id": "9OEoeosRTv-5"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m137.4/137.4 kB\u001b[0m \u001b[31m933.2 kB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25h"
]
}
],
"outputs": [],
"source": [
"!pip install -U -q google-generativeai # Install the Python SDK"
"!pip install -U -q google-generativeai # Install the Python SDK"
]
},
{
Expand Down Expand Up @@ -115,7 +106,8 @@
"outputs": [],
"source": [
"from google.colab import userdata\n",
"GOOGLE_API_KEY=userdata.get('GOOGLE_API_KEY')\n",
"\n",
"GOOGLE_API_KEY = userdata.get(\"GOOGLE_API_KEY\")\n",
"genai.configure(api_key=GOOGLE_API_KEY)"
]
},
Expand Down Expand Up @@ -163,24 +155,29 @@
}
],
"source": [
"def add(a:float, b:float):\n",
"def add(a: float, b: float):\n",
" \"\"\"returns a + b.\"\"\"\n",
" return a+b\n",
" return a + b\n",
"\n",
"\n",
"def subtract(a:float, b:float):\n",
"def subtract(a: float, b: float):\n",
" \"\"\"returns a - b.\"\"\"\n",
" return a-b\n",
" return a - b\n",
"\n",
"def multiply(a:float, b:float):\n",
"\n",
"def multiply(a: float, b: float):\n",
" \"\"\"returns a * b.\"\"\"\n",
" return a*b\n",
" return a * b\n",
"\n",
"\n",
"def divide(a:float, b:float):\n",
"def divide(a: float, b: float):\n",
" \"\"\"returns a / b.\"\"\"\n",
" return a*b\n",
" return a * b\n",
"\n",
"\n",
"model = genai.GenerativeModel(model_name='gemini-1.0-pro',\n",
" tools=[add, subtract, multiply, divide])\n",
"model = genai.GenerativeModel(\n",
" model_name=\"gemini-1.0-pro\", tools=[add, subtract, multiply, divide]\n",
")\n",
"\n",
"model"
]
Expand Down Expand Up @@ -244,7 +241,9 @@
}
],
"source": [
"response = chat.send_message('I have 57 cats, each owns 44 mittens, how many mittens is that in total?')\n",
"response = chat.send_message(\n",
" \"I have 57 cats, each owns 44 mittens, how many mittens is that in total?\"\n",
")\n",
"response.text"
]
},
Expand All @@ -267,7 +266,7 @@
}
],
"source": [
"57*44"
"57 * 44"
]
},
{
Expand Down Expand Up @@ -319,7 +318,7 @@
"source": [
"for content in chat.history:\n",
" print(content.role, \"->\", [type(part).to_dict(part) for part in content.parts])\n",
" print('-'*80)"
" print(\"-\" * 80)"
]
},
{
Expand Down Expand Up @@ -373,25 +372,27 @@
},
"outputs": [],
"source": [
"def find_movies(description: str, location: str = ''):\n",
" \"\"\"find movie titles currently playing in theaters based on any description, genre, title words, etc.\n",
"def find_movies(description: str, location: str = \"\"):\n",
" \"\"\"find movie titles currently playing in theaters based on any description, genre, title words, etc.\n",
"\n",
" Args:\n",
" description: Any kind of description including category or genre, title words, attributes, etc.\n",
" location: The city and state, e.g. San Francisco, CA or a zip code e.g. 95616\n",
" \"\"\"\n",
" return ['Barbie', 'Oppenheimer']\n",
" Args:\n",
" description: Any kind of description including category or genre, title words, attributes, etc.\n",
" location: The city and state, e.g. San Francisco, CA or a zip code e.g. 95616\n",
" \"\"\"\n",
" return [\"Barbie\", \"Oppenheimer\"]\n",
"\n",
"def find_theaters(location: str, movie: str = ''):\n",
"\n",
"def find_theaters(location: str, movie: str = \"\"):\n",
" \"\"\"Find theaters based on location and optionally movie title which are is currently playing in theaters.\n",
"\n",
" Args:\n",
" location: The city and state, e.g. San Francisco, CA or a zip code e.g. 95616\n",
" movie: Any movie title\n",
" \"\"\"\n",
" return ['Googleplex 16', 'Android Theatre']\n",
" return [\"Googleplex 16\", \"Android Theatre\"]\n",
"\n",
"\n",
"def get_showtimes(location:str, movie:str, theater:str, date:str):\n",
"def get_showtimes(location: str, movie: str, theater: str, date: str):\n",
" \"\"\"\n",
" Find the start times for movies playing in a specific theater.\n",
"\n",
Expand All @@ -401,7 +402,7 @@
" thearer: Name of the theater\n",
" date: Date for requested showtime\n",
" \"\"\"\n",
" return ['10:00', '11:00']"
" return [\"10:00\", \"11:00\"]"
]
},
{
Expand All @@ -422,13 +423,12 @@
"outputs": [],
"source": [
"functions = {\n",
" 'find_movies': find_movies,\n",
" 'find_theaters': find_theaters,\n",
" 'get_showtimes': get_showtimes,\n",
" \"find_movies\": find_movies,\n",
" \"find_theaters\": find_theaters,\n",
" \"get_showtimes\": get_showtimes,\n",
"}\n",
"\n",
"model = genai.GenerativeModel(model_name='gemini-1.0-pro',\n",
" tools=functions.values())"
"model = genai.GenerativeModel(model_name=\"gemini-1.0-pro\", tools=functions.values())"
]
},
{
Expand Down Expand Up @@ -476,7 +476,9 @@
}
],
"source": [
"response = model.generate_content('Which theaters in Mountain View show the Barbie movie?')\n",
"response = model.generate_content(\n",
" \"Which theaters in Mountain View show the Barbie movie?\"\n",
")\n",
"response.candidates[0].content.parts"
]
},
Expand All @@ -496,7 +498,7 @@
"elif ...\n",
"```\n",
"\n",
"However, since we made the `functions` dictionary earlier, we can simplify this to:"
"However, since you already made the `functions` dictionary, this can be simplified to:"
]
},
{
Expand All @@ -516,16 +518,17 @@
],
"source": [
"def call_function(function_call, functions):\n",
" function_name = function_call.name\n",
" function_args = function_call.args\n",
" return functions[function_name](**function_args)\n",
" function_name = function_call.name\n",
" function_args = function_call.args\n",
" return functions[function_name](**function_args)\n",
"\n",
"\n",
"part = response.candidates[0].content.parts[0]\n",
"\n",
"# Check if it's a function call; in real use you'd need to also handle text\n",
"# responses as you won't know what the model will respond with.\n",
"if part.function_call:\n",
" result = call_function(part.function_call, functions)\n",
" result = call_function(part.function_call, functions)\n",
"\n",
"print(result)"
]
Expand Down Expand Up @@ -560,27 +563,170 @@
"\n",
"# Put the result in a protobuf Struct\n",
"s = Struct()\n",
"s.update({'result': result})\n",
"s.update({\"result\": result})\n",
"\n",
"# Update this after https://github.com/google/generative-ai-python/issues/243\n",
"function_response = glm.Part(\n",
" function_response=glm.FunctionResponse(name='find_theaters', response=s))\n",
" function_response=glm.FunctionResponse(name=\"find_theaters\", response=s)\n",
")\n",
"\n",
"# Build the message history\n",
"messages = [\n",
" {'role':'user',\n",
" 'parts': ['Which theaters in Mountain View show the Barbie movie?.']},\n",
" {'role':'model',\n",
" 'parts': response.candidates[0].content.parts},\n",
" {'role':'user',\n",
" 'parts': [function_response]}\n",
" # fmt: off\n",
" {\"role\": \"user\",\n",
" \"parts\": [\"Which theaters in Mountain View show the Barbie movie?.\"]},\n",
" {\"role\": \"model\",\n",
" \"parts\": response.candidates[0].content.parts},\n",
" {\"role\": \"user\",\n",
" \"parts\": [function_response]},\n",
" # fmt: on\n",
"]\n",
"\n",
"# Generate the next response\n",
"response = model.generate_content(messages)\n",
"print(response.text)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "EuwKoNIhGBJN"
},
"source": [
"## Parallel function calls\n",
"\n",
"The Gemini API can call multiple functions in a single turn. This caters for scenarios where there are multiple function calls that can take place independently to complete a task.\n",
"\n",
"First set the tools up. Unlike the movie example above, these functions do not require input from each other to be called so they should be good candidates for parallel calling."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "cJ-mSixWGqLv"
},
"outputs": [],
"source": [
"def power_disco_ball(power: bool) -> bool:\n",
" \"\"\"Powers the spinning disco ball.\"\"\"\n",
" print(f\"Disco ball is {'spinning!' if power else 'stopped.'}\")\n",
" return True\n",
"\n",
"\n",
"def start_music(energetic: bool, loud: bool, bpm: int) -> str:\n",
" \"\"\"Play some music matching the specified parameters.\n",
"\n",
" Args:\n",
" energetic: Whether the music is energetic or not.\n",
" loud: Whether the music is loud or not.\n",
" bpm: The beats per minute of the music.\n",
"\n",
" Returns: The name of the song being played.\n",
" \"\"\"\n",
" print(f\"Starting music! {energetic=} {loud=}, {bpm=}\")\n",
" return \"Never gonna give you up.\"\n",
"\n",
"\n",
"def dim_lights(brightness: float) -> bool:\n",
" \"\"\"Dim the lights.\n",
"\n",
" Args:\n",
" brightness: The brightness of the lights, 0.0 is off, 1.0 is full.\n",
" \"\"\"\n",
" print(f\"Lights are now set to {brightness:.0%}\")\n",
" return True"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "zlrmXN7fxQi0"
},
"source": [
"Now call the model with an instruction that could use all of the specified tools."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "21ecYHLgIsCl"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"power_disco_ball(power=True)\n",
"start_music(energetic=True, loud=True, bpm=120.0)\n",
"dim_lights(brightness=0.3)\n"
]
}
],
"source": [
"# Set the model up with tools.\n",
"house_fns = [power_disco_ball, start_music, dim_lights]\n",
"# Try this out with Pro and Flash...\n",
"model = genai.GenerativeModel(model_name=\"gemini-1.5-pro-latest\", tools=house_fns)\n",
"\n",
"# Call the API.\n",
"chat = model.start_chat()\n",
"response = chat.send_message(\"Turn this place into a party!\")\n",
"\n",
"# Print out each of the function calls requested from this single call.\n",
"for part in response.parts:\n",
" if fn := part.function_call:\n",
" args = \", \".join(f\"{key}={val}\" for key, val in fn.args.items())\n",
" print(f\"{fn.name}({args})\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "t6iYpty7yZct"
},
"source": [
"Each of the printed results reflects a single function call that the model has requested. To send the results back, include the responses in the same order as they were requested."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "L7RxoiR3foBR"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Let's get this party started! I've turned on the disco ball, started playing some upbeat music, and dimmed the lights. 🎶✨ Get ready to dance! 🕺💃 \n",
"\n",
"\n"
]
}
],
"source": [
"import google.ai.generativelanguage as glm\n",
"\n",
"# Simulate the responses from the specified tools.\n",
"responses = {\n",
" \"power_disco_ball\": True,\n",
" \"start_music\": \"Never gonna give you up.\",\n",
" \"dim_lights\": True,\n",
"}\n",
"\n",
"# Build the response parts.\n",
"response_parts = [\n",
" glm.Part(function_response=glm.FunctionResponse(name=fn, response={\"result\": val}))\n",
" for fn, val in responses.items()\n",
"]\n",
"\n",
"response = chat.send_message(response_parts)\n",
"print(response.text)"
]
},
{
"cell_type": "markdown",
"metadata": {
Expand Down
Loading