Skip to content

Commit

Permalink
feat(persistence): use sqlean v3.45.1 as sqlite engine (#2947)
Browse files Browse the repository at this point in the history
  • Loading branch information
RogerHYang authored Apr 23, 2024
1 parent 3fb88bf commit 3b202d7
Show file tree
Hide file tree
Showing 8 changed files with 139 additions and 49 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ dependencies = [
"sqlalchemy[asyncio]>=2.0.4, <3",
"alembic>=1.3.0, <2",
"aiosqlite",
"sqlean.py>=3.45.1",
]
dynamic = ["version"]

Expand Down Expand Up @@ -144,7 +145,6 @@ dependencies = [
"respx", # For OpenAI testing
"nest-asyncio", # for executor testing
"astunparse; python_version<'3.9'", # `ast.unparse(...)` is only available starting with Python 3.9
"sqlean.py", # for running GitHub CI on Windows, because its SQLite doesn't support JSON_EXTRACT(...).
]

[tool.hatch.envs.type]
Expand Down Expand Up @@ -324,6 +324,7 @@ module = [
"nest_asyncio",
"opentelemetry.*",
"pyarrow",
"sqlean",
]
ignore_missing_imports = true

Expand Down
49 changes: 32 additions & 17 deletions src/phoenix/db/engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,20 @@
import json
from datetime import datetime
from enum import Enum
from pathlib import Path
from sqlite3 import Connection
from typing import Any, Union
from typing import Any

import aiosqlite
import numpy as np
import sqlean
from sqlalchemy import URL, event, make_url
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine

from phoenix.db.migrate import migrate_in_thread
from phoenix.db.models import init_models

sqlean.extensions.enable("text")


def set_sqlite_pragma(connection: Connection, _: Any) -> None:
cursor = connection.cursor()
Expand All @@ -24,10 +27,6 @@ def set_sqlite_pragma(connection: Connection, _: Any) -> None:
cursor.close()


def get_db_url(driver: str = "sqlite+aiosqlite", database: Union[str, Path] = ":memory:") -> URL:
return URL.create(driver, database=str(database))


def get_printable_db_url(connection_str: str) -> str:
return make_url(connection_str).render_as_string(hide_password=True)

Expand All @@ -39,9 +38,11 @@ def get_async_db_url(connection_str: str) -> URL:
url = make_url(connection_str)
if not url.database:
raise ValueError("Failed to parse database from connection string")
if "sqlite" in url.drivername:
return get_db_url(driver="sqlite+aiosqlite", database=url.database)
if "postgresql" in url.drivername:
if url.drivername.partition("+")[0] == "sqlite":
if url.database.startswith(":memory:"):
url = url.set(query={"cache": "shared"})
return url.set(drivername="sqlite+aiosqlite")
if url.drivername.partition("+")[0] == "postgresql":
url = url.set(drivername="postgresql+asyncpg")
# For some reason username and password cannot be parsed from the typical slot
# So we need to parse them out manually
Expand All @@ -61,22 +62,36 @@ def create_engine(connection_str: str, echo: bool = False) -> AsyncEngine:
url = make_url(connection_str)
if not url.database:
raise ValueError("Failed to parse database from connection string")
if "sqlite" in url.drivername:
# Split the URL to get the database name
return aio_sqlite_engine(database=url.database, echo=echo)
if "postgresql" in url.drivername:
if url.drivername.partition("+")[0] == "sqlite":
return aio_sqlite_engine(url=url, echo=echo)
if url.drivername.partition("+")[0] == "postgresql":
return aio_postgresql_engine(url=url, echo=echo)
raise ValueError(f"Unsupported driver: {url.drivername}")


def aio_sqlite_engine(
database: Union[str, Path] = ":memory:",
url: URL,
echo: bool = False,
) -> AsyncEngine:
url = get_db_url(driver="sqlite+aiosqlite", database=database)
engine = create_async_engine(url=url, echo=echo, json_serializer=_dumps)
async_url = get_async_db_url(url.render_as_string())
assert async_url.database

def async_creator() -> aiosqlite.Connection:
conn = aiosqlite.Connection(
lambda: sqlean.connect(async_url.database, uri=True),
iter_chunk_size=64,
)
conn.daemon = True
return conn

engine = create_async_engine(
url=async_url,
echo=echo,
json_serializer=_dumps,
async_creator=async_creator,
)
event.listen(engine.sync_engine, "connect", set_sqlite_pragma)
if str(database) == ":memory:":
if async_url.database.startswith(":memory:"):
try:
asyncio.get_running_loop()
except RuntimeError:
Expand Down
28 changes: 24 additions & 4 deletions src/phoenix/db/migrations/versions/cf03bd6bae1d_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,42 @@
"""

from typing import Sequence, Union
from typing import Any, Sequence, Union

import sqlalchemy as sa
from alembic import op
from sqlalchemy import JSON
from sqlalchemy.dialects import postgresql
from sqlalchemy.ext.compiler import compiles

# revision identifiers, used by Alembic.
revision: str = "cf03bd6bae1d"
down_revision: Union[str, None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None

JSON_ = JSON().with_variant(
postgresql.JSONB(), # type: ignore
"postgresql",

class JSONB(JSON):
# See https://docs.sqlalchemy.org/en/20/core/custom_types.html
__visit_name__ = "JSONB"


@compiles(JSONB, "sqlite") # type: ignore
def _(*args: Any, **kwargs: Any) -> str:
# See https://docs.sqlalchemy.org/en/20/core/custom_types.html
return "JSONB"


JSON_ = (
JSON()
.with_variant(
postgresql.JSONB(), # type: ignore
"postgresql",
)
.with_variant(
JSONB(),
"sqlite",
)
)


Expand Down
62 changes: 59 additions & 3 deletions src/phoenix/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
Float,
ForeignKey,
MetaData,
String,
TypeDecorator,
UniqueConstraint,
func,
Expand All @@ -28,9 +29,28 @@
)
from sqlalchemy.sql import expression

JSON_ = JSON().with_variant(
postgresql.JSONB(), # type: ignore
"postgresql",

class JSONB(JSON):
# See https://docs.sqlalchemy.org/en/20/core/custom_types.html
__visit_name__ = "JSONB"


@compiles(JSONB, "sqlite") # type: ignore
def _(*args: Any, **kwargs: Any) -> str:
# See https://docs.sqlalchemy.org/en/20/core/custom_types.html
return "JSONB"


JSON_ = (
JSON()
.with_variant(
postgresql.JSONB(), # type: ignore
"postgresql",
)
.with_variant(
JSONB(),
"sqlite",
)
)


Expand All @@ -43,6 +63,7 @@ class UtcTimeStamp(TypeDecorator[datetime]):
programs are always timezone-aware.
"""

# See # See https://docs.sqlalchemy.org/en/20/core/custom_types.html
cache_ok = True
impl = TIMESTAMP(timezone=True)
_LOCAL_TIMEZONE = datetime.now(timezone.utc).astimezone().tzinfo
Expand Down Expand Up @@ -128,11 +149,13 @@ class Trace(Base):

@hybrid_property
def latency_ms(self) -> float:
# See https://docs.sqlalchemy.org/en/20/orm/extensions/hybrid.html
return (self.end_time - self.start_time).total_seconds() * 1000

@latency_ms.inplace.expression
@classmethod
def _latency_ms_expression(cls) -> ColumnElement[float]:
# See https://docs.sqlalchemy.org/en/20/orm/extensions/hybrid.html
return LatencyMs(cls.start_time, cls.end_time)

project: Mapped["Project"] = relationship(
Expand Down Expand Up @@ -181,11 +204,13 @@ class Span(Base):

@hybrid_property
def latency_ms(self) -> float:
# See https://docs.sqlalchemy.org/en/20/orm/extensions/hybrid.html
return (self.end_time - self.start_time).total_seconds() * 1000

@latency_ms.inplace.expression
@classmethod
def _latency_ms_expression(cls) -> ColumnElement[float]:
# See https://docs.sqlalchemy.org/en/20/orm/extensions/hybrid.html
return LatencyMs(cls.start_time, cls.end_time)

@hybrid_property
Expand All @@ -205,13 +230,15 @@ def cumulative_llm_token_count_total(self) -> int:


class LatencyMs(expression.FunctionElement[float]):
# See https://docs.sqlalchemy.org/en/20/core/compiler.html
inherit_cache = True
type = Float()
name = "latency_ms"


@compiles(LatencyMs) # type: ignore
def _(element: Any, compiler: Any, **kw: Any) -> Any:
# See https://docs.sqlalchemy.org/en/20/core/compiler.html
start_time, end_time = list(element.clauses)
return compiler.process(
(func.extract("EPOCH", end_time) - func.extract("EPOCH", start_time)) * 1000, **kw
Expand All @@ -220,6 +247,7 @@ def _(element: Any, compiler: Any, **kw: Any) -> Any:

@compiles(LatencyMs, "sqlite") # type: ignore
def _(element: Any, compiler: Any, **kw: Any) -> Any:
# See https://docs.sqlalchemy.org/en/20/core/compiler.html
start_time, end_time = list(element.clauses)
return compiler.process(
# FIXME: We don't know why sqlite returns a slightly different value.
Expand All @@ -230,6 +258,34 @@ def _(element: Any, compiler: Any, **kw: Any) -> Any:
)


class TextContains(expression.FunctionElement[str]):
# See https://docs.sqlalchemy.org/en/20/core/compiler.html
inherit_cache = True
type = String()
name = "text_contains"


@compiles(TextContains) # type: ignore
def _(element: Any, compiler: Any, **kw: Any) -> Any:
# See https://docs.sqlalchemy.org/en/20/core/compiler.html
string, substring = list(element.clauses)
return compiler.process(string.contains(substring), **kw)


@compiles(TextContains, "postgresql") # type: ignore
def _(element: Any, compiler: Any, **kw: Any) -> Any:
# See https://docs.sqlalchemy.org/en/20/core/compiler.html
string, substring = list(element.clauses)
return compiler.process(func.strpos(string, substring) > 0, **kw)


@compiles(TextContains, "sqlite") # type: ignore
def _(element: Any, compiler: Any, **kw: Any) -> Any:
# See https://docs.sqlalchemy.org/en/20/core/compiler.html
string, substring = list(element.clauses)
return compiler.process(func.text_contains(string, substring) > 0, **kw)


async def init_models(engine: AsyncEngine) -> None:
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
Expand Down
10 changes: 3 additions & 7 deletions src/phoenix/trace/dsl/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def __call__(self, select: Select[typing.Any]) -> Select[typing.Any]:
"cast": sqlalchemy.cast,
"Float": sqlalchemy.Float,
"String": sqlalchemy.String,
"TextContains": models.TextContains,
},
)
)
Expand Down Expand Up @@ -344,13 +345,8 @@ def visit_Compare(self, node: ast.Compare) -> typing.Any:
or (typing.cast(str, ast.get_source_segment(self._source, right))) in _NAMES
):
call = ast.Call(
# TODO(persistence): FIXME: This turns into `LIKE` which for sqlite is
# case-insensitive. We want case-sensitive matching for strings,
# so for sqlite we need to turn this into `GLOB` instead.
# TODO(persistence): FIXME: Special characters such as `%` for `LIKE`
# and `*` for `GLOB` need to be escaped.
func=ast.Attribute(value=right, attr="contains", ctx=ast.Load()),
args=[left],
func=ast.Name(id="TextContains", ctx=ast.Load()),
args=[right, left],
keywords=[],
)
if isinstance(op, ast.NotIn):
Expand Down
10 changes: 6 additions & 4 deletions tests/trace/dsl/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@

