From 97b55d1037df98f935f0995efeff11669b6ecdff Mon Sep 17 00:00:00 2001 From: Roger Yang Date: Wed, 13 Nov 2024 11:25:02 -0800 Subject: [PATCH] feat: only recognize session id on root span --- src/phoenix/db/insertion/span.py | 59 ++++++++++--------- .../project_sessions/test_project_sessions.py | 11 ++-- 2 files changed, 35 insertions(+), 35 deletions(-) diff --git a/src/phoenix/db/insertion/span.py b/src/phoenix/db/insertion/span.py index 7aa575d104..e9631c1e7e 100644 --- a/src/phoenix/db/insertion/span.py +++ b/src/phoenix/db/insertion/span.py @@ -37,36 +37,37 @@ async def insert_span( assert project_rowid is not None project_session: Optional[models.ProjectSession] = None - session_id = get_attribute_value(span.attributes, SpanAttributes.SESSION_ID) - session_user = get_attribute_value(span.attributes, SpanAttributes.USER_ID) - if session_id is not None and (not isinstance(session_id, str) or session_id.strip()): - session_id = str(session_id).strip() - assert isinstance(session_id, str) - if session_user is not None: - session_user = str(session_user).strip() - assert isinstance(session_user, str) - project_session = await session.scalar( - select(models.ProjectSession).filter_by(session_id=session_id) - ) - if project_session: - if project_session.end_time < span.end_time: - project_session.end_time = span.end_time - project_session.project_id = project_rowid - if span.start_time < project_session.start_time: - project_session.start_time = span.start_time - if session_user and project_session.session_user != session_user: - project_session.session_user = session_user - else: - project_session = models.ProjectSession( - project_id=project_rowid, - session_id=session_id, - session_user=session_user if session_user else None, - start_time=span.start_time, - end_time=span.end_time, + if span.parent_id is None: + session_id = get_attribute_value(span.attributes, SpanAttributes.SESSION_ID) + session_user = get_attribute_value(span.attributes, SpanAttributes.USER_ID) + if session_id is not None and (not isinstance(session_id, str) or session_id.strip()): + session_id = str(session_id).strip() + assert isinstance(session_id, str) + if session_user is not None: + session_user = str(session_user).strip() + assert isinstance(session_user, str) + project_session = await session.scalar( + select(models.ProjectSession).filter_by(session_id=session_id) ) - session.add(project_session) - if project_session in session.dirty: - await session.flush() + if project_session: + if project_session.end_time < span.end_time: + project_session.end_time = span.end_time + project_session.project_id = project_rowid + if span.start_time < project_session.start_time: + project_session.start_time = span.start_time + if session_user and project_session.session_user != session_user: + project_session.session_user = session_user + else: + project_session = models.ProjectSession( + project_id=project_rowid, + session_id=session_id, + session_user=session_user if session_user else None, + start_time=span.start_time, + end_time=span.end_time, + ) + session.add(project_session) + if project_session in session.dirty: + await session.flush() trace_id = span.context.trace_id trace = await session.scalar(select(models.Trace).filter_by(trace_id=trace_id)) diff --git a/tests/integration/project_sessions/test_project_sessions.py b/tests/integration/project_sessions/test_project_sessions.py index 185063e4d8..24516654c4 100644 --- a/tests/integration/project_sessions/test_project_sessions.py +++ b/tests/integration/project_sessions/test_project_sessions.py @@ -41,18 +41,16 @@ def test_span_ingestion_with_session_id( assert num_traces > 1 and num_spans_per_trace > 2 project_names = [token_hex(8)] spans: List[Span] = [] + wrong_session_id = token_hex(32) for _ in range(num_traces): project_names.append(token_hex(8)) with ExitStack() as stack: for i in range(num_spans_per_trace): if i == 0: - # Not all spans are required to have `session_id`. - attributes = None - elif i == 1: - # In case of conflict, the `Span` with later `end_time` wins. attributes = {SpanAttributes.SESSION_ID: session_id} - elif str_session_id: - attributes = {SpanAttributes.SESSION_ID: token_hex(8)} + else: + # Session ID on non-root spans will be ignored. + attributes = {SpanAttributes.SESSION_ID: wrong_session_id} span = _start_span( project_name=project_names[-1], exporter=_grpc_span_exporter(), @@ -84,6 +82,7 @@ def test_span_ingestion_with_session_id( assert not sessions_by_id return assert sessions_by_id + assert wrong_session_id not in sessions_by_id assert (session := sessions_by_id.get(str_session_id)) assert (traces := [edge["node"] for edge in session["node"]["traces"]["edges"]]) assert len(traces) == num_traces