diff --git a/chatlab/__init__.py b/chatlab/__init__.py index 52f4853..9f724ed 100644 --- a/chatlab/__init__.py +++ b/chatlab/__init__.py @@ -32,6 +32,7 @@ ) from .registry import FunctionRegistry from spork import Markdown +from instructor import Partial __version__ = __version__ @@ -51,4 +52,5 @@ "FunctionRegistry", "ChatlabMetadata", "expose_exception_to_llm", + "Partial" ] diff --git a/chatlab/chat.py b/chatlab/chat.py index 18249a0..25a6e5f 100644 --- a/chatlab/chat.py +++ b/chatlab/chat.py @@ -172,9 +172,9 @@ async def __process_stream( tool_argument = ToolArguments( id=tool_call.id, name=tool_call.function.name, arguments=tool_call.function.arguments ) - # Now we get the function and see if it has the ChatLabMetadata for a render func - func = self.function_registry.get_chatlab_metadata(tool_call.function.name) + # If the user provided a custom renderer, set it on the tool argument object for displaying + func = self.function_registry.get_chatlab_metadata(tool_call.function.name) if func is not None and func.render is not None: tool_argument.custom_render = func.render diff --git a/chatlab/registry.py b/chatlab/registry.py index f26457c..1fd2c7d 100644 --- a/chatlab/registry.py +++ b/chatlab/registry.py @@ -65,7 +65,7 @@ class WhatTime(BaseModel): from openai.types.chat import ChatCompletionToolParam from .decorators import ChatlabMetadata - +from .errors import ChatLabError class APIManifest(TypedDict, total=False): """The schema for the API.""" @@ -80,13 +80,13 @@ class APIManifest(TypedDict, total=False): """ -class FunctionArgumentError(Exception): +class FunctionArgumentError(ChatLabError): """Exception raised when a function is called with invalid arguments.""" pass -class UnknownFunctionError(Exception): +class UnknownFunctionError(ChatLabError): """Exception raised when a function is called that is not registered.""" pass diff --git a/chatlab/views/tools.py b/chatlab/views/tools.py index f58dd38..c1b0a07 100644 --- a/chatlab/views/tools.py +++ b/chatlab/views/tools.py @@ -2,6 +2,8 @@ from pydantic import ValidationError from spork import AutoUpdate +import warnings + from ..components.function_details import ChatFunctionComponent from ..registry import FunctionRegistry, FunctionArgumentError, UnknownFunctionError, extract_arguments, extract_model_from_function @@ -101,23 +103,23 @@ def render(self): possible_args = parser.parse(self.arguments) Model = extract_model_from_function(self.name, self.custom_render) - model = Partial[Model](**possible_args) - - kwargs = {} + model = Model.model_validate(possible_args) # Pluck the kwargs out from the crafted model, as we can't pass the pydantic model as the arguments # However any "inner" models should retain their pydantic Model nature - for k in model.__dict__.keys(): - kwargs[k] = getattr(model, k) + kwargs = {k: getattr(model, k) for k in model.__dict__.keys()} - return self.custom_render(**kwargs) except FunctionArgumentError: return None except ValidationError: return None + + try: + return self.custom_render(**kwargs) except Exception as e: - #print(f"Exception in custom render for {self.name}.", e) - #print(self.arguments) + # Exception in userland code + # Would be preferable to bubble up, however + # it might be due to us passing a not-quite model return None return ChatFunctionComponent(name=self.name, verbage=self.verbage, input=self.arguments) diff --git a/notebooks/knowledge-graph.ipynb b/notebooks/knowledge-graph.ipynb index 61b48ce..9265b53 100644 --- a/notebooks/knowledge-graph.ipynb +++ b/notebooks/knowledge-graph.ipynb @@ -19,20 +19,7 @@ }, { "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "# ChatGPT sometimes makes some wrong colors which produce lots of Graphviz Warnings when IPython `display`s\n", - "import warnings\n", - "\n", - "warnings.simplefilter(\"ignore\")\n", - "warnings.filterwarnings(\"ignore\")" - ] - }, - { - "cell_type": "code", - "execution_count": 10, + "execution_count": 1, "metadata": {}, "outputs": [ { @@ -44,145 +31,86 @@ "\n", "\n", - "\n", - "\n", - "\n", + "\n", + "\n", + "\n", "\n", "\n", "1\n", - "\n", - "Project Planning\n", + "\n", + "Start\n", "\n", "\n", "\n", "2\n", - "\n", - "Pre-application Meeting\n", + "\n", + "Submit Application\n", "\n", "\n", "\n", "1->2\n", - "\n", - "\n", - "follows\n", + "\n", + "\n", + "Start Application\n", "\n", "\n", "\n", "3\n", - "\n", - "Application Submission\n", + "\n", + "Review Application\n", "\n", "\n", "\n", "2->3\n", - "\n", - "\n", - "follows\n", + "\n", + "\n", + "Application Review\n", "\n", "\n", "\n", "4\n", - "\n", - "Application Review\n", + "\n", + "Inspection\n", "\n", "\n", "\n", "3->4\n", - "\n", - "\n", - "follows\n", + "\n", + "\n", + "Inspection\n", "\n", "\n", "\n", "5\n", - "\n", - "Public Notification\n", + "\n", + "Approval\n", "\n", "\n", "\n", "4->5\n", - "\n", - "\n", - "follows\n", - "\n", - "\n", - "\n", - "7\n", - "\n", - "Building Permit Issuance\n", - "\n", - "\n", - "\n", - "4->7\n", - "\n", - "\n", - "leads to\n", + "\n", + "\n", + "Approval\n", "\n", "\n", "\n", "6\n", - "\n", - "Discretionary Review\n", + "\n", + "End\n", "\n", "\n", "\n", "5->6\n", - "\n", - "\n", - "optional\n", - "\n", - "\n", - "\n", - "6->7\n", - "\n", - "\n", - "leads to\n", - "\n", - "\n", - "\n", - "8\n", - "\n", - "Construction\n", - "\n", - "\n", - "\n", - "7->8\n", - "\n", - "\n", - "follows\n", - "\n", - "\n", - "\n", - "9\n", - "\n", - "Final Inspection\n", - "\n", - "\n", - "\n", - "8->9\n", - "\n", - "\n", - "follows\n", - "\n", - "\n", - "\n", - "10\n", - "\n", - "Certificate of Final Completion\n", - "\n", - "\n", - "\n", - "9->10\n", - "\n", - "\n", - "results in\n", + "\n", + "\n", + "End\n", "\n", "\n", "\n" ], "text/plain": [ - "" + "" ] }, "metadata": {}, @@ -192,10 +120,6 @@ "name": "stderr", "output_type": "stream", "text": [ - "Warning: light is not a known color.\n", - "Warning: light is not a known color.\n", - "Warning: light is not a known color.\n", - "Warning: light is not a known color.\n", "Warning: light is not a known color.\n", "Warning: light is not a known color.\n", "Warning: light is not a known color.\n", @@ -207,10 +131,18 @@ { "data": { "text/markdown": [ - "The network graph representing the permit process in San Francisco has been created and stored. It outlines the sequence of steps starting from project planning to obtaining the certificate of final completion. The graph also indicates that the discretionary review is an optional step in the process." + "I have created a network graph representing the permit process in San Francisco. The process includes the following steps:\n", + "1. Start\n", + "2. Submit Application\n", + "3. Review Application\n", + "4. Inspection\n", + "5. Approval\n", + "6. End\n", + "\n", + "Each step is connected in sequence to show the flow of the permit process." ], "text/plain": [ - "AssistantMessageView(display_id='3bb096a3-14d9-4674-91b9-f83732e7ef5e', content='The network graph representing the permit process in San Francisco has been created and stored. It outlines the sequence of steps starting from project planning to obtaining the certificate of final completion. The graph also indicates that the discretionary review is an optional step in the process.', finished=False, has_displayed=True)" + "AssistantMessageView(display_id='8f698249-c893-435f-be86-2c7029159908', content='I have created a network graph representing the permit process in San Francisco. The process includes the following steps:\\n1. Start\\n2. Submit Application\\n3. Review Application\\n4. Inspection\\n5. Approval\\n6. End\\n\\nEach step is connected in sequence to show the flow of the permit process.', finished=False, has_displayed=True)" ] }, "metadata": {}, @@ -218,11 +150,15 @@ } ], "source": [ + "import warnings\n", "from graphviz import Digraph\n", "from pydantic import BaseModel, Field\n", "from typing import List\n", "from chatlab import Chat, system\n", - "from chatlab.decorators import render_stream\n", + "from chatlab.decorators import incremental_display\n", + "\n", + "#warnings.simplefilter(\"ignore\")\n", + "#warnings.filterwarnings(\"ignore\")\n", "\n", "\n", "class Node(BaseModel):\n", @@ -256,7 +192,7 @@ " return dot\n", "\n", "\n", - "@render_stream(visualize_knowledge_graph)\n", + "@incremental_display(visualize_knowledge_graph)\n", "def store_knowledge_graph(kg: KnowledgeGraph, comment: str):\n", " \"\"\"Creates a graphviz diagram for the user and stores it in their database.\"\"\"\n", " # dot = visualize_knowledge_graph(kg)\n", @@ -268,8 +204,8 @@ " system(\n", " \"You are running inside a jupyter notebook. Your responses appear as markdown in the notebook. Functions you run can produce side effects.\"\n", " ),\n", - " #model=\"gpt-3.5-turbo\",\n", - " model=\"gpt-4-turbo-preview\",\n", + " model=\"gpt-3.5-turbo\",\n", + " #model=\"gpt-4-turbo-preview\",\n", " chat_functions=[store_knowledge_graph],\n", ")\n", "\n",