Skip to content

Commit

Permalink
Merge pull request #108 from Chainlit/willy/eng-1754-fix-mistralai-in…
Browse files Browse the repository at this point in the history
…strumentation-for-100

Willy/eng 1754 fix mistralai instrumentation for 100
  • Loading branch information
clementsirieix authored Aug 8, 2024
2 parents 5011c6d + 4a0b063 commit 98b3571
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 62 deletions.
4 changes: 2 additions & 2 deletions literalai/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ def ensure_values_serializable(data):
pass

try:
from mistralai.models.chat_completion import ChatMessage
from mistralai import UserMessage

if isinstance(data, ChatMessage):
if isinstance(data, UserMessage):
return filter_none_values(data.model_dump())
except ImportError:
pass
Expand Down
101 changes: 53 additions & 48 deletions literalai/instrumentation/mistralai.py
Original file line number Diff line number Diff line change
@@ -1,89 +1,92 @@
import time
from typing import TYPE_CHECKING, AsyncGenerator, Dict, Union
from typing import TYPE_CHECKING, AsyncGenerator, Dict, Generator, Union

from literalai.instrumentation import MISTRALAI_PROVIDER
from literalai.requirements import check_all_requirements

if TYPE_CHECKING:
from literalai.client import LiteralClient

from types import GeneratorType

from literalai.context import active_steps_var, active_thread_var
from literalai.helper import ensure_values_serializable
from literalai.observability.generation import GenerationMessage, CompletionGeneration, ChatGeneration, GenerationType
from literalai.observability.generation import (
ChatGeneration,
CompletionGeneration,
GenerationMessage,
GenerationType,
)
from literalai.wrappers import AfterContext, BeforeContext, wrap_all

REQUIREMENTS = ["mistralai>=0.2.0"]
REQUIREMENTS = ["mistralai>=1.0.0"]

