Skip to content

Commit

Permalink
feat: added another test
Browse files Browse the repository at this point in the history
  • Loading branch information
cpacker committed Nov 27, 2024
1 parent 42624b8 commit 49003e7
Show file tree
Hide file tree
Showing 4 changed files with 174 additions and 15 deletions.
80 changes: 65 additions & 15 deletions letta/functions/schema_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,26 +223,74 @@ def create_task_plan(steps: list[Step]):

def clean_property(prop: dict) -> dict:
"""Clean up a property schema to match desired format"""

if "description" not in prop:
raise ValueError(f"Property {prop} lacks a 'description' key")

return {
"type": "str" if prop["type"] == "string" else prop["type"],
"description": prop["description"],
}

def clean_schema(schema_part: dict) -> dict:
properties = {}
for name, prop in schema_part["properties"].items():
if "items" in prop: # Handle arrays
properties[name] = {"type": "array", "items": clean_schema(prop["items"])}
else:
properties[name] = clean_property(prop)
def resolve_ref(ref: str, schema: dict) -> dict:
"""Resolve a $ref reference in the schema"""
if not ref.startswith("#/$defs/"):
raise ValueError(f"Unexpected reference format: {ref}")

model_name = ref.split("/")[-1]
if model_name not in schema.get("$defs", {}):
raise ValueError(f"Reference {model_name} not found in schema definitions")

return schema["$defs"][model_name]

def clean_schema(schema_part: dict, full_schema: dict) -> dict:
"""Clean up a schema part, handling references and nested structures"""
# Handle $ref
if "$ref" in schema_part:
schema_part = resolve_ref(schema_part["$ref"], full_schema)

if "type" not in schema_part:
raise ValueError(f"Schema part lacks a 'type' key: {schema_part}")

# Handle array type
if schema_part["type"] == "array":
items_schema = schema_part["items"]
if "$ref" in items_schema:
items_schema = resolve_ref(items_schema["$ref"], full_schema)
return {"type": "array", "items": clean_schema(items_schema, full_schema), "description": schema_part.get("description", "")}

# Handle object type
if schema_part["type"] == "object":
if "properties" not in schema_part:
raise ValueError(f"Object schema lacks 'properties' key: {schema_part}")

properties = {}
for name, prop in schema_part["properties"].items():
if "items" in prop: # Handle arrays
if "description" not in prop:
raise ValueError(f"Property {prop} lacks a 'description' key")
properties[name] = {
"type": "array",
"items": clean_schema(prop["items"], full_schema),
"description": prop["description"],
}
else:
properties[name] = clean_property(prop)

return {
"type": "object",
"properties": properties,
"required": schema_part.get("required", []),
}
pydantic_model_schema_dict = {
"type": "object",
"properties": properties,
"required": schema_part.get("required", []),
}
if "description" in schema_part:
pydantic_model_schema_dict["description"] = schema_part["description"]

return pydantic_model_schema_dict

# Handle primitive types
return clean_property(schema_part)

return clean_schema(schema)
return clean_schema(schema_part=schema, full_schema=schema)


