Skip to content

Commit

Permalink
show an example of using parallel function calling
Browse files Browse the repository at this point in the history
Some `SQLModel` stuff while we're at it.
  • Loading branch information
rgbkrk committed Jan 16, 2024
1 parent 71f64cc commit f99a4c9
Showing 1 changed file with 367 additions and 0 deletions.
367 changes: 367 additions & 0 deletions notebooks/parallel-function-calling.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,367 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.3.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.3.2\u001b[0m\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n",
"Note: you may need to restart the kernel to use updated packages.\n"
]
}
],
"source": [
"%pip install sqlmodel -q"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from typing import Optional\n",
"from sqlmodel import SQLModel, Field, create_engine, Session\n",
"\n",
"# SQL Model uses Pydantic Models under the hood\n",
"\n",
"class Character(SQLModel, table=True):\n",
" id: Optional[int] = Field(default=None, primary_key=True)\n",
" name: str\n",
" race: str\n",
" character_class: str\n",
" level: int\n",
" background: str\n",
" player_name: Optional[str] = None\n",
" experience_points: int = 0\n",
" strength: int\n",
" dexterity: int\n",
" constitution: int\n",
" intelligence: int\n",
" wisdom: int\n",
" charisma: int\n",
" hit_points: int\n",
" armor_class: int\n",
" alignment: str\n",
" skills: str # Storing as comma-separated string\n",
" languages: str # Storing as comma-separated string\n",
" equipment: str # Storing as comma-separated string\n",
" spells: Optional[str] = None # Storing as comma-separated string\n",
"\n",
" def _repr_llm_(self):\n",
" return f\"<Character {self.id} {self.name}>\"\n",
" \n",
" def __repr__(self):\n",
" return f\"<Character {self.id} {self.name}>\"\n",
"\n",
"# SQLite Database URL\n",
"DATABASE_URL = \"sqlite:///:memory:\"\n",
"engine = create_engine(DATABASE_URL)\n",
"\n",
"# Create the database tables\n",
"SQLModel.metadata.create_all(engine)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"import random\n",
"\n",
"def roll_die(sides: int = 6):\n",
" \"\"\"Roll a die with the given number of sides.\"\"\"\n",
" return random.randint(1, sides)\n",
"\n",
"# Function to add a new character\n",
"def add_character(character: Character):\n",
" \"\"\"Adds a character to our characters database\"\"\"\n",
" with Session(engine) as session:\n",
" session.add(character)\n",
" session.commit\n",
"\n",
" return character\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"from chatlab import FunctionRegistry\n",
"\n",
"fr = FunctionRegistry()\n",
"fr.register_functions([roll_die, add_character])"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"from openai import OpenAI\n",
"client = OpenAI()"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"from chatlab import tool_result\n",
"\n",
"async def chatloop(initial_messages):\n",
" \"\"\"Emit messages encountered as well as tool results, making sure to autorun tools and respond to the model.\"\"\"\n",
" buffer = initial_messages.copy()\n",
"\n",
" resp = client.chat.completions.create(\n",
" model=\"gpt-3.5-turbo-1106\",\n",
" messages=initial_messages,\n",
"\n",
" # Pass in the tools from the function registry. The model will choose\n",
" # whether it uses 0, 1, 2, or N many tools.\n",
" tools=fr.tools,\n",
" tool_choice=\"auto\"\n",
" )\n",
"\n",
" message = resp.choices[0].message\n",
" buffer.append(message)\n",
"\n",
" yield message\n",
"\n",
" # call each of the tools\n",
" if message.tool_calls is not None:\n",
" for tool in message.tool_calls:\n",
" result = await fr.call(tool.function.name, tool.function.arguments)\n",
"\n",
" # An assistant message with 'tool_calls' must be followed by tool messages responding to each 'tool_call_id'.\n",
" tool_call_response = tool_result(tool.id, name=tool.function.name, content=str(result))\n",
" yield tool_call_response\n",
" buffer.append(tool_call_response)\n",
" \n",
" # Once all tools have been called, call the model again\n",
" async for m in chatloop(buffer):\n",
" yield m\n"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/markdown": [
"`roll_die()` → `17`"
],
"text/plain": [
"`roll_die()` → `17`"
]
},
"metadata": {
"text/markdown": {
"chatlab": {
"default": true
}
}
},
"output_type": "display_data"
},
{
"data": {
"text/markdown": [
"`roll_die()` → `13`"
],
"text/plain": [
"`roll_die()` → `13`"
]
},
"metadata": {
"text/markdown": {
"chatlab": {
"default": true
}
}
},
"output_type": "display_data"
},
{
"data": {
"text/markdown": [
"`roll_die()` → `5`"
],
"text/plain": [
"`roll_die()` → `5`"
]
},
"metadata": {
"text/markdown": {
"chatlab": {
"default": true
}
}
},
"output_type": "display_data"
},
{
"data": {
"text/markdown": [
"`roll_die()` → `13`"
],
"text/plain": [
"`roll_die()` → `13`"
]
},
"metadata": {
"text/markdown": {
"chatlab": {
"default": true
}
}
},
"output_type": "display_data"
},
{
"data": {
"text/markdown": [
"`roll_die()` → `10`"
],
"text/plain": [
"`roll_die()` → `10`"
]
},
"metadata": {
"text/markdown": {
"chatlab": {
"default": true
}
}
},
"output_type": "display_data"
},
{
"data": {
"text/markdown": [
"`roll_die()` → `14`"
],
"text/plain": [
"`roll_die()` → `14`"
]
},
"metadata": {
"text/markdown": {
"chatlab": {
"default": true
}
}
},
"output_type": "display_data"
},
{
"data": {
"text/markdown": [
"`roll_die()` → `11`"
],
"text/plain": [
"`roll_die()` → `11`"
]
},
"metadata": {
"text/markdown": {
"chatlab": {
"default": true
}
}
},
"output_type": "display_data"
},
{
"data": {
"text/markdown": [
"> Here are the results for your character's stats:\n",
"> \n",
"> - Strength: 17\n",
"> - Perception: 13\n",
"> - Endurance: 5\n",
"> - Charisma: 13\n",
"> - Intelligence: 10\n",
"> - Agility: 14\n",
"> - Luck: 11"
],
"text/plain": [
"> Here are the results for your character's stats:\n",
"> \n",
"> - Strength: 17\n",
"> - Perception: 13\n",
"> - Endurance: 5\n",
"> - Charisma: 13\n",
"> - Intelligence: 10\n",
"> - Agility: 14\n",
"> - Luck: 11"
]
},
"metadata": {
"text/markdown": {
"chatlab": {
"default": true
}
}
},
"output_type": "display_data"
}
],
"source": [
"from pydantic import BaseModel\n",
"from chatlab import system, user, Markdown\n",
"\n",
"async for message in chatloop([\n",
" system(\"Create your character for the Fallout RPG. The user is the DM.\"),\n",
" user(\"Roll for the following stats: Strength, Perception, Endurance, Charisma, Intelligence, Agility, and Luck.\")\n",
" ]):\n",
" # When message is a pydantic model, convert to a dict\n",
"\n",
" if isinstance(message, BaseModel):\n",
" message = message.model_dump()\n",
"\n",
" role = message['role']\n",
" content = message.get('content')\n",
"\n",
" if(role == \"assistant\" and content is not None):\n",
" display(Markdown(\"> \" + content.replace(\"\\n\", \"\\n> \")))\n",
" if(role == \"tool\"):\n",
" display(Markdown(f\"`{message['name']}()` → `{content}`\"))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "chatlab-3PJ-KiVK-py3.12",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.1"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

0 comments on commit f99a4c9

Please sign in to comment.