Skip to content

Commit

Permalink
fix: renamed the mixin
Browse files Browse the repository at this point in the history
  • Loading branch information
ivanleomk committed Nov 13, 2024
1 parent 87684eb commit 53dfdb0
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 70 deletions.
6 changes: 3 additions & 3 deletions docs/concepts/partial.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@ description: Learn to utilize field-level streaming with Instructor and OpenAI f

!!! info "Literal"

If the data structure you're using has literal values, you need to make sure to import the `PartialStringHandlingMixin` mixin.
If the data structure you're using has literal values, you need to make sure to import the `PartialLiteralMixin` mixin.

```python
from instructor.dsl.partial import PartialStringHandlingMixin
from instructor.dsl.partial import PartialLiteralMixin

class User(BaseModel, PartialStringHandlingMixin):
class User(BaseModel, PartialLiteralMixin):
name: str
age: int
category: Literal["admin", "user", "guest"]
Expand Down
6 changes: 3 additions & 3 deletions instructor/dsl/partial.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class MakeFieldsOptional:
pass


class PartialStringHandlingMixin:
class PartialLiteralMixin:
pass


Expand Down Expand Up @@ -132,7 +132,7 @@ def model_from_chunks(
potential_object = ""
partial_model = cls.get_partial_model()
partial_mode = (
"on" if issubclass(cls, PartialStringHandlingMixin) else "trailing-strings"
"on" if issubclass(cls, PartialLiteralMixin) else "trailing-strings"
)
for chunk in json_chunks:
potential_object += chunk
Expand All @@ -149,7 +149,7 @@ async def model_from_chunks_async(
potential_object = ""
partial_model = cls.get_partial_model()
partial_mode = (
"on" if issubclass(cls, PartialStringHandlingMixin) else "trailing-strings"
"on" if issubclass(cls, PartialLiteralMixin) else "trailing-strings"
)
async for chunk in json_chunks:
potential_object += chunk
Expand Down
6 changes: 3 additions & 3 deletions tests/dsl/test_partial.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# type: ignore[all]
from pydantic import BaseModel, Field
from instructor.dsl.partial import Partial, PartialStringHandlingMixin
from instructor.dsl.partial import Partial, PartialLiteralMixin
import pytest
import instructor
from openai import OpenAI, AsyncOpenAI
Expand Down Expand Up @@ -116,7 +116,7 @@ async def async_generator():


def test_summary_extraction():
class Summary(BaseModel, PartialStringHandlingMixin):
class Summary(BaseModel, PartialLiteralMixin):
summary: str = Field(description="A detailed summary")

client = OpenAI()
Expand All @@ -143,7 +143,7 @@ class Summary(BaseModel, PartialStringHandlingMixin):

@pytest.mark.asyncio
async def test_summary_extraction_async():
class Summary(BaseModel, PartialStringHandlingMixin):
class Summary(BaseModel, PartialLiteralMixin):
summary: str = Field(description="A detailed summary")

client = AsyncOpenAI()
Expand Down
123 changes: 62 additions & 61 deletions tests/llm/test_openai/test_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from pydantic import BaseModel
import pytest
import instructor
from instructor.dsl.partial import Partial, PartialStringHandlingMixin
from instructor.dsl.partial import Partial, PartialLiteralMixin

from .util import models, modes

Expand Down Expand Up @@ -85,7 +85,7 @@ async def test_partial_model_async(model, mode, aclient):

@pytest.mark.parametrize("model,mode", product(models, modes))
def test_literal_partial_mixin(model, mode, client):
class UserWithMixin(BaseModel, PartialStringHandlingMixin):
class UserWithMixin(BaseModel, PartialLiteralMixin):
name: str
age: int

Expand Down Expand Up @@ -142,62 +142,63 @@ class UserWithoutMixin(BaseModel):

assert changes > 3

@pytest.mark.asyncio
@pytest.mark.parametrize("model,mode", product(models, modes))
async def test_literal_partial_mixin_async(model, mode, client):
class UserWithMixin(BaseModel, PartialStringHandlingMixin):
name: str
age: int

client = instructor.patch(client, mode=mode)
resp = await client.chat.completions.create(
model=model,
response_model=Partial[UserWithMixin],
max_retries=2,
stream=True,
messages=[
{"role": "user", "content": "Jason Liu is 12 years old"},
],
)

changes = 0
last_name = None
last_age = None
async for m in resp:
assert isinstance(m, UserWithMixin)
if m.name != last_name:
last_name = m.name
changes += 1
if m.age != last_age:
last_age = m.age
changes += 1

assert changes == 2 # Ensure we got at least one field update

class UserWithoutMixin(BaseModel):
name: str
age: int

resp = await client.chat.completions.create(
model=model,
response_model=Partial[UserWithoutMixin],
max_retries=2,
stream=True,
messages=[
{"role": "user", "content": "Jason Liu is 12 years old"},
],
)

changes = 0
last_name = None
last_age = None
async for m in resp:
assert isinstance(m, UserWithoutMixin)
if m.name != last_name:
last_name = m.name
changes += 1
if m.age != last_age:
last_age = m.age
changes += 1

assert changes > 3

@pytest.mark.asyncio
@pytest.mark.parametrize("model,mode", product(models, modes))
async def test_literal_partial_mixin_async(model, mode, client):
class UserWithMixin(BaseModel, PartialLiteralMixin):
name: str
age: int

client = instructor.patch(client, mode=mode)
resp = await client.chat.completions.create(
model=model,
response_model=Partial[UserWithMixin],
max_retries=2,
stream=True,
messages=[
{"role": "user", "content": "Jason Liu is 12 years old"},
],
)

changes = 0
last_name = None
last_age = None
async for m in resp:
assert isinstance(m, UserWithMixin)
if m.name != last_name:
last_name = m.name
changes += 1
if m.age != last_age:
last_age = m.age
changes += 1

assert changes == 2 # Ensure we got at least one field update

class UserWithoutMixin(BaseModel):
name: str
age: int

resp = await client.chat.completions.create(
model=model,
response_model=Partial[UserWithoutMixin],
max_retries=2,
stream=True,
messages=[
{"role": "user", "content": "Jason Liu is 12 years old"},
],
)

changes = 0
last_name = None
last_age = None
async for m in resp:
assert isinstance(m, UserWithoutMixin)
if m.name != last_name:
last_name = m.name
changes += 1
if m.age != last_age:
last_age = m.age
changes += 1

assert changes > 3

0 comments on commit 53dfdb0

Please sign in to comment.