def generate_schema(function, name: Optional[str] = None, description: Optional[str] = None) -> dict:
Expand Down Expand Up @@ -289,14 +337,16 @@ def generate_schema(function, name: Optional[str] = None, description: Optional[
and not get_origin(param.annotation)
and issubclass(param.annotation, BaseModel)
):
print("Generating schema for pydantic model:", param.annotation)
# Extract the properties from the pydantic model
# schema["parameters"]["properties"][param.name] = pydantic_model_to_open_ai(param.annotation)
schema["parameters"]["properties"] = pydantic_model_to_json_schema(param.annotation)
schema["parameters"]["properties"][param.name] = pydantic_model_to_json_schema(param.annotation)
schema["parameters"]["properties"][param.name]["description"] = param_doc.description

# Otherwise, we convert the Python typing to JSON schema types
# NOTE: important - if a dict or list, the internal type can be a Pydantic model itself
# however in that
else:
print("Generating schema for non-pydantic model:", param.annotation)
# Grab the description for the parameter from the extended docstring
# If it doesn't exist, we should raise an error
param_doc = next((d for d in docstring.params if d.arg_name == param.name), None)
Expand Down
23 changes: 23 additions & 0 deletions tests/test_tool_schema_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,21 @@
# from .test_tool_schema_parsing_files.list_of_pydantic_example import create_task_plan


def _clean_diff(d1, d2):
"""Utility function to clean up the diff between two dictionaries."""

# Keys in d1 but not in d2
removed = {k: d1[k] for k in d1.keys() - d2.keys()}

# Keys in d2 but not in d1
added = {k: d2[k] for k in d2.keys() - d1.keys()}

# Keys in both but values changed
changed = {k: (d1[k], d2[k]) for k in d1.keys() & d2.keys() if d1[k] != d2[k]}

return {k: v for k, v in {"removed": removed, "added": added, "changed": changed}.items() if v} # Only include non-empty differences


def _compare_schemas(generated_schema: dict, expected_schema: dict, strip_heartbeat: bool = True):
"""Compare an autogenerated schema to an expected schema."""

Expand All @@ -19,8 +34,12 @@ def _compare_schemas(generated_schema: dict, expected_schema: dict, strip_heartb
# Check that the two schemas are equal
# If not, pretty print the difference by dumping with indent=4
if generated_schema != expected_schema:
print("==== GENERATED SCHEMA ====")
print(json.dumps(generated_schema, indent=4))
print("==== EXPECTED SCHEMA ====")
print(json.dumps(expected_schema, indent=4))
print("==== DIFF ====")
print(json.dumps(_clean_diff(generated_schema, expected_schema), indent=4))
raise AssertionError("Schemas are not equal")
else:
print("Schemas are equal")
Expand All @@ -47,4 +66,8 @@ def _run_schema_test(schema_name: str, desired_function_name: str):
def test_derive_openai_json_schema():
"""Test that the schema generator works across a variety of example source code inputs."""

print("==== TESTING basic example where the arg is a pydantic model ====")
_run_schema_test("list_of_pydantic_example", "create_task_plan")

print("==== TESTING more complex example where the arg is a nested pydantic model ====")
_run_schema_test("nested_pydantic_as_arg_example", "create_task_plan")
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
{
"name": "create_task_plan",
"description": "Creates a task plan for the current task.",
"parameters": {
"type": "object",
"properties": {
"steps": {
"type": "object",
"description": "List of steps to add to the task plan.",
"properties": {
"steps": {
"type": "array",
"description": "A list of steps to add to the task plan.",
"items": {
"type": "object",
"properties": {
"name": {
"type": "str",
"description": "Name of the step."
},
"key": {
"type": "str",
"description": "Unique identifier for the step."
},
"description": {
"type": "str",
"description": "An exhaustic description of what this step is trying to achieve and accomplish."
}
},
"required": ["name", "key", "description"]
}
}
},
"required": ["steps"]
}
},
"required": ["steps"]
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from pydantic import BaseModel, Field


class Step(BaseModel):
name: str = Field(
...,
description="Name of the step.",
)
key: str = Field(
...,
description="Unique identifier for the step.",
)
description: str = Field(
...,
description="An exhaustic description of what this step is trying to achieve and accomplish.",
)


# NOTE: this example is pretty contrived - you probably don't want to have a nested pydantic model with
# a single field that's the same as the variable name (in this case, `steps`)
class Steps(BaseModel):
steps: list[Step] = Field(
...,
description="A list of steps to add to the task plan.",
)


def create_task_plan(steps: Steps) -> str:
"""
Creates a task plan for the current task.
It takes in a list of steps, and updates the task with the new steps provided.
If there are any current steps, they will be overwritten.
Each step in the list should have the following format:
{
"name": <string> -- Name of the step.
"key": <string> -- Unique identifier for the step.
"description": <string> -- An exhaustic description of what this step is trying to achieve and accomplish.
}
Args:
steps: List of steps to add to the task plan.
Returns:
str: A summary of the updated task plan after deletion
"""
DUMMY_MESSAGE = "Task plan created successfully."
return DUMMY_MESSAGE

0 comments on commit 49003e7

Please sign in to comment.