Skip to content

Commit

Permalink
refactor(openai): use function_token_reserve
Browse files Browse the repository at this point in the history
  • Loading branch information
zhudotexe committed Aug 17, 2023
1 parent 319a0cc commit 5d01270
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 18 deletions.
30 changes: 18 additions & 12 deletions kani/engines/openai/engine.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import functools

from kani.ai_function import AIFunction
from kani.exceptions import MissingModelDependencies
from kani.models import ChatMessage
from . import function_calling
from .client import OpenAIClient
from .models import ChatCompletion, FunctionSpec
from ..base import BaseEngine
Expand All @@ -25,14 +28,7 @@


class OpenAIEngine(BaseEngine):
"""Engine for using the OpenAI API.
.. caution::
Due to having to track "hidden" tokens for the function spec, it is not recommended to reuse an OpenAIEngine
instance in multiple kani. To take advantage of reuse, construct a shared :class:`.OpenAIClient` and
initialize OpenAIEngine with ``client=the_client_instance`` rather than ``api_key="..."``.
"""
"""Engine for using the OpenAI API."""

def __init__(
self,
Expand Down Expand Up @@ -73,7 +69,6 @@ def __init__(
self.max_context_size = max_context_size
self.hyperparams = hyperparams
self.tokenizer = None # tiktoken caches a tokenizer globally in module, so we can unconditionally load it
self.token_reserve = 0
self._load_tokenizer()

def _load_tokenizer(self):
Expand Down Expand Up @@ -103,10 +98,21 @@ async def predict(
completion = await self.client.create_chat_completion(
model=self.model, messages=messages, functions=function_spec, **self.hyperparams, **hyperparams
)
# calculate function calling reserve tokens on first run
if functions and self.token_reserve == 0:
self.token_reserve = max(completion.prompt_tokens - sum(self.message_len(m) for m in messages), 0)
return completion

def function_token_reserve(self, functions: list[AIFunction]) -> int:
if not functions:
return 0
# wrap an inner impl to use lru_cache with frozensets
return self._function_token_reserve_impl(frozenset(functions))

@functools.lru_cache(maxsize=256)
def _function_token_reserve_impl(self, functions):
# openai doesn't tell us exactly how their function prompt works, so
# we rely on community reverse-engineering to build the right prompt
# hopefully OpenAI releases a utility to calculate this in the future, this seems kind of fragile
prompt = function_calling.prompt(functions)
return len(self.tokenizer.encode(prompt)) + 16 # internal MD headers, namespace {} delimiters

async def close(self):
await self.client.close()
81 changes: 81 additions & 0 deletions kani/engines/openai/function_calling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
"""
As OpenAI doesn't tell us exactly how functions are exposed to GPT, we have to rely on some community reverse
engineering to build a reliable method to reserve tokens for AI Functions.
See https://community.openai.com/t/how-to-calculate-the-tokens-when-using-function-call/266573 for more details.
"""
import json
import warnings

from kani.ai_function import AIFunction


def prompt(functions: list[AIFunction]) -> str:
return "".join(map(format_function, functions))


def format_function(function: AIFunction) -> str:
# Thanks @CGamesPlay for https://gist.github.com/CGamesPlay/dd4f108f27e2eec145eedf5c717318f5, which this
# implementation is based on.
def resolve_ref(schema):
if schema.get("$ref") is not None:
*_, ref = schema["$ref"].rsplit("/", 1)
schema = json_schema["$defs"][ref]
return schema

def format_schema(schema, indent):
schema = resolve_ref(schema)
if "enum" in schema:
return format_enum(schema)
elif schema["type"] == "object":
return format_object(schema, indent)
elif schema["type"] == "array":
return format_schema(schema["items"], indent) + "[]"
elif schema["type"] in ("string", "number", "integer", "boolean", "null"): # these are all 1 token!
return schema["type"]
warnings.warn(
f"Unknown JSON schema type estimating tokens for OpenAI: {schema['type']!r}\n"
"The returned estimate may be off by a significant amount."
)
return schema["type"]

def format_enum(schema):
return " | ".join(json.dumps(o) for o in schema["enum"])

def format_object(schema, indent):
result = "{\n"
if "properties" not in schema or len(schema["properties"]) == 0:
if schema.get("additionalProperties", False):
return "object"
return None
for key, value in schema["properties"].items():
value = resolve_ref(value)
value_rendered = format_schema(value, indent + 1)
if value_rendered is None:
continue
if "description" in value and indent == 0:
for line in value["description"].strip().split("\n"):
result += f"{' ' * indent}// {line}\n"
optional = "" if key in schema.get("required", {}) else "?"
comment = "" if value.get("default") is None else f" // default: {format_default(value)}"
result += f"{' ' * indent}{key}{optional}: {value_rendered},{comment}\n"
result += (" " * (indent - 1)) + "}"
return result

def format_default(schema):
v = schema["default"]
if schema["type"] == "number":
return f"{v:.1f}" if float(v).is_integer() else str(v)
else:
return str(v)

json_schema = function.json_schema
if function.desc:
out = f"// {function.desc}\ntype {function.name} = ("
else:
out = f"type {function.name} = ("
formatted = format_object(json_schema, 0)
if formatted is not None:
out += "_: " + formatted
out += ") => any;\n\n"
return out
8 changes: 6 additions & 2 deletions kani/json_schema.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import inspect
import typing
from typing import TYPE_CHECKING, Optional
from typing import Optional, TYPE_CHECKING

import pydantic

if TYPE_CHECKING:
from .ai_function import AIParam

# this is the same as Pydantic's as of v2.1, but we specify it here because some downstream things rely on it
# (e.g. engines.openai.function_calling_tokens)
REF_TEMPLATE = "#/$defs/{model}"


class AIParamSchema:
"""Used to annotate parameters of AIFunctions in order to make generating their schema nicer.
Expand Down Expand Up @@ -152,4 +156,4 @@ def create_json_schema(params: list[AIParamSchema]) -> dict:
fields[param.name] = (param.type, pydantic.Field(**field_kwargs))
# create a temp model for generating json schemas
pydantic_model = pydantic.create_model("_FunctionSpec", **fields)
return pydantic_model.model_json_schema(schema_generator=JSONSchemaBuilder)
return pydantic_model.model_json_schema(schema_generator=JSONSchemaBuilder, ref_template=REF_TEMPLATE)
14 changes: 10 additions & 4 deletions tests/test_type_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,11 @@ class BadEnum(enum.Enum):
async def example_primitives(
a: str,
b: float,
c: Annotated[str, AIParam(desc="I am C")],
d: Annotated[int, "I am not an AIParam"] = 2,
c: bool,
d: int,
e: None,
aa: Annotated[str, AIParam(desc="I am AA")],
dd: Annotated[int, "I am not an AIParam"] = 2,
):
"""description!"""
pass
Expand Down Expand Up @@ -80,10 +83,13 @@ def test_schema_primitives():
"properties": {
"a": {"type": "string"},
"b": {"type": "number"},
"c": {"description": "I am C", "type": "string"},
"c": {"type": "boolean"},
"d": {"type": "integer"},
"e": {"type": "null"},
"aa": {"description": "I am AA", "type": "string"},
"dd": {"default": 2, "type": "integer"},
},
"required": ["a", "b", "c"],
"required": ["a", "b", "c", "d", "e", "aa"],
"type": "object",
},
)
Expand Down

0 comments on commit 5d01270

Please sign in to comment.