Skip to content

Commit

Permalink
Merge pull request #136 from rgbkrk/streaming-input
Browse files Browse the repository at this point in the history
 Streaming input!
  • Loading branch information
rgbkrk authored Feb 28, 2024
2 parents 8592749 + 66dcce0 commit bb4f382
Show file tree
Hide file tree
Showing 13 changed files with 3,011 additions and 433 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.9", "3.10", "3.11", "3.12"]
python-version: ["3.10", "3.11", "3.12"]

steps:
- uses: actions/checkout@v3
Expand Down
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Support parallel tool calling by default in `Chat`.
- Legacy support for function calling is available by passing `legacy_function_calling=True` to the `Chat` constructor.
- :new: `@incremental_display` decorator (see https://github.com/rgbkrk/chatlab/pull/136) that allows you to stream visualizations to the user as the model fills out function arguments.

![building a graph quickly](https://private-user-images.githubusercontent.com/836375/308375331-8953f679-5051-4416-b8a1-0994bde8b032.gif?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3MDkwODIyMjEsIm5iZiI6MTcwOTA4MTkyMSwicGF0aCI6Ii84MzYzNzUvMzA4Mzc1MzMxLTg5NTNmNjc5LTUwNTEtNDQxNi1iOGExLTA5OTRiZGU4YjAzMi5naWY_WC1BbXotQWxnb3JpdGhtPUFXUzQtSE1BQy1TSEEyNTYmWC1BbXotQ3JlZGVudGlhbD1BS0lBVkNPRFlMU0E1M1BRSzRaQSUyRjIwMjQwMjI4JTJGdXMtZWFzdC0xJTJGczMlMkZhd3M0X3JlcXVlc3QmWC1BbXotRGF0ZT0yMDI0MDIyOFQwMDU4NDFaJlgtQW16LUV4cGlyZXM9MzAwJlgtQW16LVNpZ25hdHVyZT0yNmIxMWIzMTczM2VmYzA0M2VlMDY1MmE0YmE0ODUwZjJjNzMxYjIzYjc1MDI1MmJjYTBhNWZhMzg5MjExZThmJlgtQW16LVNpZ25lZEhlYWRlcnM9aG9zdCZhY3Rvcl9pZD0wJmtleV9pZD0wJnJlcG9faWQ9MCJ9.hhQqvJtvm3cI8W-JTuQiu8rryl97O_xLKf7dpR2QFSQ)



## [1.3.0]

Expand Down
2 changes: 2 additions & 0 deletions chatlab/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
)
from .registry import FunctionRegistry
from spork import Markdown
from instructor import Partial

__version__ = __version__

Expand All @@ -51,4 +52,5 @@
"FunctionRegistry",
"ChatlabMetadata",
"expose_exception_to_llm",
"Partial",
]
7 changes: 6 additions & 1 deletion chatlab/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,10 +169,15 @@ async def __process_stream(
and tool_call.function.arguments is not None
and tool_call.id is not None
):
# Must build up
tool_argument = ToolArguments(
id=tool_call.id, name=tool_call.function.name, arguments=tool_call.function.arguments
)

# If the user provided a custom renderer, set it on the tool argument object for displaying
func = self.function_registry.get_chatlab_metadata(tool_call.function.name)
if func is not None and func.render is not None:
tool_argument.custom_render = func.render

tool_argument.display()
tool_calls.append(tool_argument)

Expand Down
52 changes: 45 additions & 7 deletions chatlab/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,14 @@
"""


class ChatlabMetadata:
"""ChatLab metadata for a function."""

expose_exception_to_llm: bool
from typing import Callable, Optional

def __init__(self, expose_exception_to_llm=False):
"""Initialize ChatLab metadata for a function."""
self.expose_exception_to_llm = expose_exception_to_llm
from pydantic import BaseModel

class ChatlabMetadata(BaseModel):
"""ChatLab metadata for a function."""
expose_exception_to_llm: bool = False
render: Optional[Callable] = None

def expose_exception_to_llm(func):
"""Expose exceptions from calling the function to the LLM.
Expand Down Expand Up @@ -70,3 +69,42 @@ def expose_exception_to_llm(func):

func.chatlab_metadata.expose_exception_to_llm = True
return func


'''
The `incremental_display` decorator lets you render a function while the model is streaming in arguments.
def visualize_knowledge_graph(kg: KnowledgeGraph, comment: str = "Knowledge Graph"):
"""Visualizes a knowledge graph using graphviz."""
dot = Digraph(comment=comment)
for node in kg.nodes:
dot.node(str(node.id), node.label, color=node.color)
for edge in kg.edges:
dot.edge(str(edge.source), str(edge.target), label=edge.label, color=edge.color)
return dot
@incremental_display(visualize_knowledge_graph)
def store_knowledge_graph(kg: KnowledgeGraph, comment: str = "Knowledge Graph"):
"""Databases a knowledge graph"""
...
chat.register(store_knowledge_graph)
'''

def incremental_display(render_func: Callable):
def decorator(func):
if not hasattr(func, "chatlab_metadata"):
func.chatlab_metadata = ChatlabMetadata()

# Make sure that chatlab_metadata is an instance of ChatlabMetadata
if not isinstance(func.chatlab_metadata, ChatlabMetadata):
raise Exception("func.chatlab_metadata must be an instance of ChatlabMetadata")

func.chatlab_metadata.render = render_func
return func
return decorator

Empty file added chatlab/py.typed
Empty file.
133 changes: 73 additions & 60 deletions chatlab/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,12 @@ class WhatTime(BaseModel):

from openai.types import FunctionDefinition, FunctionParameters
from openai.types.chat.completion_create_params import Function, FunctionCall
from pydantic import BaseModel, create_model
from pydantic import BaseModel, Field, create_model