APIS_TO_WRAP = [
{
"module": "mistralai.client",
"object": "MistralClient",
"method": "chat",
"module": "mistralai.chat",
"object": "Chat",
"method": "complete",
"metadata": {
"type": GenerationType.CHAT,
},
"async": False,
},
{
"module": "mistralai.client",
"object": "MistralClient",
"method": "chat_stream",
"module": "mistralai.chat",
"object": "Chat",
"method": "stream",
"metadata": {
"type": GenerationType.CHAT,
},
"async": False,
},
{
"module": "mistralai.async_client",
"object": "MistralAsyncClient",
"method": "chat",
"module": "mistralai.chat",
"object": "Chat",
"method": "complete_async",
"metadata": {
"type": GenerationType.CHAT,
},
"async": True,
},
{
"module": "mistralai.async_client",
"object": "MistralAsyncClient",
"method": "chat_stream",
"module": "mistralai.chat",
"object": "Chat",
"method": "stream_async",
"metadata": {
"type": GenerationType.CHAT,
},
"async": True,
},
{
"module": "mistralai.client",
"object": "MistralClient",
"method": "completion",
"module": "mistralai.fim",
"object": "Fim",
"method": "complete",
"metadata": {
"type": GenerationType.COMPLETION,
},
"async": False,
},
{
"module": "mistralai.client",
"object": "MistralClient",
"method": "completion_stream",
"module": "mistralai.fim",
"object": "Fim",
"method": "stream",
"metadata": {
"type": GenerationType.COMPLETION,
},
"async": False,
},
{
"module": "mistralai.async_client",
"object": "MistralAsyncClient",
"method": "completion",
"module": "mistralai.fim",
"object": "Fim",
"method": "complete_async",
"metadata": {
"type": GenerationType.COMPLETION,
},
"async": True,
},
{
"module": "mistralai.async_client",
"object": "MistralAsyncClient",
"method": "completion_stream",
"module": "mistralai.fim",
"object": "Fim",
"method": "stream_async",
"metadata": {
"type": GenerationType.COMPLETION,
},
Expand Down Expand Up @@ -239,13 +242,13 @@ async def before(context: BeforeContext, *args, **kwargs):

return before

from mistralai.models.chat_completion import DeltaMessage
from mistralai import DeltaMessage

def process_delta(new_delta: DeltaMessage, message_completion: GenerationMessage):
if new_delta.tool_calls:
if "tool_calls" not in message_completion:
message_completion["tool_calls"] = []
delta_tool_call = new_delta.tool_calls[0]
delta_tool_call = new_delta.tool_calls[0] # type: ignore
delta_function = delta_tool_call.function
if not delta_function:
return False
Expand Down Expand Up @@ -273,9 +276,11 @@ def process_delta(new_delta: DeltaMessage, message_completion: GenerationMessage
else:
return False

from mistralai import models

def streaming_response(
generation: Union[ChatGeneration, CompletionGeneration],
result: GeneratorType,
result: Generator[models.CompletionEvent, None, None],
context: AfterContext,
):
completion = ""
Expand All @@ -286,8 +291,8 @@ def streaming_response(
token_count = 0
for chunk in result:
if generation and isinstance(generation, ChatGeneration):
if len(chunk.choices) > 0:
ok = process_delta(chunk.choices[0].delta, message_completion)
if len(chunk.data.choices) > 0:
ok = process_delta(chunk.data.choices[0].delta, message_completion)
if not ok:
yield chunk
continue
Expand All @@ -298,22 +303,22 @@ def streaming_response(
token_count += 1
elif generation and isinstance(generation, CompletionGeneration):
if (
len(chunk.choices) > 0
and chunk.choices[0].message.content is not None
len(chunk.data.choices) > 0
and chunk.data.choices[0].delta.content is not None
):
if generation.tt_first_token is None:
generation.tt_first_token = (
time.time() - context["start"]
) * 1000
token_count += 1
completion += chunk.choices[0].message.content
completion += chunk.data.choices[0].delta.content

if (
generation
and getattr(chunk, "model", None)
and generation.model != chunk.model
and generation.model != chunk.data.model
):
generation.model = chunk.model
generation.model = chunk.data.model

yield chunk

Expand Down Expand Up @@ -358,7 +363,7 @@ def after(result, context: AfterContext, *args, **kwargs):
generation.model = model
if generation.settings:
generation.settings["model"] = model
if isinstance(result, GeneratorType):
if isinstance(result, Generator):
return streaming_response(generation, result, context)
else:
generation.duration = time.time() - context["start"]
Expand Down Expand Up @@ -387,7 +392,7 @@ def after(result, context: AfterContext, *args, **kwargs):

async def async_streaming_response(
generation: Union[ChatGeneration, CompletionGeneration],
result: AsyncGenerator,
result: AsyncGenerator[models.CompletionEvent, None],
context: AfterContext,
):
completion = ""
Expand All @@ -398,8 +403,8 @@ async def async_streaming_response(
token_count = 0
async for chunk in result:
if generation and isinstance(generation, ChatGeneration):
if len(chunk.choices) > 0:
ok = process_delta(chunk.choices[0].delta, message_completion)
if len(chunk.data.choices) > 0:
ok = process_delta(chunk.data.choices[0].delta, message_completion)
if not ok:
yield chunk
continue
Expand All @@ -410,22 +415,22 @@ async def async_streaming_response(
token_count += 1
elif generation and isinstance(generation, CompletionGeneration):
if (
len(chunk.choices) > 0
and chunk.choices[0].message.content is not None
len(chunk.data.choices) > 0
and chunk.data.choices[0].delta is not None
):
if generation.tt_first_token is None:
generation.tt_first_token = (
time.time() - context["start"]
) * 1000
token_count += 1
completion += chunk.choices[0].message.content
completion += chunk.data.choices[0].delta.content or ""

if (
generation
and getattr(chunk, "model", None)
and generation.model != chunk.model
and generation.model != chunk.data.model
):
generation.model = chunk.model
generation.model = chunk.data.model

yield chunk

Expand Down
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ mypy
langchain
llama-index
pytest_httpx
mistralai < 1.0.0
mistralai
21 changes: 10 additions & 11 deletions tests/e2e/test_mistralai.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@
from asyncio import sleep

import pytest
from mistralai.async_client import MistralAsyncClient
from mistralai.client import MistralClient
from mistralai import Mistral
from pytest_httpx import HTTPXMock

from literalai.client import LiteralClient
from literalai.observability.generation import CompletionGeneration, ChatGeneration
from literalai.observability.generation import ChatGeneration, CompletionGeneration


@pytest.fixture
Expand Down Expand Up @@ -63,13 +62,13 @@ async def test_chat(self, client: "LiteralClient", httpx_mock: "HTTPXMock"):
},
}
)
mai_client = MistralClient(api_key="j3s4V1z4")
mai_client = Mistral(api_key="j3s4V1z4")
thread_id = None

@client.thread
def main():
# https://docs.mistral.ai/api/#operation/createChatCompletion
mai_client.chat(
mai_client.chat.complete(
model="open-mistral-7b",
messages=[
{
Expand Down Expand Up @@ -124,13 +123,13 @@ async def test_completion(self, client: "LiteralClient", httpx_mock: "HTTPXMock"
},
)

mai_client = MistralClient(api_key="j3s4V1z4")
mai_client = Mistral(api_key="j3s4V1z4")
thread_id = None

@client.thread
def main():
# https://docs.mistral.ai/api/#operation/createFIMCompletion
mai_client.completion(
mai_client.fim.complete(
model="codestral-2405",
prompt="1+1=",
temperature=0,
Expand Down Expand Up @@ -183,13 +182,13 @@ async def test_async_chat(self, client: "LiteralClient", httpx_mock: "HTTPXMock"
},
)

mai_client = MistralAsyncClient(api_key="j3s4V1z4")
mai_client = Mistral(api_key="j3s4V1z4")
thread_id = None

@client.thread
async def main():
# https://docs.mistral.ai/api/#operation/createChatCompletion
await mai_client.chat(
await mai_client.chat.complete_async(
model="open-mistral-7b",
messages=[
{
Expand Down Expand Up @@ -246,13 +245,13 @@ async def test_async_completion(
},
)

mai_client = MistralAsyncClient(api_key="j3s4V1z4")
mai_client = Mistral(api_key="j3s4V1z4")
thread_id = None

@client.thread
async def main():
# https://docs.mistral.ai/api/#operation/createFIMCompletion
await mai_client.completion(
await mai_client.fim.complete_async(
model="codestral-2405",
prompt="1+1=",
temperature=0,
Expand Down

0 comments on commit 98b3571

Please sign in to comment.