@pytest.fixture(scope="session")
def session_maker() -> sessionmaker:
# `sqlean` is added to help with running the test on GitHub CI for Windows,
# because its version of SQLite doesn't have `JSON_EXTRACT`.
sqlean.extensions.enable_all()
engine = create_engine("sqlite:///:memory:", module=sqlean, echo=True)
Base.metadata.create_all(engine)
session_maker = sessionmaker(engine)
Expand Down Expand Up @@ -124,7 +123,7 @@ def _insert_project_abc(session: Session) -> None:
start_time=datetime.fromisoformat("2021-01-01T00:00:00.000+00:00"),
end_time=datetime.fromisoformat("2021-01-01T00:00:30.000+00:00"),
attributes={
"input": {"value": "210"},
"input": {"value": "xy%z*"},
"output": {"value": "321"},
},
events=[],
Expand All @@ -147,6 +146,9 @@ def _insert_project_abc(session: Session) -> None:
start_time=datetime.fromisoformat("2021-01-01T00:00:00.000+00:00"),
end_time=datetime.fromisoformat("2021-01-01T00:00:05.000+00:00"),
attributes={
"input": {
"value": "XY%*Z",
},
"metadata": {
"a.b.c": 123,
"1.2.3": "abc",
Expand Down Expand Up @@ -182,7 +184,7 @@ def _insert_project_abc(session: Session) -> None:
attributes={
"attributes": "attributes",
"input": {
"value": "xyz",
"value": "xy%*z",
},
"retrieval": {
"documents": [
Expand Down
10 changes: 5 additions & 5 deletions tests/trace/dsl/test_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,15 +75,15 @@ def test_get_attribute_keys_list(expression: str, expected: Optional[List[str]])
[
(
"parent_id is not None and 'abc' in name or span_kind == 'LLM' and span_id in ('123',)", # noqa E501
"or_(and_(parent_id != None, name.contains('abc')), and_(span_kind == 'LLM', span_id.in_(('123',))))" # noqa E501
"or_(and_(parent_id != None, TextContains(name, 'abc')), and_(span_kind == 'LLM', span_id.in_(('123',))))" # noqa E501
if sys.version_info >= (3, 9)
else "or_(and_((parent_id != None), name.contains('abc')), and_((span_kind == 'LLM'), span_id.in_(('123',))))", # noqa E501
else "or_(and_((parent_id != None), TextContains(name, 'abc')), and_((span_kind == 'LLM'), span_id.in_(('123',))))", # noqa E501
),
(
"(parent_id is None or 'abc' not in name) and not (span_kind != 'LLM' or span_id not in ('123',))", # noqa E501
"and_(or_(parent_id == None, not_(name.contains('abc'))), not_(or_(span_kind != 'LLM', span_id.not_in(('123',)))))" # noqa E501
"and_(or_(parent_id == None, not_(TextContains(name, 'abc'))), not_(or_(span_kind != 'LLM', span_id.not_in(('123',)))))" # noqa E501
if sys.version_info >= (3, 9)
else "and_(or_((parent_id == None), not_(name.contains('abc'))), not_(or_((span_kind != 'LLM'), span_id.not_in(('123',)))))", # noqa E501
else "and_(or_((parent_id == None), not_(TextContains(name, 'abc'))), not_(or_((span_kind != 'LLM'), span_id.not_in(('123',)))))", # noqa E501
),
(
"1000 < latency_ms < 2000 or status_code == 'ERROR' or 2000 <= cumulative_llm_token_count_total", # noqa E501
Expand All @@ -99,7 +99,7 @@ def test_get_attribute_keys_list(expression: str, expected: Optional[List[str]])
),
(
"first.value in (1,) and second.value in ('2',) and '3' in third.value",
"and_(attributes[['first', 'value']].as_float().in_((1,)), attributes[['second', 'value']].as_string().in_(('2',)), attributes[['third', 'value']].as_string().contains('3'))", # noqa E501
"and_(attributes[['first', 'value']].as_float().in_((1,)), attributes[['second', 'value']].as_string().in_(('2',)), TextContains(attributes[['third', 'value']].as_string(), '3'))", # noqa E501
),
(
"'1.0' < my.value < 2.0",
Expand Down
Loading

0 comments on commit 3b202d7

Please sign in to comment.