Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mistral optimised #396

Merged
merged 22 commits into from
Dec 23, 2024
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 105 additions & 15 deletions pydantic_ai_slim/pydantic_ai/models/mistral.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations as _annotations

import json
import os
from collections.abc import AsyncIterator, Iterable
from contextlib import asynccontextmanager
Expand Down Expand Up @@ -39,7 +40,6 @@
)

try:
from json_repair import repair_json
from mistralai import (
UNSET,
CompletionChunk as MistralCompletionChunk,
Expand Down Expand Up @@ -198,11 +198,10 @@ async def _stream_completions_create(
"""Create a streaming completion request to the Mistral model."""
response: MistralEventStreamAsync[MistralCompletionEvent] | None
mistral_messages = list(chain(*(self._map_message(m) for m in messages)))

model_settings = model_settings or {}

if self.result_tools and self.function_tools or self.function_tools:
# Function Calling Mode
# Function Calling
response = await self.client.chat.stream_async(
model=str(self.model_name),
messages=mistral_messages,
Expand All @@ -218,9 +217,9 @@ async def _stream_completions_create(
elif self.result_tools:
# Json Mode
parameters_json_schemas = [tool.parameters_json_schema for tool in self.result_tools]

user_output_format_message = self._generate_user_output_format(parameters_json_schemas)
mistral_messages.append(user_output_format_message)

response = await self.client.chat.stream_async(
model=str(self.model_name),
messages=mistral_messages,
Expand Down Expand Up @@ -270,12 +269,13 @@ def _map_function_and_result_tools_definition(self) -> list[MistralTool] | None:
@staticmethod
def _process_response(response: MistralChatCompletionResponse) -> ModelResponse:
"""Process a non-streamed response, and prepare a message to return."""
assert response.choices, 'Unexpected empty response choice.'

if response.created:
timestamp = datetime.fromtimestamp(response.created, tz=timezone.utc)
else:
timestamp = _now_utc()

assert response.choices, 'Unexpected empty response choice.'
choice = response.choices[0]
content = choice.message.content
tool_calls = choice.message.tool_calls
Expand Down Expand Up @@ -330,6 +330,7 @@ async def _process_streamed_response(
content,
timestamp,
start_usage,
_JSONChunkParser(),
)

elif content:
Expand Down Expand Up @@ -509,6 +510,101 @@ def timestamp(self) -> datetime:
return self._timestamp


class _JSONChunkParser:
YanSte marked this conversation as resolved.
Show resolved Hide resolved
"""A class to repair JSON chunks that might be corrupted (e.g. missing closing quotes)."""

def __init__(self) -> None:
self.new_chars: list[str] = []
self.stack: list[str] = []
self.is_inside_string = False
self.escaped = False

def process_chunk(self, chunk: str) -> dict[str, Any] | None:
"""Process a JSON chunk, attempting to parse it into a valid JSON object by repairing issues."""
# Strip whitespace, newlines, backtick from the start and end
chunk = chunk.strip(' \n\r\t`')
try:
output_json: dict[str, Any] | None = json.loads(chunk)
return output_json
except json.JSONDecodeError:
pass # Continue to attempt repairing

return self._repair_json(chunk)

def _repair_json(self, chunk: str) -> dict[str, Any] | None:
"""Attempts to repair and parse the accumulated buffer as JSON, handling common issues."""
# Next string to continue processing from the previous iteration.
start_index = len(self.new_chars)
for char in chunk[start_index:]:
if self.is_inside_string:
# End of string detected
if char == '"' and not self.escaped:
self.is_inside_string = False

# Replace newline with escape sequence within a string
elif char == '\n' and not self.escaped:
char = '\\n'

# Toggle escaped status on encountering backslash
elif char == '\\':
self.escaped = not self.escaped

# Reset escaped status for other characters
else:
self.escaped = False
else:
# Start of string detected
if char == '"':
self.is_inside_string = True
self.escaped = False

# Track expected closing brace
elif char == '{':
self.stack.append('}')

# Track expected closing bracket
elif char == '[':
self.stack.append(']')

# Handle closing characters and check for mismatches
elif char == '}' or char == ']':
if self.stack and self.stack[-1] == char:
self.stack.pop()
else:
# Mismatched closing character means malformed input
return None

self.new_chars.append(char)

# Prepare closing characters to balance the structure (Copy)
closing_chars = self.stack[::]

# If inside a string, ensure it is closed
if self.is_inside_string:
closing_chars.append('"')
self.is_inside_string = True

# Reverse to maintain correct order of closing characters
closing_chars.reverse()

# (Copy)
repaired_chars = self.new_chars[::]

# Attempt to parse the repaired JSON string
while repaired_chars:
try:
value = ''.join(repaired_chars + closing_chars)
return json.loads(value)
except json.JSONDecodeError:
# Remove the last character and retry parsing
value = repaired_chars.pop()
# Check if the last character removed was an opening character
if closing_chars and closing_chars[0] == {'"': '"', '{': '}', '[': ']'}.get(value):
closing_chars.pop(0)

return None


@dataclass
class MistralStreamStructuredResponse(StreamStructuredResponse):
"""Implementation of `StreamStructuredResponse` for Mistral models."""
Expand All @@ -519,6 +615,7 @@ class MistralStreamStructuredResponse(StreamStructuredResponse):
_delta_content: str | None
_timestamp: datetime
_usage: Usage
_json: _JSONChunkParser

async def __anext__(self) -> None:
chunk = await self._response.__anext__()
Expand Down Expand Up @@ -546,20 +643,13 @@ def get(self, *, final: bool = False) -> ModelResponse:
calls.append(tool)

elif self._delta_content and self._result_tools:
# NOTE: Params set for the most efficient and fastest way.
output_json = repair_json(self._delta_content, return_objects=True, skip_json_loads=True)
assert isinstance(
output_json, dict
), f'Expected repair_json as type dict, invalid type: {type(output_json)}'
output_json: dict[str, Any] | None = self._json.process_chunk(self._delta_content)

if output_json:
for result_tool in self._result_tools.values():
# NOTE: Additional verification to prevent JSON validation to crash in `result.py`
# NOTE: Additional verification to prevent JSON validation to crash in `_result.py`
# Ensures required parameters in the JSON schema are respected, especially for stream-based return types.
# For example, `return_type=list[str]` expects a 'response' key with value type array of str.
# when `{"response":` then `repair_json` sets `{"response": ""}` (type not found default str)
# when `{"response": {` then `repair_json` sets `{"response": {}}` (type found)
# This ensures it's corrected to `{"response": {}}` and other required parameters and type.
# Example with BaseModel and required fields.
if not self._validate_required_json_shema(output_json, result_tool.parameters_json_schema):
continue

Expand Down
2 changes: 1 addition & 1 deletion pydantic_ai_slim/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ openai = ["openai>=1.54.3"]
vertexai = ["google-auth>=2.36.0", "requests>=2.32.3"]
anthropic = ["anthropic>=0.40.0"]
groq = ["groq>=0.12.0"]
mistral = ["mistralai>=1.2.5", "json-repair>=0.30.3"]
mistral = ["mistralai>=1.2.5"]
logfire = ["logfire>=2.3"]

[dependency-groups]
Expand Down
Loading
Loading