from openai.types.chat import ChatCompletionToolParam

from .decorators import ChatlabMetadata
from .errors import ChatLabError


class APIManifest(TypedDict, total=False):
Expand All @@ -80,13 +81,13 @@ class APIManifest(TypedDict, total=False):
"""


class FunctionArgumentError(Exception):
class FunctionArgumentError(ChatLabError):
"""Exception raised when a function is called with invalid arguments."""

pass


class UnknownFunctionError(Exception):
class UnknownFunctionError(ChatLabError):
"""Exception raised when a function is called that is not registered."""

pass
Expand Down Expand Up @@ -123,6 +124,49 @@ class FunctionSchemaConfig:
arbitrary_types_allowed = True


def extract_model_from_function(func_name: str, function: Callable) -> Type[BaseModel]:
# extract function parameters and their type annotations
sig = inspect.signature(function)

fields = {}
required_fields = []
for name, param in sig.parameters.items():
# skip 'self' for class methods
if name == "self":
continue

# determine type annotation
if param.annotation == inspect.Parameter.empty:
# no annotation, raise instead of falling back to Any
raise Exception(f"`{name}` parameter of {func_name} must have a JSON-serializable type annotation")
type_annotation = param.annotation

default_value: Any = ...

# determine if there is a default value
if param.default != inspect.Parameter.empty:
default_value = param.default
else:
required_fields.append(name)

# Check if the annotation is Union that includes None, indicating an optional parameter
if get_origin(type_annotation) is Union:
args = get_args(type_annotation)
if len(args) == 2 and type(None) in args:
type_annotation = next(arg for arg in args if arg is not type(None))
default_value = None

fields[name] = (type_annotation, Field(default=default_value) if default_value is not ... else ...)

model = create_model(
function.__name__,
__config__=FunctionSchemaConfig, # type: ignore
__required__=required_fields,
**fields, # type: ignore
)
return model


def generate_function_schema(
function: Callable,
parameter_schema: Optional[Union[Type["BaseModel"], dict]] = None,
Expand All @@ -149,44 +193,7 @@ def generate_function_schema(
elif parameter_schema is not None:
parameters = parameter_schema.model_json_schema() # type: ignore
else:
# extract function parameters and their type annotations
sig = inspect.signature(function)

fields = {}
for name, param in sig.parameters.items():
# skip 'self' for class methods
if name == "self":
continue

# determine type annotation
if param.annotation == inspect.Parameter.empty:
# no annotation, raise instead of falling back to Any
raise Exception(f"`{name}` parameter of {func_name} must have a JSON-serializable type annotation")
type_annotation = param.annotation

# determine if there is a default value
if param.default != inspect.Parameter.empty:
default_value = param.default
else:
default_value = ...

# Check if the annotation is Union that includes None, indicating an optional parameter
if get_origin(type_annotation) is Union:
args = get_args(type_annotation)
if len(args) == 2 and type(None) in args:
# It's an optional parameter
type_annotation = next(arg for arg in args if arg is not type(None))
default_value = None if default_value is ... else default_value

fields[name] = (type_annotation, default_value)

# create the pydantic model and return its JSON schema to pass into the 'parameters' part of the
# function schema used by OpenAI
model = create_model(
function.__name__,
__config__=FunctionSchemaConfig, # type: ignore
**fields, # type: ignore
)
model = extract_model_from_function(func_name, function)
parameters: dict = model.model_json_schema() # type: ignore

if "properties" not in parameters:
Expand Down Expand Up @@ -458,25 +465,8 @@ async def call(self, name: str, arguments: Optional[str] = None) -> Any:

function = possible_function

parameters: dict = {}

if arguments is not None and arguments != "":
try:
parameters = json.loads(arguments)
except json.JSONDecodeError:
raise FunctionArgumentError(f"Invalid Function call on {name}. Arguments must be a valid JSON object")

prepared_arguments = {}

for param_name, param in inspect.signature(function).parameters.items():
param_type = param.annotation
arg_value = parameters.get(param_name)

# Check if parameter type is a subclass of BaseModel and deserialize JSON into Pydantic model
if inspect.isclass(param_type) and issubclass(param_type, BaseModel):
prepared_arguments[param_name] = param_type.model_validate(arg_value)
else:
prepared_arguments[param_name] = cast(Any, arg_value)
# TODO: Use the model extractor here
prepared_arguments = extract_arguments(name, function, arguments)

if asyncio.iscoroutinefunction(function):
result = await function(**prepared_arguments)
Expand All @@ -494,3 +484,26 @@ def __contains__(self, name) -> bool:
def function_definitions(self) -> list[FunctionDefinition]:
"""Get a list of function definitions."""
return list(self.__schemas.values())


def extract_arguments(name: str, function: Callable, arguments: Optional[str]) -> dict:
dict_arguments = {}
if arguments is not None and arguments != "":
try:
dict_arguments = json.loads(arguments)
except json.JSONDecodeError:
raise FunctionArgumentError(f"Invalid Function call on {name}. Arguments must be a valid JSON object")

prepared_arguments = {}

for param_name, param in inspect.signature(function).parameters.items():
param_type = param.annotation
arg_value = dict_arguments.get(param_name)

# Check if parameter type is a subclass of BaseModel and deserialize JSON into Pydantic model
if inspect.isclass(param_type) and issubclass(param_type, BaseModel):
prepared_arguments[param_name] = param_type.model_validate(arg_value)
else:
prepared_arguments[param_name] = cast(Any, arg_value)

return prepared_arguments
Loading

0 comments on commit bb4f382

Please sign in to comment.