-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
show an example of using parallel function calling
Some `SQLModel` stuff while we're at it.
- Loading branch information
Showing
1 changed file
with
367 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,367 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"\n", | ||
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.3.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.3.2\u001b[0m\n", | ||
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n", | ||
"Note: you may need to restart the kernel to use updated packages.\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"%pip install sqlmodel -q" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from typing import Optional\n", | ||
"from sqlmodel import SQLModel, Field, create_engine, Session\n", | ||
"\n", | ||
"# SQL Model uses Pydantic Models under the hood\n", | ||
"\n", | ||
"class Character(SQLModel, table=True):\n", | ||
" id: Optional[int] = Field(default=None, primary_key=True)\n", | ||
" name: str\n", | ||
" race: str\n", | ||
" character_class: str\n", | ||
" level: int\n", | ||
" background: str\n", | ||
" player_name: Optional[str] = None\n", | ||
" experience_points: int = 0\n", | ||
" strength: int\n", | ||
" dexterity: int\n", | ||
" constitution: int\n", | ||
" intelligence: int\n", | ||
" wisdom: int\n", | ||
" charisma: int\n", | ||
" hit_points: int\n", | ||
" armor_class: int\n", | ||
" alignment: str\n", | ||
" skills: str # Storing as comma-separated string\n", | ||
" languages: str # Storing as comma-separated string\n", | ||
" equipment: str # Storing as comma-separated string\n", | ||
" spells: Optional[str] = None # Storing as comma-separated string\n", | ||
"\n", | ||
" def _repr_llm_(self):\n", | ||
" return f\"<Character {self.id} {self.name}>\"\n", | ||
" \n", | ||
" def __repr__(self):\n", | ||
" return f\"<Character {self.id} {self.name}>\"\n", | ||
"\n", | ||
"# SQLite Database URL\n", | ||
"DATABASE_URL = \"sqlite:///:memory:\"\n", | ||
"engine = create_engine(DATABASE_URL)\n", | ||
"\n", | ||
"# Create the database tables\n", | ||
"SQLModel.metadata.create_all(engine)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import random\n", | ||
"\n", | ||
"def roll_die(sides: int = 6):\n", | ||
" \"\"\"Roll a die with the given number of sides.\"\"\"\n", | ||
" return random.randint(1, sides)\n", | ||
"\n", | ||
"# Function to add a new character\n", | ||
"def add_character(character: Character):\n", | ||
" \"\"\"Adds a character to our characters database\"\"\"\n", | ||
" with Session(engine) as session:\n", | ||
" session.add(character)\n", | ||
" session.commit\n", | ||
"\n", | ||
" return character\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from chatlab import FunctionRegistry\n", | ||
"\n", | ||
"fr = FunctionRegistry()\n", | ||
"fr.register_functions([roll_die, add_character])" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from openai import OpenAI\n", | ||
"client = OpenAI()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 6, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from chatlab import tool_result\n", | ||
"\n", | ||
"async def chatloop(initial_messages):\n", | ||
" \"\"\"Emit messages encountered as well as tool results, making sure to autorun tools and respond to the model.\"\"\"\n", | ||
" buffer = initial_messages.copy()\n", | ||
"\n", | ||
" resp = client.chat.completions.create(\n", | ||
" model=\"gpt-3.5-turbo-1106\",\n", | ||
" messages=initial_messages,\n", | ||
"\n", | ||
" # Pass in the tools from the function registry. The model will choose\n", | ||
" # whether it uses 0, 1, 2, or N many tools.\n", | ||
" tools=fr.tools,\n", | ||
" tool_choice=\"auto\"\n", | ||
" )\n", | ||
"\n", | ||
" message = resp.choices[0].message\n", | ||
" buffer.append(message)\n", | ||
"\n", | ||
" yield message\n", | ||
"\n", | ||
" # call each of the tools\n", | ||
" if message.tool_calls is not None:\n", | ||
" for tool in message.tool_calls:\n", | ||
" result = await fr.call(tool.function.name, tool.function.arguments)\n", | ||
"\n", | ||
" # An assistant message with 'tool_calls' must be followed by tool messages responding to each 'tool_call_id'.\n", | ||
" tool_call_response = tool_result(tool.id, name=tool.function.name, content=str(result))\n", | ||
" yield tool_call_response\n", | ||
" buffer.append(tool_call_response)\n", | ||
" \n", | ||
" # Once all tools have been called, call the model again\n", | ||
" async for m in chatloop(buffer):\n", | ||
" yield m\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 10, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/markdown": [ | ||
"`roll_die()` → `17`" | ||
], | ||
"text/plain": [ | ||
"`roll_die()` → `17`" | ||
] | ||
}, | ||
"metadata": { | ||
"text/markdown": { | ||
"chatlab": { | ||
"default": true | ||
} | ||
} | ||
}, | ||
"output_type": "display_data" | ||
}, | ||
{ | ||
"data": { | ||
"text/markdown": [ | ||
"`roll_die()` → `13`" | ||
], | ||
"text/plain": [ | ||
"`roll_die()` → `13`" | ||
] | ||
}, | ||
"metadata": { | ||
"text/markdown": { | ||
"chatlab": { | ||
"default": true | ||
} | ||
} | ||
}, | ||
"output_type": "display_data" | ||
}, | ||
{ | ||
"data": { | ||
"text/markdown": [ | ||
"`roll_die()` → `5`" | ||
], | ||
"text/plain": [ | ||
"`roll_die()` → `5`" | ||
] | ||
}, | ||
"metadata": { | ||
"text/markdown": { | ||
"chatlab": { | ||
"default": true | ||
} | ||
} | ||
}, | ||
"output_type": "display_data" | ||
}, | ||
{ | ||
"data": { | ||
"text/markdown": [ | ||
"`roll_die()` → `13`" | ||
], | ||
"text/plain": [ | ||
"`roll_die()` → `13`" | ||
] | ||
}, | ||
"metadata": { | ||
"text/markdown": { | ||
"chatlab": { | ||
"default": true | ||
} | ||
} | ||
}, | ||
"output_type": "display_data" | ||
}, | ||
{ | ||
"data": { | ||
"text/markdown": [ | ||
"`roll_die()` → `10`" | ||
], | ||
"text/plain": [ | ||
"`roll_die()` → `10`" | ||
] | ||
}, | ||
"metadata": { | ||
"text/markdown": { | ||
"chatlab": { | ||
"default": true | ||
} | ||
} | ||
}, | ||
"output_type": "display_data" | ||
}, | ||
{ | ||
"data": { | ||
"text/markdown": [ | ||
"`roll_die()` → `14`" | ||
], | ||
"text/plain": [ | ||
"`roll_die()` → `14`" | ||
] | ||
}, | ||
"metadata": { | ||
"text/markdown": { | ||
"chatlab": { | ||
"default": true | ||
} | ||
} | ||
}, | ||
"output_type": "display_data" | ||
}, | ||
{ | ||
"data": { | ||
"text/markdown": [ | ||
"`roll_die()` → `11`" | ||
], | ||
"text/plain": [ | ||
"`roll_die()` → `11`" | ||
] | ||
}, | ||
"metadata": { | ||
"text/markdown": { | ||
"chatlab": { | ||
"default": true | ||
} | ||
} | ||
}, | ||
"output_type": "display_data" | ||
}, | ||
{ | ||
"data": { | ||
"text/markdown": [ | ||
"> Here are the results for your character's stats:\n", | ||
"> \n", | ||
"> - Strength: 17\n", | ||
"> - Perception: 13\n", | ||
"> - Endurance: 5\n", | ||
"> - Charisma: 13\n", | ||
"> - Intelligence: 10\n", | ||
"> - Agility: 14\n", | ||
"> - Luck: 11" | ||
], | ||
"text/plain": [ | ||
"> Here are the results for your character's stats:\n", | ||
"> \n", | ||
"> - Strength: 17\n", | ||
"> - Perception: 13\n", | ||
"> - Endurance: 5\n", | ||
"> - Charisma: 13\n", | ||
"> - Intelligence: 10\n", | ||
"> - Agility: 14\n", | ||
"> - Luck: 11" | ||
] | ||
}, | ||
"metadata": { | ||
"text/markdown": { | ||
"chatlab": { | ||
"default": true | ||
} | ||
} | ||
}, | ||
"output_type": "display_data" | ||
} | ||
], | ||
"source": [ | ||
"from pydantic import BaseModel\n", | ||
"from chatlab import system, user, Markdown\n", | ||
"\n", | ||
"async for message in chatloop([\n", | ||
" system(\"Create your character for the Fallout RPG. The user is the DM.\"),\n", | ||
" user(\"Roll for the following stats: Strength, Perception, Endurance, Charisma, Intelligence, Agility, and Luck.\")\n", | ||
" ]):\n", | ||
" # When message is a pydantic model, convert to a dict\n", | ||
"\n", | ||
" if isinstance(message, BaseModel):\n", | ||
" message = message.model_dump()\n", | ||
"\n", | ||
" role = message['role']\n", | ||
" content = message.get('content')\n", | ||
"\n", | ||
" if(role == \"assistant\" and content is not None):\n", | ||
" display(Markdown(\"> \" + content.replace(\"\\n\", \"\\n> \")))\n", | ||
" if(role == \"tool\"):\n", | ||
" display(Markdown(f\"`{message['name']}()` → `{content}`\"))" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "chatlab-3PJ-KiVK-py3.12", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.12.1" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |