Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: added new mixin to modify partial parsing behaviour #1152

Merged
merged 9 commits into from
Nov 14, 2024
20 changes: 17 additions & 3 deletions instructor/dsl/partial.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ class MakeFieldsOptional:
pass


class LiteralPartialMixin:
pass


def _make_field_optional(
field: FieldInfo,
) -> tuple[Any, FieldInfo]:
Expand Down Expand Up @@ -127,9 +131,14 @@ def model_from_chunks(
) -> Generator[T_Model, None, None]:
potential_object = ""
partial_model = cls.get_partial_model()
partial_mode = (
"on" if issubclass(cls, LiteralPartialMixin) else "trailing-strings"
)
for chunk in json_chunks:
potential_object += chunk
obj = from_json((potential_object.strip() or "{}").encode(), partial_mode="on")
obj = from_json(
(potential_object.strip() or "{}").encode(), partial_mode=partial_mode
)
obj = partial_model.model_validate(obj, strict=None, **kwargs)
yield obj

Expand All @@ -139,9 +148,14 @@ async def model_from_chunks_async(
) -> AsyncGenerator[T_Model, None]:
potential_object = ""
partial_model = cls.get_partial_model()
partial_mode = (
"on" if issubclass(cls, LiteralPartialMixin) else "trailing-strings"
)
async for chunk in json_chunks:
potential_object += chunk
obj = from_json((potential_object.strip() or "{}").encode(), partial_mode="on")
obj = from_json(
(potential_object.strip() or "{}").encode(), partial_mode=partial_mode
)
obj = partial_model.model_validate(obj, strict=None, **kwargs)
yield obj

Expand All @@ -163,7 +177,7 @@ def extract_json(
import json

resp = chunk.candidates[0].content.parts[0].function_call
resp_dict = type(resp).to_dict(resp) # type:ignore
resp_dict = type(resp).to_dict(resp) # type:ignore
if "args" in resp_dict:
yield json.dumps(resp_dict["args"])
elif chunk.choices:
Expand Down
124 changes: 123 additions & 1 deletion tests/llm/test_openai/test_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from pydantic import BaseModel
import pytest
import instructor
from instructor.dsl.partial import Partial
from instructor.dsl.partial import Partial, LiteralPartialMixin

from .util import models, modes

Expand Down Expand Up @@ -81,3 +81,125 @@ async def test_partial_model_async(model, mode, aclient):
)
async for m in model:
assert isinstance(m, UserExtract)


@pytest.mark.parametrize("model,mode", product(models, modes))
def test_literal_partial_mixin(model, mode, client):
# Test with LiteralPartialMixin
class UserWithMixin(BaseModel, LiteralPartialMixin):
name: str
age: int

client = instructor.patch(client, mode=mode)
resp = client.chat.completions.create(
model=model,
response_model=Partial[UserWithMixin],
max_retries=2,
stream=True,
messages=[
{"role": "user", "content": "Jason Liu is 12 years old"},
],
)

changes = 0
last_name = None
last_age = None
for m in resp:
assert isinstance(m, UserWithMixin)
if m.name != last_name:
last_name = m.name
changes += 1
if m.age != last_age:
last_age = m.age
changes += 1

assert changes == 2 # Ensure we got at least one field update

class UserWithoutMixin(BaseModel):
name: str
age: int

resp = client.chat.completions.create(
model=model,
response_model=Partial[UserWithoutMixin],
max_retries=2,
stream=True,
messages=[
{"role": "user", "content": "Jason Liu is 12 years old"},
],
)

changes = 0
last_name = None
last_age = None
for m in resp:
assert isinstance(m, UserWithoutMixin)
if m.name != last_name:
last_name = m.name
changes += 1
if m.age != last_age:
last_age = m.age
changes += 1

assert changes > 3

@pytest.mark.asyncio
@pytest.mark.parametrize("model,mode", product(models, modes))
async def test_literal_partial_mixin_async(model, mode, client):
# Test with LiteralPartialMixin
class UserWithMixin(BaseModel, LiteralPartialMixin):
name: str
age: int

client = instructor.patch(client, mode=mode)
resp = await client.chat.completions.create(
model=model,
response_model=Partial[UserWithMixin],
max_retries=2,
stream=True,
messages=[
{"role": "user", "content": "Jason Liu is 12 years old"},
],
)

changes = 0
last_name = None
last_age = None
async for m in resp:
assert isinstance(m, UserWithMixin)
if m.name != last_name:
last_name = m.name
changes += 1
if m.age != last_age:
last_age = m.age
changes += 1

assert changes == 2 # Ensure we got at least one field update

class UserWithoutMixin(BaseModel):
name: str
age: int

resp = await client.chat.completions.create(
model=model,
response_model=Partial[UserWithoutMixin],
max_retries=2,
stream=True,
messages=[
{"role": "user", "content": "Jason Liu is 12 years old"},
],
)

changes = 0
last_name = None
last_age = None
async for m in resp:
assert isinstance(m, UserWithoutMixin)
if m.name != last_name:
last_name = m.name
changes += 1
if m.age != last_age:
last_age = m.age
changes += 1

assert changes > 3
Loading