diff --git a/outlines/generate/json.py b/outlines/generate/json.py index d098d920d..22f03fc26 100644 --- a/outlines/generate/json.py +++ b/outlines/generate/json.py @@ -65,6 +65,12 @@ def json( regex_str = build_regex_from_schema(schema, whitespace_pattern) generator = regex(model, regex_str, sampler) generator.format_sequence = lambda x: pyjson.loads(x) + elif isinstance(schema_object, dict) and schema_object.get("type") == "function": + # Handle OpenAI function call format + schema = pyjson.dumps(schema_object["function"]["parameters"]) + regex_str = build_regex_from_schema(schema, whitespace_pattern) + generator = regex(model, regex_str, sampler) + generator.format_sequence = lambda x: pyjson.loads(x) else: raise ValueError( f"Cannot parse schema {schema_object}. The schema must be either "