Skip to content

Commit

Permalink
fix logfire instrumentation of sql queries
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelcolvin committed Dec 10, 2024
1 parent cc44897 commit 3aa23e1
Showing 1 changed file with 32 additions and 14 deletions.
46 changes: 32 additions & 14 deletions pydantic_ai_examples/chat_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,16 @@
from concurrent.futures.thread import ThreadPoolExecutor
from contextlib import asynccontextmanager
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from typing import Annotated, Callable, TypeVar
from typing import Annotated, Any, Callable, TypeVar

import fastapi
import logfire
from fastapi import Depends, Request
from fastapi.responses import HTMLResponse, Response, StreamingResponse
from pydantic import Field, TypeAdapter
from typing_extensions import ParamSpec
from typing_extensions import LiteralString, ParamSpec

from pydantic_ai import Agent
from pydantic_ai.messages import (
Expand All @@ -35,6 +36,7 @@
logfire.configure(send_to_logfire='if-token-present')

agent = Agent('openai:gpt-4o')
THIS_DIR = Path(__file__).parent


@asynccontextmanager
Expand Down Expand Up @@ -80,7 +82,6 @@ async def stream_messages():
# stream the user prompt so that can be displayed straight away
yield MessageTypeAdapter.dump_json(UserPrompt(content=prompt)) + b'\n'
# get the chat history so far to pass as context to the agent
assert database is not None, 'Database not initialised'
messages = await database.get_messages()
# run the agent with the user prompt and the chat history
async with agent.run_stream(prompt, message_history=messages) as result:
Expand All @@ -96,7 +97,6 @@ async def stream_messages():
return StreamingResponse(stream_messages(), media_type='text/plain')


THIS_DIR = Path(__file__).parent
MessageTypeAdapter: TypeAdapter[Message] = TypeAdapter(
Annotated[Message, Field(discriminator='role')]
)
Expand All @@ -106,7 +106,11 @@ async def stream_messages():

@dataclass
class Database:
"""Rudimentary database to store chat messages in SQLite."""
"""Rudimentary database to store chat messages in SQLite.
The SQLite standard library package is synchronous, so we
use a thread pool executor to run queries asynchronously.
"""

con: sqlite3.Connection
_loop: asyncio.AbstractEventLoop
Expand All @@ -117,10 +121,11 @@ class Database:
async def connect(
cls, file: Path = THIS_DIR / '.chat_app_messages.sqlite'
) -> AsyncIterator[Database]:
loop = asyncio.get_event_loop()
executor = ThreadPoolExecutor(max_workers=1)
con = await loop.run_in_executor(executor, cls._connect, file)
slf = cls(con, loop, executor)
with logfire.span('connect to DB'):
loop = asyncio.get_event_loop()
executor = ThreadPoolExecutor(max_workers=1)
con = await loop.run_in_executor(executor, cls._connect, file)
slf = cls(con, loop, executor)
try:
yield slf
finally:
Expand All @@ -139,27 +144,40 @@ def _connect(file: Path) -> sqlite3.Connection:

async def add_messages(self, messages: bytes):
await self._asyncify(
self.con.execute,
self._execute,
'INSERT INTO messages (message_list) VALUES (?);',
(messages,),
messages,
commit=True,
)
await self._asyncify(self.con.commit)

async def get_messages(self) -> list[Message]:
c = await self._asyncify(
self.con.execute, 'SELECT message_list FROM messages order by id desc'
self._execute, 'SELECT message_list FROM messages order by id desc'
)
rows = await self._asyncify(c.fetchall)
messages: list[Message] = []
for row in rows:
messages.extend(MessagesTypeAdapter.validate_json(row[0]))
return messages

def _execute(
self, sql: LiteralString, *args: Any, commit: bool = False
) -> sqlite3.Cursor:
cur = self.con.cursor()
cur.execute(sql, args)
if commit:
self.con.commit()
return cur

async def _asyncify(
self, func: Callable[P, R], *args: P.args, **kwargs: P.kwargs
) -> R:
assert kwargs == {}, 'kwargs not supported'
return await self._loop.run_in_executor(self._executor, func, *args) # type: ignore
return await self._loop.run_in_executor( # type: ignore
self._executor,
partial(func, **kwargs),
*args, # type: ignore
)


if __name__ == '__main__':
Expand Down

0 comments on commit 3aa23e1

Please sign in to comment.