Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: added new mixin to modify partial parsing behaviour #1152

Merged
merged 9 commits into from
Nov 14, 2024
17 changes: 17 additions & 0 deletions docs/concepts/partial.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,23 @@ description: Learn to utilize field-level streaming with Instructor and OpenAI f

# Streaming Partial Responses

!!! info "Literal"

If the data structure you're using has literal values, you need to make sure to import the `PartialLiteralMixin` mixin.

```python
from instructor.dsl.partial import PartialLiteralMixin

class User(BaseModel, PartialLiteralMixin):
name: str
age: int
category: Literal["admin", "user", "guest"]

// The rest of your code below
```

This is because `jiter` throws an error otherwise if it encounters a incomplete Literal value while it's being streamed in

Field level streaming provides incremental snapshots of the current state of the response model that are immediately useable. This approach is particularly relevant in contexts like rendering UI components.

Instructor supports this pattern by making use of `create_partial`. This lets us dynamically create a new class that treats all of the original model's fields as `Optional`.
Expand Down
14 changes: 12 additions & 2 deletions instructor/dsl/partial.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ class MakeFieldsOptional:
pass


class PartialLiteralMixin:
pass


def _make_field_optional(
field: FieldInfo,
) -> tuple[Any, FieldInfo]:
Expand Down Expand Up @@ -127,10 +131,13 @@ def model_from_chunks(
) -> Generator[T_Model, None, None]:
potential_object = ""
partial_model = cls.get_partial_model()
partial_mode = (
"on" if issubclass(cls, PartialLiteralMixin) else "trailing-strings"
)
for chunk in json_chunks:
potential_object += chunk
obj = from_json(
(potential_object.strip() or "{}").encode(), partial_mode="on"
(potential_object.strip() or "{}").encode(), partial_mode=partial_mode
)
obj = partial_model.model_validate(obj, strict=None, **kwargs)
yield obj
Expand All @@ -141,10 +148,13 @@ async def model_from_chunks_async(
) -> AsyncGenerator[T_Model, None]:
potential_object = ""
partial_model = cls.get_partial_model()
partial_mode = (
"on" if issubclass(cls, PartialLiteralMixin) else "trailing-strings"
)
async for chunk in json_chunks:
potential_object += chunk
obj = from_json(
(potential_object.strip() or "{}").encode(), partial_mode="on"
(potential_object.strip() or "{}").encode(), partial_mode=partial_mode
)
obj = partial_model.model_validate(obj, strict=None, **kwargs)
yield obj
Expand Down
6 changes: 3 additions & 3 deletions tests/dsl/test_partial.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# type: ignore[all]
from pydantic import BaseModel, Field
from instructor.dsl.partial import Partial
from instructor.dsl.partial import Partial, PartialLiteralMixin
import pytest
import instructor
from openai import OpenAI, AsyncOpenAI
Expand Down Expand Up @@ -116,7 +116,7 @@ async def async_generator():


def test_summary_extraction():
class Summary(BaseModel):
class Summary(BaseModel, PartialLiteralMixin):
summary: str = Field(description="A detailed summary")

client = OpenAI()
Expand All @@ -143,7 +143,7 @@ class Summary(BaseModel):

@pytest.mark.asyncio
async def test_summary_extraction_async():
class Summary(BaseModel):
class Summary(BaseModel, PartialLiteralMixin):
summary: str = Field(description="A detailed summary")

client = AsyncOpenAI()
Expand Down
123 changes: 122 additions & 1 deletion tests/llm/test_openai/test_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from pydantic import BaseModel
import pytest
import instructor
from instructor.dsl.partial import Partial
from instructor.dsl.partial import Partial, PartialLiteralMixin

from .util import models, modes

Expand Down Expand Up @@ -81,3 +81,124 @@ async def test_partial_model_async(model, mode, aclient):
)
async for m in model:
assert isinstance(m, UserExtract)


@pytest.mark.parametrize("model,mode", product(models, modes))
def test_literal_partial_mixin(model, mode, client):
class UserWithMixin(BaseModel, PartialLiteralMixin):
name: str
age: int

client = instructor.patch(client, mode=mode)
resp = client.chat.completions.create(
model=model,
response_model=Partial[UserWithMixin],
max_retries=2,
stream=True,
messages=[
{"role": "user", "content": "Jason Liu is 12 years old"},
],
)

changes = 0
last_name = None
last_age = None
for m in resp:
assert isinstance(m, UserWithMixin)
if m.name != last_name:
last_name = m.name
changes += 1
if m.age != last_age:
last_age = m.age
changes += 1

assert changes == 2 # Ensure we got at least one field update

class UserWithoutMixin(BaseModel):
name: str
age: int

resp = client.chat.completions.create(
model=model,
response_model=Partial[UserWithoutMixin],
max_retries=2,
stream=True,
messages=[
{"role": "user", "content": "Jason Liu is 12 years old"},
],
)

changes = 0
last_name = None
last_age = None
for m in resp:
assert isinstance(m, UserWithoutMixin)
if m.name != last_name:
last_name = m.name
changes += 1
if m.age != last_age:
last_age = m.age
changes += 1

assert changes > 3


@pytest.mark.asyncio
@pytest.mark.parametrize("model,mode", product(models, modes))
async def test_literal_partial_mixin_async(model, mode, aclient):
class UserWithMixin(BaseModel, PartialLiteralMixin):
name: str
age: int

client = instructor.patch(aclient, mode=mode)
resp = await client.chat.completions.create(
model=model,
response_model=Partial[UserWithMixin],
max_retries=2,
stream=True,
messages=[
{"role": "user", "content": "Jason Liu is 12 years old"},
],
)

changes = 0
last_name = None
last_age = None
async for m in resp:
assert isinstance(m, UserWithMixin)
if m.name != last_name:
last_name = m.name
changes += 1
if m.age != last_age:
last_age = m.age
changes += 1

assert changes == 2 # Ensure we got at least one field update

class UserWithoutMixin(BaseModel):
name: str
age: int

resp = await client.chat.completions.create(
model=model,
response_model=Partial[UserWithoutMixin],
max_retries=2,
stream=True,
messages=[
{"role": "user", "content": "Jason Liu is 12 years old"},
],
)

changes = 0
last_name = None
last_age = None
async for m in resp:
assert isinstance(m, UserWithoutMixin)
if m.name != last_name:
last_name = m.name
changes += 1
if m.age != last_age:
last_age = m.age
changes += 1

assert changes > 3
Loading