Skip to content

Commit

Permalink
feat(sessions): alembic data migration queries for populating the pro…
Browse files Browse the repository at this point in the history
…ject sessions table (#5539)
  • Loading branch information
RogerHYang committed Dec 9, 2024
1 parent 16d9edd commit 50f5794
Show file tree
Hide file tree
Showing 5 changed files with 452 additions and 36 deletions.
86 changes: 85 additions & 1 deletion scripts/fixtures/multi-turn_chat_sessions.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"source": [
"from contextlib import ExitStack, contextmanager\n",
"from random import choice, choices, randint, random, shuffle\n",
"from uuid import uuid4\n",
"\n",
"import numpy as np\n",
"import openai\n",
Expand All @@ -33,9 +34,11 @@
"from opentelemetry.sdk.trace import SpanLimits, StatusCode, TracerProvider\n",
"from opentelemetry.sdk.trace.export import SimpleSpanProcessor\n",
"from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter\n",
"from opentelemetry.trace import format_span_id\n",
"from tiktoken import encoding_for_model\n",
"\n",
"import phoenix as px\n",
"from phoenix.trace import using_project\n",
"from phoenix.trace.span_evaluations import SpanEvaluations\n",
"\n",
"fake = Faker([\"ja_JP\", \"vi_VN\", \"ko_KR\", \"zh_CN\", \"th_TH\", \"bn_BD\"])"
Expand Down Expand Up @@ -168,12 +171,93 @@
" root.end(int(fake.future_datetime(\"+5s\").timestamp() * 10**9))"
]
},
{
"cell_type": "markdown",
"id": "b922a0c7",
"metadata": {},
"source": [
"# Generate Sessions (For Demos)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "98a91087",
"metadata": {},
"outputs": [],
"source": [
"session_count = randint(5, 10)\n",
"project_name = \"Sessions-Fixtures\"\n",
"\n",
"\n",
"def simulate_openai():\n",
" user_id = Faker().user_name()\n",
" session_id = str(uuid4())\n",
" client = openai.Client(api_key=\"sk-\")\n",
" model = \"gpt-4o-mini\"\n",
" encoding = encoding_for_model(model)\n",
" messages = np.concatenate(convo.sample(randint(1, 10)).values)\n",
" counts = [len(encoding.encode(m[\"content\"])) for m in messages]\n",
" openai_mock = OpenAIMock()\n",
" tracer = tracer_provider.get_tracer(__name__)\n",
" with openai_mock.router:\n",
" for i in range(1, len(messages), 2):\n",
" openai_mock.chat.completions.create.response = dict(\n",
" choices=[dict(index=0, finish_reason=\"stop\", message=messages[i])],\n",
" usage=dict(\n",
" prompt_tokens=sum(counts[:i]),\n",
" completion_tokens=counts[i],\n",
" total_tokens=sum(counts[: i + 1]),\n",
" ),\n",
" )\n",
" with ExitStack() as stack:\n",
" attributes = {\n",
" \"input.value\": messages[i - 1][\"content\"],\n",
" \"output.value\": messages[i][\"content\"],\n",
" \"session.id\": session_id,\n",
" \"user.id\": user_id,\n",
" }\n",
" root = stack.enter_context(\n",
" tracer.start_as_current_span(\n",
" \"root\",\n",
" attributes=attributes,\n",
" )\n",
" )\n",
" client.chat.completions.create(model=model, messages=messages[:i])\n",
" root.set_status(StatusCode.OK)\n",
"\n",
"\n",
"OpenAIInstrumentor().instrument(tracer_provider=tracer_provider)\n",
"try:\n",
" with using_project(project_name):\n",
" for _ in range(session_count):\n",
" simulate_openai()\n",
"finally:\n",
" OpenAIInstrumentor().uninstrument()\n",
"spans = export_spans(0)\n",
"\n",
"# Annotate root spans\n",
"root_span_ids = pd.Series(\n",
" [format_span_id(span.context.span_id) for span in spans if span.parent is None]\n",
")\n",
"for name in [\"Helpfulness\", \"Relevance\", \"Engagement\"]:\n",
" span_ids = root_span_ids.sample(frac=0.5)\n",
" df = pd.DataFrame(\n",
" {\n",
" \"context.span_id\": span_ids,\n",
" \"score\": np.random.rand(len(span_ids)),\n",
" \"label\": np.random.choice([\"👍\", \"👎\"], len(span_ids)),\n",
" }\n",
" ).set_index(\"context.span_id\")\n",
" px.Client().log_evaluations(SpanEvaluations(name, df))"
]
},
{
"cell_type": "markdown",
"id": "a2f2ac17",
"metadata": {},
"source": [
"# Genarate Sessions\n"
"# Genarate Sessions (For Development)"
]
},
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,75 @@
"""

from typing import Sequence, Union
from datetime import datetime
from typing import Any, Optional, Sequence, Union

import sqlalchemy as sa
from alembic import op
from openinference.semconv.trace import SpanAttributes
from sqlalchemy import (
JSON,
func,
insert,
select,
update,
)
from sqlalchemy.dialects import postgresql
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column


class JSONB(JSON):
# See https://docs.sqlalchemy.org/en/20/core/custom_types.html
__visit_name__ = "JSONB"


@compiles(JSONB, "sqlite")
def _(*args: Any, **kwargs: Any) -> str:
# See https://docs.sqlalchemy.org/en/20/core/custom_types.html
return "JSONB"


JSON_ = (
JSON()
.with_variant(
postgresql.JSONB(), # type: ignore
"postgresql",
)
.with_variant(
JSONB(),
"sqlite",
)
)


class Base(DeclarativeBase): ...


class ProjectSession(Base):
__tablename__ = "project_sessions"
id: Mapped[int] = mapped_column(primary_key=True)
session_id: Mapped[str]
session_user: Mapped[Optional[str]]
project_id: Mapped[int]
start_time: Mapped[datetime]


class Trace(Base):
__tablename__ = "traces"
id: Mapped[int] = mapped_column(primary_key=True)
project_session_rowid: Mapped[Union[int, None]]
project_rowid: Mapped[int]
start_time: Mapped[datetime]


class Span(Base):
__tablename__ = "spans"
id: Mapped[int] = mapped_column(primary_key=True)
trace_rowid: Mapped[int]
parent_id: Mapped[Optional[str]]
attributes: Mapped[dict[str, Any]] = mapped_column(JSON_, nullable=False)


# revision identifiers, used by Alembic.
revision: str = "4ded9e43755f"
Expand Down Expand Up @@ -47,10 +112,69 @@ def upgrade() -> None:
"traces",
["project_session_rowid"],
)
sessions_from_span = (
select(
Span.attributes[SESSION_ID].as_string().label("session_id"),
Span.attributes[USER_ID].as_string().label("session_user"),
Trace.project_rowid.label("project_id"),
Trace.start_time.label("start_time"),
func.row_number()
.over(
partition_by=Span.attributes[SESSION_ID],
order_by=[Trace.start_time, Trace.id, Span.id],
)
.label("rank"),
)
.join_from(Span, Trace, Span.trace_rowid == Trace.id)
.where(Span.parent_id.is_(None))
.where(Span.attributes[SESSION_ID].as_string() != "")
.subquery()
)
op.execute(
insert(ProjectSession).from_select(
[
"session_id",
"session_user",
"project_id",
"start_time",
],
select(
sessions_from_span.c.session_id,
sessions_from_span.c.session_user,
sessions_from_span.c.project_id,
sessions_from_span.c.start_time,
).where(sessions_from_span.c.rank == 1),
)
)
sessions_for_trace_id = (
select(
Span.trace_rowid,
ProjectSession.id.label("project_session_rowid"),
)
.join_from(
Span,
ProjectSession,
Span.attributes[SESSION_ID].as_string() == ProjectSession.session_id,
)
.where(Span.parent_id.is_(None))
.where(Span.attributes[SESSION_ID].as_string() != "")
.subquery()
)
op.execute(
(
update(Trace)
.values(project_session_rowid=sessions_for_trace_id.c.project_session_rowid)
.where(Trace.id == sessions_for_trace_id.c.trace_rowid)
)
)


def downgrade() -> None:
op.drop_index("ix_traces_project_session_rowid")
with op.batch_alter_table("traces") as batch_op:
batch_op.drop_column("project_session_rowid")
op.drop_table("project_sessions")


SESSION_ID = SpanAttributes.SESSION_ID.split(".")
USER_ID = SpanAttributes.USER_ID.split(".")
34 changes: 34 additions & 0 deletions tests/integration/db_migrations/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import os
from typing import Optional

from alembic import command
from alembic.config import Config
from phoenix.config import ENV_PHOENIX_SQL_DATABASE_SCHEMA
from sqlalchemy import Engine, Row, text


def _up(engine: Engine, alembic_config: Config, revision: str) -> None:
with engine.connect() as conn:
alembic_config.attributes["connection"] = conn
command.upgrade(alembic_config, revision)
engine.dispose()
assert _version_num(engine) == (revision,)


def _down(engine: Engine, alembic_config: Config, revision: str) -> None:
with engine.connect() as conn:
alembic_config.attributes["connection"] = conn
command.downgrade(alembic_config, revision)
engine.dispose()
assert _version_num(engine) == (None if revision == "base" else (revision,))


def _version_num(engine: Engine) -> Optional[Row[tuple[str]]]:
schema_prefix = ""
if engine.url.get_backend_name().startswith("postgresql"):
assert (schema := os.environ[ENV_PHOENIX_SQL_DATABASE_SCHEMA])
schema_prefix = f"{schema}."
table, column = "alembic_version", "version_num"
stmt = text(f"SELECT {column} FROM {schema_prefix}{table}")
with engine.connect() as conn:
return conn.execute(stmt).first()
Loading

0 comments on commit 50f5794

Please sign in to comment.