diff --git a/app/schema.graphql b/app/schema.graphql index 9a66772638..d23a7312f1 100644 --- a/app/schema.graphql +++ b/app/schema.graphql @@ -1325,6 +1325,16 @@ type ProjectSession implements Node { """The Globally Unique ID of this object""" id: GlobalID! sessionId: String! + sessionUser: String + startTime: DateTime! + endTime: DateTime! + + """Duration of the session in seconds""" + durationS: Float! + numTraces: Int! + firstInput: SpanIOValue + lastOutput: SpanIOValue + tokenUsage: TokenUsage! traces(first: Int = 50, last: Int, after: String, before: String): TraceConnection! } @@ -1706,6 +1716,12 @@ type TimeSeriesDataPoint { value: Float } +type TokenUsage { + prompt: Int! + completion: Int! + total: Int! +} + type ToolCallChunk implements ChatCompletionSubscriptionPayload { datasetExampleId: GlobalID id: String! @@ -1716,7 +1732,12 @@ type Trace implements Node { """The Globally Unique ID of this object""" id: GlobalID! traceId: String! + startTime: DateTime! + endTime: DateTime! projectId: GlobalID! + projectSessionId: GlobalID + session: ProjectSession + rootSpan: Span spans(first: Int = 50, last: Int, after: String, before: String): SpanConnection! """Annotations associated with the trace.""" diff --git a/app/src/Routes.tsx b/app/src/Routes.tsx index b56d696c0b..7e77af3035 100644 --- a/app/src/Routes.tsx +++ b/app/src/Routes.tsx @@ -8,6 +8,7 @@ import { Layout } from "./pages/Layout"; import { spanPlaygroundPageLoaderQuery$data } from "./pages/playground/__generated__/spanPlaygroundPageLoaderQuery.graphql"; import { PlaygroundExamplePage } from "./pages/playground/PlaygroundExamplePage"; import { projectLoaderQuery$data } from "./pages/project/__generated__/projectLoaderQuery.graphql"; +import { SessionPage } from "./pages/trace/SessionPage"; import { APIsPage, AuthenticatedRoot, @@ -118,6 +119,7 @@ const router = createBrowserRouter( } /> }> } /> + } /> diff --git a/app/src/pages/project/ProjectPage.tsx b/app/src/pages/project/ProjectPage.tsx index 57847f41b2..5017fcf47b 100644 --- a/app/src/pages/project/ProjectPage.tsx +++ b/app/src/pages/project/ProjectPage.tsx @@ -19,8 +19,10 @@ import { } from "@phoenix/components/datetime"; import { ProjectPageQuery } from "./__generated__/ProjectPageQuery.graphql"; +import { ProjectPageSessionsQuery as ProjectPageSessionsQueryType } from "./__generated__/ProjectPageSessionsQuery.graphql"; import { ProjectPageSpansQuery as ProjectPageSpansQueryType } from "./__generated__/ProjectPageSpansQuery.graphql"; import { ProjectPageHeader } from "./ProjectPageHeader"; +import { SessionsTable } from "./SessionsTable"; import { SpanFilterConditionProvider } from "./SpanFilterConditionContext"; import { SpansTable } from "./SpansTable"; import { StreamToggle } from "./StreamToggle"; @@ -75,6 +77,14 @@ const ProjectPageSpansQuery = graphql` } `; +const ProjectPageSessionsQuery = graphql` + query ProjectPageSessionsQuery($id: GlobalID!, $timeRange: TimeRange!) { + project: node(id: $id) { + ...SessionsTable_sessions + } + } +`; + export function ProjectPageContent({ projectId, timeRange, @@ -109,18 +119,35 @@ export function ProjectPageContent({ ); const [spansQueryReference, loadSpansQuery, disposeSpansQuery] = useQueryLoader(ProjectPageSpansQuery); + const [sessionsQueryReference, loadSessionsQuery, disposeSessionsQuery] = + useQueryLoader(ProjectPageSessionsQuery); const onTabChange = useCallback( (index: number) => { if (index === 1) { + disposeSessionsQuery(); loadSpansQuery({ id: projectId as string, timeRange: timeRangeVariable, }); + } else if (index === 2) { + disposeSpansQuery(); + loadSessionsQuery({ + id: projectId as string, + timeRange: timeRangeVariable, + }); } else { disposeSpansQuery(); + disposeSessionsQuery(); } }, - [disposeSpansQuery, loadSpansQuery, projectId, timeRangeVariable] + [ + disposeSpansQuery, + loadSpansQuery, + disposeSessionsQuery, + loadSessionsQuery, + projectId, + timeRangeVariable, + ] ); return (
@@ -161,6 +188,22 @@ export function ProjectPageContent({ ); }} + + {({ isSelected }) => { + return ( + isSelected && + sessionsQueryReference && ( + + + + + + ) + ); + }} + @@ -177,3 +220,12 @@ function SpansTabContent({ const data = usePreloadedQuery(ProjectPageSpansQuery, queryReference); return ; } + +function SessionsTabContent({ + queryReference, +}: { + queryReference: PreloadedQuery; +}) { + const data = usePreloadedQuery(ProjectPageSessionsQuery, queryReference); + return ; +} diff --git a/app/src/pages/project/SessionsTable.tsx b/app/src/pages/project/SessionsTable.tsx new file mode 100644 index 0000000000..7d1daf51a0 --- /dev/null +++ b/app/src/pages/project/SessionsTable.tsx @@ -0,0 +1,294 @@ +/* eslint-disable react/prop-types */ +import React, { + startTransition, + useEffect, + useMemo, + useRef, + useState, +} from "react"; +import { graphql, usePaginationFragment } from "react-relay"; +import { useNavigate } from "react-router"; +import { + ColumnDef, + ExpandedState, + flexRender, + getCoreRowModel, + getExpandedRowModel, + getSortedRowModel, + SortingState, + useReactTable, +} from "@tanstack/react-table"; +import { css } from "@emotion/react"; + +import { Icon, Icons, View } from "@arizeai/components"; + +import { selectableTableCSS } from "@phoenix/components/table/styles"; +import { TimestampCell } from "@phoenix/components/table/TimestampCell"; +import { useStreamState } from "@phoenix/contexts/StreamStateContext"; +import { useTracingContext } from "@phoenix/contexts/TracingContext"; + +import { IntCell, TextCell } from "../../components/table"; +import { TokenCount } from "../../components/trace/TokenCount"; + +import { SessionsTable_sessions$key } from "./__generated__/SessionsTable_sessions.graphql"; +import { SessionsTableQuery } from "./__generated__/SessionsTableQuery.graphql"; +import { SessionsTableEmpty } from "./SessionsTableEmpty"; +import { SpanFilterConditionField } from "./SpanFilterConditionField"; +import { spansTableCSS } from "./styles"; + +type SessionsTableProps = { + project: SessionsTable_sessions$key; +}; + +const PAGE_SIZE = 50; + +export function SessionsTable(props: SessionsTableProps) { + // we need a reference to the scrolling element for pagination logic down below + const tableContainerRef = useRef(null); + const [sorting, setSorting] = useState([]); + const [filterCondition, setFilterCondition] = useState(""); + const navigate = useNavigate(); + const { fetchKey } = useStreamState(); + const { data, loadNext, hasNext, isLoadingNext, refetch } = + usePaginationFragment( + graphql` + fragment SessionsTable_sessions on Project + @refetchable(queryName: "SessionsTableQuery") + @argumentDefinitions( + after: { type: "String", defaultValue: null } + first: { type: "Int", defaultValue: 50 } + ) { + name + sessions(first: $first, after: $after, timeRange: $timeRange) + @connection(key: "SessionsTable_sessions") { + edges { + session: node { + id + sessionId + numTraces + startTime + endTime + firstInput { + value + } + lastOutput { + value + } + tokenUsage { + prompt + completion + total + } + } + } + } + } + `, + props.project + ); + const tableData = useMemo(() => { + return data.sessions.edges.map(({ session }) => session); + }, [data]); + type TableRow = (typeof tableData)[number]; + const columns: ColumnDef[] = [ + { + header: "session id", + accessorKey: "sessionId", + enableSorting: false, + cell: TextCell, + }, + { + header: "first input", + accessorKey: "firstInput.value", + enableSorting: false, + cell: TextCell, + }, + { + header: "last output", + accessorKey: "lastOutput.value", + enableSorting: false, + cell: TextCell, + }, + { + header: "start time", + accessorKey: "startTime", + enableSorting: false, + cell: TimestampCell, + }, + { + header: "end time", + accessorKey: "endTime", + enableSorting: false, + cell: TimestampCell, + }, + { + header: "total tokens", + accessorKey: "tokenUsage.total", + enableSorting: false, + minSize: 80, + cell: ({ row, getValue }) => { + const value = getValue(); + if (value == null || typeof value !== "number") { + return "--"; + } + const { prompt, completion } = row.original.tokenUsage; + return ( + + ); + }, + }, + { + header: "total traces", + accessorKey: "numTraces", + enableSorting: false, + cell: IntCell, + }, + ]; + useEffect(() => { + startTransition(() => { + refetch( + { + after: null, + first: PAGE_SIZE, + }, + { fetchPolicy: "store-and-network" } + ); + }); + }, [sorting, refetch, filterCondition, fetchKey]); + const fetchMoreOnBottomReached = React.useCallback( + (containerRefElement?: HTMLDivElement | null) => { + if (containerRefElement) { + const { scrollHeight, scrollTop, clientHeight } = containerRefElement; + // once the user has scrolled within 300px of the bottom of the table, fetch more data if there is any + if ( + scrollHeight - scrollTop - clientHeight < 300 && + !isLoadingNext && + hasNext + ) { + loadNext(PAGE_SIZE); + } + } + }, + [hasNext, isLoadingNext, loadNext] + ); + const [expanded, setExpanded] = useState({}); + const columnVisibility = useTracingContext((state) => state.columnVisibility); + const table = useReactTable({ + columns, + data: tableData, + onExpandedChange: setExpanded, + manualSorting: true, + state: { + sorting, + expanded, + columnVisibility, + }, + enableSubRowSelection: false, + onSortingChange: setSorting, + getCoreRowModel: getCoreRowModel(), + getSortedRowModel: getSortedRowModel(), + getExpandedRowModel: getExpandedRowModel(), + getRowId: (row) => row.id, + }); + const rows = table.getRowModel().rows; + const isEmpty = rows.length === 0; + return ( +
+ + + +
fetchMoreOnBottomReached(e.target as HTMLDivElement)} + ref={tableContainerRef} + > + + + {table.getHeaderGroups().map((headerGroup) => ( + + {headerGroup.headers.map((header) => ( + + ))} + + ))} + + {isEmpty ? ( + + ) : ( + + {rows.map((row) => { + return ( + + navigate(`sessions/${encodeURIComponent(row.id)}`) + } + > + {row.getVisibleCells().map((cell) => { + return ( + + ); + })} + + ); + })} + + )} +
+ {header.isPlaceholder ? null : ( +
+ {flexRender( + header.column.columnDef.header, + header.getContext() + )} + {header.column.getIsSorted() ? ( + + ) : ( + + ) + } + /> + ) : null} +
+ )} +
+ {flexRender( + cell.column.columnDef.cell, + cell.getContext() + )} +
+
+
+ ); +} diff --git a/app/src/pages/project/SessionsTableEmpty.tsx b/app/src/pages/project/SessionsTableEmpty.tsx new file mode 100644 index 0000000000..a35e5cee5c --- /dev/null +++ b/app/src/pages/project/SessionsTableEmpty.tsx @@ -0,0 +1,24 @@ +import React from "react"; +import { css } from "@emotion/react"; + +import { Flex } from "@arizeai/components"; + +export function SessionsTableEmpty() { + return ( + + + css` + text-align: center; + padding: ${theme.spacing.margin24}px ${theme.spacing.margin24}px !important; + `} + > + + No sessions found for this project + + + + + ); +} diff --git a/app/src/pages/project/__generated__/ProjectPageSessionsQuery.graphql.ts b/app/src/pages/project/__generated__/ProjectPageSessionsQuery.graphql.ts new file mode 100644 index 0000000000..37dab6d571 --- /dev/null +++ b/app/src/pages/project/__generated__/ProjectPageSessionsQuery.graphql.ts @@ -0,0 +1,334 @@ +/** + * @generated SignedSource<<2ec99b88de5f2bfca5562d9f2bfa9bc2>> + * @lightSyntaxTransform + * @nogrep + */ + +/* tslint:disable */ +/* eslint-disable */ +// @ts-nocheck + +import { ConcreteRequest, Query } from 'relay-runtime'; +import { FragmentRefs } from "relay-runtime"; +export type TimeRange = { + end: string; + start: string; +}; +export type ProjectPageSessionsQuery$variables = { + id: string; + timeRange: TimeRange; +}; +export type ProjectPageSessionsQuery$data = { + readonly project: { + readonly " $fragmentSpreads": FragmentRefs<"SessionsTable_sessions">; + }; +}; +export type ProjectPageSessionsQuery = { + response: ProjectPageSessionsQuery$data; + variables: ProjectPageSessionsQuery$variables; +}; + +const node: ConcreteRequest = (function(){ +var v0 = [ + { + "defaultValue": null, + "kind": "LocalArgument", + "name": "id" + }, + { + "defaultValue": null, + "kind": "LocalArgument", + "name": "timeRange" + } +], +v1 = [ + { + "kind": "Variable", + "name": "id", + "variableName": "id" + } +], +v2 = { + "alias": null, + "args": null, + "kind": "ScalarField", + "name": "__typename", + "storageKey": null +}, +v3 = { + "alias": null, + "args": null, + "kind": "ScalarField", + "name": "id", + "storageKey": null +}, +v4 = [ + { + "kind": "Literal", + "name": "first", + "value": 50 + }, + { + "kind": "Variable", + "name": "timeRange", + "variableName": "timeRange" + } +], +v5 = [ + { + "alias": null, + "args": null, + "kind": "ScalarField", + "name": "value", + "storageKey": null + } +]; +return { + "fragment": { + "argumentDefinitions": (v0/*: any*/), + "kind": "Fragment", + "metadata": null, + "name": "ProjectPageSessionsQuery", + "selections": [ + { + "alias": "project", + "args": (v1/*: any*/), + "concreteType": null, + "kind": "LinkedField", + "name": "node", + "plural": false, + "selections": [ + { + "args": null, + "kind": "FragmentSpread", + "name": "SessionsTable_sessions" + } + ], + "storageKey": null + } + ], + "type": "Query", + "abstractKey": null + }, + "kind": "Request", + "operation": { + "argumentDefinitions": (v0/*: any*/), + "kind": "Operation", + "name": "ProjectPageSessionsQuery", + "selections": [ + { + "alias": "project", + "args": (v1/*: any*/), + "concreteType": null, + "kind": "LinkedField", + "name": "node", + "plural": false, + "selections": [ + (v2/*: any*/), + { + "kind": "TypeDiscriminator", + "abstractKey": "__isNode" + }, + (v3/*: any*/), + { + "kind": "InlineFragment", + "selections": [ + { + "alias": null, + "args": null, + "kind": "ScalarField", + "name": "name", + "storageKey": null + }, + { + "alias": null, + "args": (v4/*: any*/), + "concreteType": "ProjectSessionConnection", + "kind": "LinkedField", + "name": "sessions", + "plural": false, + "selections": [ + { + "alias": null, + "args": null, + "concreteType": "ProjectSessionEdge", + "kind": "LinkedField", + "name": "edges", + "plural": true, + "selections": [ + { + "alias": "session", + "args": null, + "concreteType": "ProjectSession", + "kind": "LinkedField", + "name": "node", + "plural": false, + "selections": [ + (v3/*: any*/), + { + "alias": null, + "args": null, + "kind": "ScalarField", + "name": "sessionId", + "storageKey": null + }, + { + "alias": null, + "args": null, + "kind": "ScalarField", + "name": "numTraces", + "storageKey": null + }, + { + "alias": null, + "args": null, + "kind": "ScalarField", + "name": "startTime", + "storageKey": null + }, + { + "alias": null, + "args": null, + "kind": "ScalarField", + "name": "endTime", + "storageKey": null + }, + { + "alias": null, + "args": null, + "concreteType": "SpanIOValue", + "kind": "LinkedField", + "name": "firstInput", + "plural": false, + "selections": (v5/*: any*/), + "storageKey": null + }, + { + "alias": null, + "args": null, + "concreteType": "SpanIOValue", + "kind": "LinkedField", + "name": "lastOutput", + "plural": false, + "selections": (v5/*: any*/), + "storageKey": null + }, + { + "alias": null, + "args": null, + "concreteType": "TokenUsage", + "kind": "LinkedField", + "name": "tokenUsage", + "plural": false, + "selections": [ + { + "alias": null, + "args": null, + "kind": "ScalarField", + "name": "prompt", + "storageKey": null + }, + { + "alias": null, + "args": null, + "kind": "ScalarField", + "name": "completion", + "storageKey": null + }, + { + "alias": null, + "args": null, + "kind": "ScalarField", + "name": "total", + "storageKey": null + } + ], + "storageKey": null + } + ], + "storageKey": null + }, + { + "alias": null, + "args": null, + "kind": "ScalarField", + "name": "cursor", + "storageKey": null + }, + { + "alias": null, + "args": null, + "concreteType": "ProjectSession", + "kind": "LinkedField", + "name": "node", + "plural": false, + "selections": [ + (v2/*: any*/) + ], + "storageKey": null + } + ], + "storageKey": null + }, + { + "alias": null, + "args": null, + "concreteType": "PageInfo", + "kind": "LinkedField", + "name": "pageInfo", + "plural": false, + "selections": [ + { + "alias": null, + "args": null, + "kind": "ScalarField", + "name": "endCursor", + "storageKey": null + }, + { + "alias": null, + "args": null, + "kind": "ScalarField", + "name": "hasNextPage", + "storageKey": null + } + ], + "storageKey": null + } + ], + "storageKey": null + }, + { + "alias": null, + "args": (v4/*: any*/), + "filters": [ + "timeRange" + ], + "handle": "connection", + "key": "SessionsTable_sessions", + "kind": "LinkedHandle", + "name": "sessions" + } + ], + "type": "Project", + "abstractKey": null + } + ], + "storageKey": null + } + ] + }, + "params": { + "cacheID": "84ec25aac21d8b45bb3c4481ed30c6e7", + "id": null, + "metadata": {}, + "name": "ProjectPageSessionsQuery", + "operationKind": "query", + "text": "query ProjectPageSessionsQuery(\n $id: GlobalID!\n $timeRange: TimeRange!\n) {\n project: node(id: $id) {\n __typename\n ...SessionsTable_sessions\n __isNode: __typename\n id\n }\n}\n\nfragment SessionsTable_sessions on Project {\n name\n sessions(first: 50, timeRange: $timeRange) {\n edges {\n session: node {\n id\n sessionId\n numTraces\n startTime\n endTime\n firstInput {\n value\n }\n lastOutput {\n value\n }\n tokenUsage {\n prompt\n completion\n total\n }\n }\n cursor\n node {\n __typename\n }\n }\n pageInfo {\n endCursor\n hasNextPage\n }\n }\n id\n}\n" + } +}; +})(); + +(node as any).hash = "58a3bc9becb7a255676ca0efafb8d0b7"; + +export default node; diff --git a/app/src/pages/project/__generated__/SessionsTableQuery.graphql.ts b/app/src/pages/project/__generated__/SessionsTableQuery.graphql.ts new file mode 100644 index 0000000000..f652065888 --- /dev/null +++ b/app/src/pages/project/__generated__/SessionsTableQuery.graphql.ts @@ -0,0 +1,364 @@ +/** + * @generated SignedSource<> + * @lightSyntaxTransform + * @nogrep + */ + +/* tslint:disable */ +/* eslint-disable */ +// @ts-nocheck + +import { ConcreteRequest, Query } from 'relay-runtime'; +import { FragmentRefs } from "relay-runtime"; +export type TimeRange = { + end: string; + start: string; +}; +export type SessionsTableQuery$variables = { + after?: string | null; + first?: number | null; + id: string; + timeRange?: TimeRange | null; +}; +export type SessionsTableQuery$data = { + readonly node: { + readonly " $fragmentSpreads": FragmentRefs<"SessionsTable_sessions">; + }; +}; +export type SessionsTableQuery = { + response: SessionsTableQuery$data; + variables: SessionsTableQuery$variables; +}; + +const node: ConcreteRequest = (function(){ +var v0 = { + "defaultValue": null, + "kind": "LocalArgument", + "name": "after" +}, +v1 = { + "defaultValue": 50, + "kind": "LocalArgument", + "name": "first" +}, +v2 = { + "defaultValue": null, + "kind": "LocalArgument", + "name": "id" +}, +v3 = { + "defaultValue": null, + "kind": "LocalArgument", + "name": "timeRange" +}, +v4 = [ + { + "kind": "Variable", + "name": "id", + "variableName": "id" + } +], +v5 = { + "kind": "Variable", + "name": "after", + "variableName": "after" +}, +v6 = { + "kind": "Variable", + "name": "first", + "variableName": "first" +}, +v7 = { + "alias": null, + "args": null, + "kind": "ScalarField", + "name": "__typename", + "storageKey": null +}, +v8 = { + "alias": null, + "args": null, + "kind": "ScalarField", + "name": "id", + "storageKey": null +}, +v9 = [ + (v5/*: any*/), + (v6/*: any*/), + { + "kind": "Variable", + "name": "timeRange", + "variableName": "timeRange" + } +], +v10 = [ + { + "alias": null, + "args": null, + "kind": "ScalarField", + "name": "value", + "storageKey": null + } +]; +return { + "fragment": { + "argumentDefinitions": [ + (v0/*: any*/), + (v1/*: any*/), + (v2/*: any*/), + (v3/*: any*/) + ], + "kind": "Fragment", + "metadata": null, + "name": "SessionsTableQuery", + "selections": [ + { + "alias": null, + "args": (v4/*: any*/), + "concreteType": null, + "kind": "LinkedField", + "name": "node", + "plural": false, + "selections": [ + { + "args": [ + (v5/*: any*/), + (v6/*: any*/) + ], + "kind": "FragmentSpread", + "name": "SessionsTable_sessions" + } + ], + "storageKey": null + } + ], + "type": "Query", + "abstractKey": null + }, + "kind": "Request", + "operation": { + "argumentDefinitions": [ + (v0/*: any*/), + (v1/*: any*/), + (v3/*: any*/), + (v2/*: any*/) + ], + "kind": "Operation", + "name": "SessionsTableQuery", + "selections": [ + { + "alias": null, + "args": (v4/*: any*/), + "concreteType": null, + "kind": "LinkedField", + "name": "node", + "plural": false, + "selections": [ + (v7/*: any*/), + { + "kind": "TypeDiscriminator", + "abstractKey": "__isNode" + }, + (v8/*: any*/), + { + "kind": "InlineFragment", + "selections": [ + { + "alias": null, + "args": null, + "kind": "ScalarField", + "name": "name", + "storageKey": null + }, + { + "alias": null, + "args": (v9/*: any*/), + "concreteType": "ProjectSessionConnection", + "kind": "LinkedField", + "name": "sessions", + "plural": false, + "selections": [ + { + "alias": null, + "args": null, + "concreteType": "ProjectSessionEdge", + "kind": "LinkedField", + "name": "edges", + "plural": true, + "selections": [ + { + "alias": "session", + "args": null, + "concreteType": "ProjectSession", + "kind": "LinkedField", + "name": "node", + "plural": false, + "selections": [ + (v8/*: any*/), + { + "alias": null, + "args": null, + "kind": "ScalarField", + "name": "sessionId", + "storageKey": null + }, + { + "alias": null, + "args": null, + "kind": "ScalarField", + "name": "numTraces", + "storageKey": null + }, + { + "alias": null, + "args": null, + "kind": "ScalarField", + "name": "startTime", + "storageKey": null + }, + { + "alias": null, + "args": null, + "kind": "ScalarField", + "name": "endTime", + "storageKey": null + }, + { + "alias": null, + "args": null, + "concreteType": "SpanIOValue", + "kind": "LinkedField", + "name": "firstInput", + "plural": false, + "selections": (v10/*: any*/), + "storageKey": null + }, + { + "alias": null, + "args": null, + "concreteType": "SpanIOValue", + "kind": "LinkedField", + "name": "lastOutput", + "plural": false, + "selections": (v10/*: any*/), + "storageKey": null + }, + { + "alias": null, + "args": null, + "concreteType": "TokenUsage", + "kind": "LinkedField", + "name": "tokenUsage", + "plural": false, + "selections": [ + { + "alias": null, + "args": null, + "kind": "ScalarField", + "name": "prompt", + "storageKey": null + }, + { + "alias": null, + "args": null, + "kind": "ScalarField", + "name": "completion", + "storageKey": null + }, + { + "alias": null, + "args": null, + "kind": "ScalarField", + "name": "total", + "storageKey": null + } + ], + "storageKey": null + } + ], + "storageKey": null + }, + { + "alias": null, + "args": null, + "kind": "ScalarField", + "name": "cursor", + "storageKey": null + }, + { + "alias": null, + "args": null, + "concreteType": "ProjectSession", + "kind": "LinkedField", + "name": "node", + "plural": false, + "selections": [ + (v7/*: any*/) + ], + "storageKey": null + } + ], + "storageKey": null + }, + { + "alias": null, + "args": null, + "concreteType": "PageInfo", + "kind": "LinkedField", + "name": "pageInfo", + "plural": false, + "selections": [ + { + "alias": null, + "args": null, + "kind": "ScalarField", + "name": "endCursor", + "storageKey": null + }, + { + "alias": null, + "args": null, + "kind": "ScalarField", + "name": "hasNextPage", + "storageKey": null + } + ], + "storageKey": null + } + ], + "storageKey": null + }, + { + "alias": null, + "args": (v9/*: any*/), + "filters": [ + "timeRange" + ], + "handle": "connection", + "key": "SessionsTable_sessions", + "kind": "LinkedHandle", + "name": "sessions" + } + ], + "type": "Project", + "abstractKey": null + } + ], + "storageKey": null + } + ] + }, + "params": { + "cacheID": "eae7b869a03863c68f10ff3d78c3f46b", + "id": null, + "metadata": {}, + "name": "SessionsTableQuery", + "operationKind": "query", + "text": "query SessionsTableQuery(\n $after: String = null\n $first: Int = 50\n $timeRange: TimeRange\n $id: GlobalID!\n) {\n node(id: $id) {\n __typename\n ...SessionsTable_sessions_2HEEH6\n __isNode: __typename\n id\n }\n}\n\nfragment SessionsTable_sessions_2HEEH6 on Project {\n name\n sessions(first: $first, after: $after, timeRange: $timeRange) {\n edges {\n session: node {\n id\n sessionId\n numTraces\n startTime\n endTime\n firstInput {\n value\n }\n lastOutput {\n value\n }\n tokenUsage {\n prompt\n completion\n total\n }\n }\n cursor\n node {\n __typename\n }\n }\n pageInfo {\n endCursor\n hasNextPage\n }\n }\n id\n}\n" + } +}; +})(); + +(node as any).hash = "ffd50d06a86cb2efbd63be2f7e658dbf"; + +export default node; diff --git a/app/src/pages/project/__generated__/SessionsTable_sessions.graphql.ts b/app/src/pages/project/__generated__/SessionsTable_sessions.graphql.ts new file mode 100644 index 0000000000..06fe1abde9 --- /dev/null +++ b/app/src/pages/project/__generated__/SessionsTable_sessions.graphql.ts @@ -0,0 +1,301 @@ +/** + * @generated SignedSource<> + * @lightSyntaxTransform + * @nogrep + */ + +/* tslint:disable */ +/* eslint-disable */ +// @ts-nocheck + +import { ReaderFragment, RefetchableFragment } from 'relay-runtime'; +import { FragmentRefs } from "relay-runtime"; +export type SessionsTable_sessions$data = { + readonly id: string; + readonly name: string; + readonly sessions: { + readonly edges: ReadonlyArray<{ + readonly session: { + readonly endTime: string; + readonly firstInput: { + readonly value: string; + } | null; + readonly id: string; + readonly lastOutput: { + readonly value: string; + } | null; + readonly numTraces: number; + readonly sessionId: string; + readonly startTime: string; + readonly tokenUsage: { + readonly completion: number; + readonly prompt: number; + readonly total: number; + }; + }; + }>; + }; + readonly " $fragmentType": "SessionsTable_sessions"; +}; +export type SessionsTable_sessions$key = { + readonly " $data"?: SessionsTable_sessions$data; + readonly " $fragmentSpreads": FragmentRefs<"SessionsTable_sessions">; +}; + +import SessionsTableQuery_graphql from './SessionsTableQuery.graphql'; + +const node: ReaderFragment = (function(){ +var v0 = [ + "sessions" +], +v1 = { + "alias": null, + "args": null, + "kind": "ScalarField", + "name": "id", + "storageKey": null +}, +v2 = [ + { + "alias": null, + "args": null, + "kind": "ScalarField", + "name": "value", + "storageKey": null + } +]; +return { + "argumentDefinitions": [ + { + "defaultValue": null, + "kind": "LocalArgument", + "name": "after" + }, + { + "defaultValue": 50, + "kind": "LocalArgument", + "name": "first" + }, + { + "kind": "RootArgument", + "name": "timeRange" + } + ], + "kind": "Fragment", + "metadata": { + "connection": [ + { + "count": "first", + "cursor": "after", + "direction": "forward", + "path": (v0/*: any*/) + } + ], + "refetch": { + "connection": { + "forward": { + "count": "first", + "cursor": "after" + }, + "backward": null, + "path": (v0/*: any*/) + }, + "fragmentPathInResult": [ + "node" + ], + "operation": SessionsTableQuery_graphql, + "identifierInfo": { + "identifierField": "id", + "identifierQueryVariableName": "id" + } + } + }, + "name": "SessionsTable_sessions", + "selections": [ + { + "alias": null, + "args": null, + "kind": "ScalarField", + "name": "name", + "storageKey": null + }, + { + "alias": "sessions", + "args": [ + { + "kind": "Variable", + "name": "timeRange", + "variableName": "timeRange" + } + ], + "concreteType": "ProjectSessionConnection", + "kind": "LinkedField", + "name": "__SessionsTable_sessions_connection", + "plural": false, + "selections": [ + { + "alias": null, + "args": null, + "concreteType": "ProjectSessionEdge", + "kind": "LinkedField", + "name": "edges", + "plural": true, + "selections": [ + { + "alias": "session", + "args": null, + "concreteType": "ProjectSession", + "kind": "LinkedField", + "name": "node", + "plural": false, + "selections": [ + (v1/*: any*/), + { + "alias": null, + "args": null, + "kind": "ScalarField", + "name": "sessionId", + "storageKey": null + }, + { + "alias": null, + "args": null, + "kind": "ScalarField", + "name": "numTraces", + "storageKey": null + }, + { + "alias": null, + "args": null, + "kind": "ScalarField", + "name": "startTime", + "storageKey": null + }, + { + "alias": null, + "args": null, + "kind": "ScalarField", + "name": "endTime", + "storageKey": null + }, + { + "alias": null, + "args": null, + "concreteType": "SpanIOValue", + "kind": "LinkedField", + "name": "firstInput", + "plural": false, + "selections": (v2/*: any*/), + "storageKey": null + }, + { + "alias": null, + "args": null, + "concreteType": "SpanIOValue", + "kind": "LinkedField", + "name": "lastOutput", + "plural": false, + "selections": (v2/*: any*/), + "storageKey": null + }, + { + "alias": null, + "args": null, + "concreteType": "TokenUsage", + "kind": "LinkedField", + "name": "tokenUsage", + "plural": false, + "selections": [ + { + "alias": null, + "args": null, + "kind": "ScalarField", + "name": "prompt", + "storageKey": null + }, + { + "alias": null, + "args": null, + "kind": "ScalarField", + "name": "completion", + "storageKey": null + }, + { + "alias": null, + "args": null, + "kind": "ScalarField", + "name": "total", + "storageKey": null + } + ], + "storageKey": null + } + ], + "storageKey": null + }, + { + "alias": null, + "args": null, + "kind": "ScalarField", + "name": "cursor", + "storageKey": null + }, + { + "alias": null, + "args": null, + "concreteType": "ProjectSession", + "kind": "LinkedField", + "name": "node", + "plural": false, + "selections": [ + { + "alias": null, + "args": null, + "kind": "ScalarField", + "name": "__typename", + "storageKey": null + } + ], + "storageKey": null + } + ], + "storageKey": null + }, + { + "alias": null, + "args": null, + "concreteType": "PageInfo", + "kind": "LinkedField", + "name": "pageInfo", + "plural": false, + "selections": [ + { + "alias": null, + "args": null, + "kind": "ScalarField", + "name": "endCursor", + "storageKey": null + }, + { + "alias": null, + "args": null, + "kind": "ScalarField", + "name": "hasNextPage", + "storageKey": null + } + ], + "storageKey": null + } + ], + "storageKey": null + }, + (v1/*: any*/) + ], + "type": "Project", + "abstractKey": null +}; +})(); + +(node as any).hash = "ffd50d06a86cb2efbd63be2f7e658dbf"; + +export default node; diff --git a/app/src/pages/trace/SessionDetails.tsx b/app/src/pages/trace/SessionDetails.tsx new file mode 100644 index 0000000000..4cf89b1f81 --- /dev/null +++ b/app/src/pages/trace/SessionDetails.tsx @@ -0,0 +1,79 @@ +import React, { useMemo } from "react"; +import { graphql, useLazyLoadQuery } from "react-relay"; +import { css } from "@emotion/react"; + +import { SessionDetailsQuery } from "./__generated__/SessionDetailsQuery.graphql"; + +export type SessionDetailsProps = { + sessionId: string; +}; + +/** + * A component that shows the details of a session + */ +export function SessionDetails(props: SessionDetailsProps) { + const { sessionId } = props; + const data = useLazyLoadQuery( + graphql` + query SessionDetailsQuery($id: GlobalID!) { + session: node(id: $id) { + ... on ProjectSession { + traces { + edges { + trace: node { + rootSpan { + input { + value + } + output { + value + } + } + } + } + } + } + } + } + `, + { + id: sessionId, + }, + { + fetchPolicy: "store-and-network", + } + ); + const spansList = useMemo(() => { + const gqlSpans = data.session?.traces?.edges || []; + return gqlSpans.map(({ trace }) => trace); + }, [data]); + return ( +
+ + + + + + + + + + {spansList.map((trace, index) => ( + + + + + + ))} + +
#UserAssistant
{index + 1}{trace.rootSpan?.input?.value}{trace.rootSpan?.output?.value}
+
+ ); +} diff --git a/app/src/pages/trace/SessionPage.tsx b/app/src/pages/trace/SessionPage.tsx new file mode 100644 index 0000000000..82a0d560b9 --- /dev/null +++ b/app/src/pages/trace/SessionPage.tsx @@ -0,0 +1,25 @@ +import React from "react"; +import { useNavigate, useParams } from "react-router"; + +import { Dialog, DialogContainer } from "@arizeai/components"; + +import { SessionDetails } from "./SessionDetails"; + +/** + * A component that shows the details of a session + */ +export function SessionPage() { + const { sessionId, projectId } = useParams(); + const navigate = useNavigate(); + return ( + navigate(`/projects/${projectId}`)} + > + + + + + ); +} diff --git a/app/src/pages/trace/__generated__/SessionDetailsQuery.graphql.ts b/app/src/pages/trace/__generated__/SessionDetailsQuery.graphql.ts new file mode 100644 index 0000000000..8babc62608 --- /dev/null +++ b/app/src/pages/trace/__generated__/SessionDetailsQuery.graphql.ts @@ -0,0 +1,207 @@ +/** + * @generated SignedSource<<09739022ca9b52d61d2eb5646d7e9906>> + * @lightSyntaxTransform + * @nogrep + */ + +/* tslint:disable */ +/* eslint-disable */ +// @ts-nocheck + +import { ConcreteRequest, Query } from 'relay-runtime'; +export type SessionDetailsQuery$variables = { + id: string; +}; +export type SessionDetailsQuery$data = { + readonly session: { + readonly traces?: { + readonly edges: ReadonlyArray<{ + readonly trace: { + readonly rootSpan: { + readonly input: { + readonly value: string; + } | null; + readonly output: { + readonly value: string; + } | null; + } | null; + }; + }>; + }; + }; +}; +export type SessionDetailsQuery = { + response: SessionDetailsQuery$data; + variables: SessionDetailsQuery$variables; +}; + +const node: ConcreteRequest = (function(){ +var v0 = [ + { + "defaultValue": null, + "kind": "LocalArgument", + "name": "id" + } +], +v1 = [ + { + "kind": "Variable", + "name": "id", + "variableName": "id" + } +], +v2 = [ + { + "alias": null, + "args": null, + "kind": "ScalarField", + "name": "value", + "storageKey": null + } +], +v3 = { + "kind": "InlineFragment", + "selections": [ + { + "alias": null, + "args": null, + "concreteType": "TraceConnection", + "kind": "LinkedField", + "name": "traces", + "plural": false, + "selections": [ + { + "alias": null, + "args": null, + "concreteType": "TraceEdge", + "kind": "LinkedField", + "name": "edges", + "plural": true, + "selections": [ + { + "alias": "trace", + "args": null, + "concreteType": "Trace", + "kind": "LinkedField", + "name": "node", + "plural": false, + "selections": [ + { + "alias": null, + "args": null, + "concreteType": "Span", + "kind": "LinkedField", + "name": "rootSpan", + "plural": false, + "selections": [ + { + "alias": null, + "args": null, + "concreteType": "SpanIOValue", + "kind": "LinkedField", + "name": "input", + "plural": false, + "selections": (v2/*: any*/), + "storageKey": null + }, + { + "alias": null, + "args": null, + "concreteType": "SpanIOValue", + "kind": "LinkedField", + "name": "output", + "plural": false, + "selections": (v2/*: any*/), + "storageKey": null + } + ], + "storageKey": null + } + ], + "storageKey": null + } + ], + "storageKey": null + } + ], + "storageKey": null + } + ], + "type": "ProjectSession", + "abstractKey": null +}; +return { + "fragment": { + "argumentDefinitions": (v0/*: any*/), + "kind": "Fragment", + "metadata": null, + "name": "SessionDetailsQuery", + "selections": [ + { + "alias": "session", + "args": (v1/*: any*/), + "concreteType": null, + "kind": "LinkedField", + "name": "node", + "plural": false, + "selections": [ + (v3/*: any*/) + ], + "storageKey": null + } + ], + "type": "Query", + "abstractKey": null + }, + "kind": "Request", + "operation": { + "argumentDefinitions": (v0/*: any*/), + "kind": "Operation", + "name": "SessionDetailsQuery", + "selections": [ + { + "alias": "session", + "args": (v1/*: any*/), + "concreteType": null, + "kind": "LinkedField", + "name": "node", + "plural": false, + "selections": [ + { + "alias": null, + "args": null, + "kind": "ScalarField", + "name": "__typename", + "storageKey": null + }, + (v3/*: any*/), + { + "kind": "TypeDiscriminator", + "abstractKey": "__isNode" + }, + { + "alias": null, + "args": null, + "kind": "ScalarField", + "name": "id", + "storageKey": null + } + ], + "storageKey": null + } + ] + }, + "params": { + "cacheID": "8b29972b49439ce8942839cc38a0c5be", + "id": null, + "metadata": {}, + "name": "SessionDetailsQuery", + "operationKind": "query", + "text": "query SessionDetailsQuery(\n $id: GlobalID!\n) {\n session: node(id: $id) {\n __typename\n ... on ProjectSession {\n traces {\n edges {\n trace: node {\n rootSpan {\n input {\n value\n }\n output {\n value\n }\n }\n }\n }\n }\n }\n __isNode: __typename\n id\n }\n}\n" + } +}; +})(); + +(node as any).hash = "e2fd415f62c6b3af00e449a4dc383959"; + +export default node; diff --git a/pyproject.toml b/pyproject.toml index c686126918..25029bea2c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,7 @@ dependencies = [ "starlette", "uvicorn", "psutil", - "strawberry-graphql==0.243.1", # need to pin version because we're monkey-patching + "strawberry-graphql==0.247.0", # need to pin version because we're monkey-patching "pyarrow", "typing-extensions>=4.6", "scipy", @@ -87,7 +87,7 @@ dev = [ "pytest-postgresql", "asyncpg", "psycopg[binary,pool]", - "strawberry-graphql[debug-server,opentelemetry]==0.243.1", # need to pin version because we're monkey-patching + "strawberry-graphql[debug-server,opentelemetry]==0.247.0", # need to pin version because we're monkey-patching "pre-commit", "arize[AutoEmbeddings, LLM_Evaluation]", "llama-index>=0.10.3", @@ -133,7 +133,7 @@ container = [ "opentelemetry-instrumentation-sqlalchemy", "opentelemetry-instrumentation-grpc", "py-grpc-prometheus", - "strawberry-graphql[opentelemetry]==0.243.1", # need to pin version because we're monkey-patching + "strawberry-graphql[opentelemetry]==0.247.0", # need to pin version because we're monkey-patching "uvloop; platform_system != 'Windows'", "fast-hdbscan>=0.2.0", "numba>=0.60.0", # https://github.com/astral-sh/uv/issues/6281 @@ -229,7 +229,7 @@ dependencies = [ "py-grpc-prometheus", "pypistats", # this is needed to type-check third-party packages "requests", # this is needed to type-check third-party packages - "strawberry-graphql[opentelemetry]==0.243.1", # need to pin version because we're monkey-patching + "strawberry-graphql[opentelemetry]==0.247.0", # need to pin version because we're monkey-patching "tenacity", "types-cachetools", "types-protobuf", diff --git a/requirements/build-graphql-schema.txt b/requirements/build-graphql-schema.txt index f91b682420..6d7989141f 100644 --- a/requirements/build-graphql-schema.txt +++ b/requirements/build-graphql-schema.txt @@ -1 +1 @@ -strawberry-graphql[cli]==0.243.1 +strawberry-graphql[cli]==0.247.0 diff --git a/requirements/type-check.txt b/requirements/type-check.txt index c2c5367119..feb846c375 100644 --- a/requirements/type-check.txt +++ b/requirements/type-check.txt @@ -18,7 +18,7 @@ psycopg[binary,pool] py-grpc-prometheus pypistats # this is needed to type-check third-party packages requests # this is needed to type-check third-party packages -strawberry-graphql[opentelemetry]==0.243.1 # need to pin version because we're monkey-patching +strawberry-graphql[opentelemetry]==0.247.0 # need to pin version because we're monkey-patching tenacity types-cachetools types-protobuf diff --git a/scripts/fixtures/multi-turn_chat_sessions.ipynb b/scripts/fixtures/multi-turn_chat_sessions.ipynb index 1b80fcdf43..a1f82ceb04 100644 --- a/scripts/fixtures/multi-turn_chat_sessions.ipynb +++ b/scripts/fixtures/multi-turn_chat_sessions.ipynb @@ -17,34 +17,29 @@ "metadata": {}, "outputs": [], "source": [ - "from base64 import b64encode\n", "from contextlib import ExitStack, contextmanager\n", - "from io import BytesIO\n", - "from random import choice, randint, random, shuffle\n", - "from secrets import token_hex\n", - "from time import sleep\n", + "from random import choice, choices, randint, random, shuffle\n", "\n", + "import numpy as np\n", "import openai\n", + "import pandas as pd\n", "from datasets import load_dataset\n", "from faker import Faker\n", - "from langchain_core.messages import AIMessage, HumanMessage\n", - "from langchain_openai import ChatOpenAI\n", - "from llama_index.core.llms import ChatMessage\n", - "from llama_index.llms.openai import OpenAI\n", "from mdgen import MarkdownPostProvider\n", "from openai_responses import OpenAIMock\n", - "from openinference.instrumentation.langchain import LangChainInstrumentor\n", - "from openinference.instrumentation.llama_index import LlamaIndexInstrumentor\n", + "from openinference.instrumentation import using_session, using_user\n", "from openinference.instrumentation.openai import OpenAIInstrumentor\n", "from openinference.semconv.trace import OpenInferenceSpanKindValues, SpanAttributes\n", "from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter\n", "from opentelemetry.sdk.trace import SpanLimits, StatusCode, TracerProvider\n", "from opentelemetry.sdk.trace.export import SimpleSpanProcessor\n", "from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter\n", - "from PIL import Image\n", "from tiktoken import encoding_for_model\n", "\n", - "fake = Faker()\n", + "import phoenix as px\n", + "from phoenix.trace.span_evaluations import SpanEvaluations\n", + "\n", + "fake = Faker(\"ja_JP\")\n", "fake.add_provider(MarkdownPostProvider)" ] }, @@ -105,11 +100,21 @@ "outputs": [], "source": [ "def gen_session_id():\n", - " return token_hex(32) if random() < 0.5 else int(abs(random()) * 1_000_000_000)\n", + " p = random()\n", + " if p < 0.1:\n", + " return \":\" * randint(1, 5)\n", + " if p < 0.9:\n", + " return Faker([\"ja_JP\", \"vi_VN\", \"ko_KR\", \"zh_CN\"]).address()\n", + " return int(abs(random()) * 1_000_000_000)\n", "\n", "\n", "def gen_user_id():\n", - " return fake.user_name() if random() < 0.5 else int(abs(random()) * 1_000_000_000)\n", + " p = random()\n", + " if p < 0.1:\n", + " return \":\" * randint(1, 5)\n", + " if p < 0.9:\n", + " return Faker([\"ja_JP\", \"vi_VN\", \"ko_KR\", \"zh_CN\"]).name()\n", + " return int(abs(random()) * 1_000_000_000)\n", "\n", "\n", "def export_spans():\n", @@ -118,66 +123,49 @@ " shuffle(spans)\n", " for span in spans:\n", " otlp_span_exporter.export([span])\n", - " sleep(0.01)\n", " in_memory_span_exporter.clear()\n", + " session_count = len({id_ for span in spans if (id_ := span.attributes.get(\"session.id\"))})\n", + " trace_count = len({span.context.trace_id for span in spans})\n", + " print(f\"Exported {session_count} sessions, {trace_count} traces, {len(spans)} spans\")\n", + " return spans\n", "\n", "\n", "def rand_span_kind():\n", " yield SpanAttributes.OPENINFERENCE_SPAN_KIND, choice(list(OpenInferenceSpanKindValues)).value\n", "\n", "\n", - "def set_session_id(span, has_session_id, session_id):\n", - " if not has_session_id and random() < 0.1:\n", - " span.set_attribute(SpanAttributes.SESSION_ID, session_id)\n", - " return True\n", - " return has_session_id\n", - "\n", - "\n", - "def set_user_id(span, has_user_id, user_id):\n", - " if not has_user_id and random() < 0.1:\n", - " span.set_attribute(SpanAttributes.USER_ID, user_id)\n", - " return True\n", - " return has_user_id\n", + "def rand_status_code():\n", + " return choices(\n", + " [StatusCode.OK, StatusCode.ERROR, StatusCode.UNSET], k=1, weights=[0.98, 0.01, 0.01]\n", + " )[0]\n", "\n", "\n", "@contextmanager\n", - "def trace_tree(session_id, user_id):\n", - " has_session_id = has_user_id = False\n", - " tracer = tracer_provider.get_tracer(__name__)\n", - " with ExitStack() as trace:\n", - " root = trace.enter_context(\n", - " tracer.start_as_current_span(\n", - " \"root\",\n", - " attributes=dict(rand_span_kind()),\n", - " end_on_exit=False,\n", - " )\n", - " )\n", - " root.set_status(choice([StatusCode.OK] * 100 + list(StatusCode)))\n", - " for _ in range(randint(0, 10)):\n", - " span = trace.enter_context(\n", - " tracer.start_as_current_span(\"parent\", attributes=dict(rand_span_kind()))\n", - " )\n", - " has_session_id = set_session_id(span, has_session_id, session_id)\n", - " has_user_id = set_user_id(span, has_user_id, user_id)\n", - " span.set_status(choice([StatusCode.OK] * 100 + list(StatusCode)))\n", - " for _ in range(randint(0, 10)):\n", - " span = tracer.start_span(\"sibling\", attributes=dict(rand_span_kind()))\n", - " has_session_id = set_session_id(span, has_session_id, session_id)\n", - " has_user_id = set_user_id(span, has_user_id, user_id)\n", - " span.set_status(choice([StatusCode.OK] * 100 + list(StatusCode)))\n", - " span.end()\n", + "def trace_tree(tracer, n=5):\n", + " if n <= 0:\n", " yield\n", - " for _ in range(randint(0, 10)):\n", - " span = tracer.start_span(\"sibling\", attributes=dict(rand_span_kind()))\n", - " has_session_id = set_session_id(span, has_session_id, session_id)\n", - " has_user_id = set_user_id(span, has_user_id, user_id)\n", - " span.set_status(choice([StatusCode.OK] * 100 + list(StatusCode)))\n", - " span.end()\n", - " if not has_session_id:\n", - " root.set_attribute(SpanAttributes.SESSION_ID, session_id)\n", - " if not has_user_id:\n", - " root.set_attribute(SpanAttributes.USER_ID, user_id)\n", - " root.end()" + " return\n", + " has_yielded = False\n", + " with tracer.start_as_current_span(\n", + " Faker(\"ja_JP\").kana_name(),\n", + " attributes=dict(rand_span_kind()),\n", + " end_on_exit=False,\n", + " ) as root:\n", + " for _ in range(randint(0, n)):\n", + " with trace_tree(tracer, randint(0, n - 1)):\n", + " if not has_yielded and random() < 0.5:\n", + " yield\n", + " has_yielded = True\n", + " else:\n", + " pass\n", + " if not has_yielded:\n", + " yield\n", + " has_yielded = True\n", + " for _ in range(randint(0, n)):\n", + " with trace_tree(tracer, randint(0, n - 1)):\n", + " pass\n", + " root.set_status(rand_status_code())\n", + " root.end(int(fake.future_datetime(\"+5s\").timestamp() * 10**9))" ] }, { @@ -185,15 +173,7 @@ "id": "a2f2ac17", "metadata": {}, "source": [ - "# Text Only" - ] - }, - { - "cell_type": "markdown", - "id": "2abc6b1f", - "metadata": {}, - "source": [ - "## OpenAI" + "# Genarate Sessions" ] }, { @@ -203,66 +183,21 @@ "metadata": {}, "outputs": [], "source": [ - "session_count = 5\n", - "user_id = gen_user_id()\n", + "session_count = randint(5, 10)\n", + "tree_complexity = 4 # set to 0 for single span under root\n", "\n", "\n", - "def simulate_openai(messages):\n", + "def simulate_openai():\n", + " user_id = gen_user_id() if random() < 0.9 else \" \"\n", " session_id = gen_session_id()\n", " client = openai.Client(api_key=\"sk-\")\n", " model = \"gpt-4o-mini\"\n", " encoding = encoding_for_model(model)\n", + " messages = np.concatenate(convo.sample(randint(1, 10)).values)\n", " counts = [len(encoding.encode(m[\"content\"])) for m in messages]\n", " openai_mock = OpenAIMock()\n", - " with ExitStack() as stack:\n", - " stack.enter_context(openai_mock.router)\n", - " for i in range(1, len(messages), 2):\n", - " openai_mock.chat.completions.create.response = dict(\n", - " choices=[dict(index=0, finish_reason=\"stop\", message=messages[i])],\n", - " usage=dict(\n", - " prompt_tokens=sum(counts[:i]),\n", - " completion_tokens=counts[i],\n", - " total_tokens=sum(counts[: i + 1]),\n", - " ),\n", - " )\n", - " with trace_tree(session_id, user_id):\n", - " client.chat.completions.create(model=model, messages=messages[:i])\n", - "\n", - "\n", - "OpenAIInstrumentor().instrument(tracer_provider=tracer_provider)\n", - "convo.sample(session_count).apply(simulate_openai)\n", - "OpenAIInstrumentor().uninstrument()\n", - "export_spans()" - ] - }, - { - "cell_type": "markdown", - "id": "b3367066", - "metadata": {}, - "source": [ - "## LangChain" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "0f42e1f8", - "metadata": {}, - "outputs": [], - "source": [ - "session_count = 5\n", - "user_id = gen_user_id()\n", - "\n", - "\n", - "def simulate_langchain(messages):\n", - " session_id = gen_session_id()\n", - " model = \"gpt-4o-mini\"\n", - " encoding = encoding_for_model(model)\n", - " counts = [len(encoding.encode(m[\"content\"])) for m in messages]\n", - " llm = ChatOpenAI(model_name=model, openai_api_key=\"sk-\")\n", - " openai_mock = OpenAIMock()\n", - " with ExitStack() as stack:\n", - " stack.enter_context(openai_mock.router)\n", + " tracer = tracer_provider.get_tracer(__name__)\n", + " with openai_mock.router:\n", " for i in range(1, len(messages), 2):\n", " openai_mock.chat.completions.create.response = dict(\n", " choices=[dict(index=0, finish_reason=\"stop\", message=messages[i])],\n", @@ -272,140 +207,53 @@ " total_tokens=sum(counts[: i + 1]),\n", " ),\n", " )\n", - " with trace_tree(session_id, user_id):\n", - " llm.invoke(\n", - " [\n", - " HumanMessage(m[\"content\"])\n", - " if m[\"role\"] == \"user\"\n", - " else AIMessage(m[\"content\"])\n", - " for m in messages[:i]\n", - " ]\n", + " with ExitStack() as stack:\n", + " attributes = {\n", + " \"input.value\": messages[i - 1][\"content\"],\n", + " \"output.value\": messages[i][\"content\"],\n", + " }\n", + " if random() < 0.5:\n", + " attributes[\"session.id\"] = session_id\n", + " attributes[\"user.id\"] = user_id\n", + " else:\n", + " stack.enter_context(using_session(session_id))\n", + " stack.enter_context(using_user(user_id))\n", + " root = stack.enter_context(\n", + " tracer.start_as_current_span(\n", + " \"root\",\n", + " attributes=attributes,\n", + " end_on_exit=False,\n", + " )\n", " )\n", - "\n", - "\n", - "LangChainInstrumentor().instrument(tracer_provider=tracer_provider)\n", - "convo.sample(session_count).apply(simulate_langchain)\n", - "LangChainInstrumentor().uninstrument()\n", - "export_spans()" - ] - }, - { - "cell_type": "markdown", - "id": "cc33b8eb", - "metadata": {}, - "source": [ - "## Llama-Index" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "03417027", - "metadata": {}, - "outputs": [], - "source": [ - "session_count = 5\n", - "user_id = gen_user_id()\n", - "\n", - "\n", - "def simulate_llama_index(messages):\n", - " session_id = gen_session_id()\n", - " model = \"gpt-4o-mini\"\n", - " encoding = encoding_for_model(model)\n", - " counts = [len(encoding.encode(m[\"content\"])) for m in messages]\n", - " llm = OpenAI(api_key=\"sk-\")\n", - " openai_mock = OpenAIMock()\n", - " with ExitStack() as stack:\n", - " stack.enter_context(openai_mock.router)\n", - " for i in range(1, len(messages), 2):\n", - " openai_mock.chat.completions.create.response = dict(\n", - " choices=[dict(index=0, finish_reason=\"stop\", message=messages[i])],\n", - " usage=dict(\n", - " prompt_tokens=sum(counts[:i]),\n", - " completion_tokens=counts[i],\n", - " total_tokens=sum(counts[: i + 1]),\n", - " ),\n", - " )\n", - " with trace_tree(session_id, user_id):\n", - " llm.complete([ChatMessage(**m) for m in messages[:i]])\n", - "\n", - "\n", - "LlamaIndexInstrumentor().instrument(tracer_provider=tracer_provider)\n", - "convo.sample(session_count).apply(simulate_llama_index)\n", - "LlamaIndexInstrumentor().uninstrument()\n", - "export_spans()" - ] - }, - { - "cell_type": "markdown", - "id": "ce8fe3ab", - "metadata": {}, - "source": [ - "# Vision" - ] - }, - { - "cell_type": "markdown", - "id": "74d13b81", - "metadata": {}, - "source": [ - "## OpenAI" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "96ce318b", - "metadata": {}, - "outputs": [], - "source": [ - "session_count = 5\n", - "user_id = gen_user_id()\n", - "\n", - "\n", - "def simulate_openai_vision():\n", - " session_id = gen_session_id()\n", - " client = openai.Client(api_key=\"sk-\")\n", - " model = \"gpt-4o-mini\"\n", - " encoding = encoding_for_model(model)\n", - " openai_mock = OpenAIMock()\n", - " messages = []\n", - " usage = dict(prompt_tokens=0, completion_tokens=0, total_tokens=0)\n", - " with ExitStack() as stack:\n", - " stack.enter_context(openai_mock.router)\n", - " for _ in range(randint(5, 20)):\n", - " text = fake.post(size=\"small\")\n", - " if random() < 0.5:\n", - " images = []\n", - " for _ in range(randint(3, 10)):\n", - " img = Image.new(\"RGB\", (5, 5), fake.color_rgb())\n", - " buffered = BytesIO()\n", - " img.save(buffered, format=\"PNG\")\n", - " url = f\"data:image/png;base64,{b64encode(buffered.getvalue()).decode()}\"\n", - " images.append(dict(type=\"image_url\", image_url=dict(url=url)))\n", - " content = [dict(type=\"text\", text=text)] + images\n", - " shuffle(content)\n", - " else:\n", - " content = text\n", - " request = dict(role=\"user\", content=content)\n", - " response = dict(role=\"assistant\", content=fake.post(size=\"medium\"))\n", - " usage[\"prompt_tokens\"] += len(encoding.encode(text))\n", - " usage[\"completion_tokens\"] += len(encoding.encode(response[\"content\"]))\n", - " usage[\"total_tokens\"] = usage[\"prompt_tokens\"] + usage[\"completion_tokens\"]\n", - " messages.extend([request, response])\n", - " openai_mock.chat.completions.create.response = dict(\n", - " choices=[dict(index=0, finish_reason=\"stop\", message=messages[-1])],\n", - " usage=usage,\n", - " )\n", - " with trace_tree(session_id, user_id):\n", - " client.chat.completions.create(model=model, messages=messages[:-1])\n", + " with trace_tree(tracer, tree_complexity):\n", + " client.chat.completions.create(model=model, messages=messages[:i])\n", + " root.set_status(rand_status_code())\n", + " root.end(int(fake.future_datetime(\"+5s\").timestamp() * 10**9))\n", "\n", "\n", "OpenAIInstrumentor().instrument(tracer_provider=tracer_provider)\n", - "for _ in range(session_count):\n", - " simulate_openai_vision()\n", - "OpenAIInstrumentor().uninstrument()\n", - "export_spans()" + "try:\n", + " for _ in range(session_count):\n", + " simulate_openai()\n", + "finally:\n", + " OpenAIInstrumentor().uninstrument()\n", + "spans = export_spans()\n", + "\n", + "# Annotate root spans\n", + "root_span_ids = pd.Series(\n", + " [span.context.span_id.to_bytes(8, \"big\").hex() for span in spans if span.parent is None]\n", + ")\n", + "for name in \"ABC\":\n", + " span_ids = root_span_ids.sample(frac=0.5)\n", + " df = pd.DataFrame(\n", + " {\n", + " \"context.span_id\": span_ids,\n", + " \"score\": np.random.rand(len(span_ids)),\n", + " \"label\": np.random.choice([\"👍\", \"👎\"], len(span_ids)),\n", + " \"explanation\": [fake.paragraph(10) for _ in range(len(span_ids))],\n", + " }\n", + " ).set_index(\"context.span_id\")\n", + " px.Client().log_evaluations(SpanEvaluations(name, df))" ] } ], diff --git a/src/phoenix/db/insertion/span.py b/src/phoenix/db/insertion/span.py index 1ca80b2b08..a5fcba714f 100644 --- a/src/phoenix/db/insertion/span.py +++ b/src/phoenix/db/insertion/span.py @@ -38,89 +38,63 @@ async def insert_span( project_session: Optional[models.ProjectSession] = None session_id = get_attribute_value(span.attributes, SpanAttributes.SESSION_ID) + session_user = get_attribute_value(span.attributes, SpanAttributes.USER_ID) if session_id is not None and (not isinstance(session_id, str) or session_id.strip()): session_id = str(session_id).strip() assert isinstance(session_id, str) + if session_user is not None: + session_user = str(session_user).strip() + assert isinstance(session_user, str) project_session = await session.scalar( select(models.ProjectSession).filter_by(session_id=session_id) ) if project_session: - project_session_needs_update = False - project_session_end_time = None - project_session_project_id = None if project_session.end_time < span.end_time: - project_session_needs_update = True - project_session_end_time = span.end_time - project_session_project_id = project_rowid - project_session_start_time = None + project_session.end_time = span.end_time + project_session.project_id = project_rowid if span.start_time < project_session.start_time: - project_session_needs_update = True - project_session_start_time = span.start_time - if project_session_needs_update: - project_session = await session.scalar( - update(models.ProjectSession) - .filter_by(id=project_session.id) - .values( - start_time=project_session_start_time or project_session.start_time, - end_time=project_session_end_time or project_session.end_time, - project_id=project_session_project_id or project_session.project_id, - ) - .returning(models.ProjectSession) - ) + project_session.start_time = span.start_time + if session_user and project_session.session_user != session_user: + project_session.session_user = session_user else: - project_session = await session.scalar( - insert(models.ProjectSession) - .values( - project_id=project_rowid, - session_id=session_id, - start_time=span.start_time, - end_time=span.end_time, - ) - .returning(models.ProjectSession) + project_session = models.ProjectSession( + project_id=project_rowid, + session_id=session_id, + session_user=session_user if session_user else None, + start_time=span.start_time, + end_time=span.end_time, ) + session.add(project_session) + if project_session in session.dirty: + await session.flush() trace_id = span.context.trace_id - trace: Optional[models.Trace] = await session.scalar( - select(models.Trace).filter_by(trace_id=trace_id) - ) + trace = await session.scalar(select(models.Trace).filter_by(trace_id=trace_id)) if trace: - trace_needs_update = False - trace_end_time = None - trace_project_rowid = None - trace_project_session_id = None + if project_session and ( + trace.project_session_rowid is None + or ( + trace.end_time < span.end_time and trace.project_session_rowid != project_session.id + ) + ): + trace.project_session_rowid = project_session.id if trace.end_time < span.end_time: - trace_needs_update = True - trace_end_time = span.end_time - trace_project_rowid = project_rowid - trace_project_session_id = project_session.id if project_session else None - trace_start_time = None + trace.end_time = span.end_time + trace.project_rowid = project_rowid if span.start_time < trace.start_time: - trace_needs_update = True - trace_start_time = span.start_time - if trace_needs_update: - await session.execute( - update(models.Trace) - .filter_by(id=trace.id) - .values( - start_time=trace_start_time or trace.start_time, - end_time=trace_end_time or trace.end_time, - project_rowid=trace_project_rowid or trace.project_rowid, - project_session_id=trace_project_session_id or trace.project_session_id, - ) - ) + trace.start_time = span.start_time else: - trace = await session.scalar( - insert(models.Trace) - .values( - project_rowid=project_rowid, - trace_id=span.context.trace_id, - start_time=span.start_time, - end_time=span.end_time, - project_session_id=project_session.id if project_session else None, - ) - .returning(models.Trace) + trace = models.Trace( + project_rowid=project_rowid, + trace_id=span.context.trace_id, + start_time=span.start_time, + end_time=span.end_time, + project_session_rowid=project_session.id if project_session else None, ) - assert trace is not None + session.add(trace) + if trace in session.dirty: + await session.flush() + cumulative_error_count = int(span.status_code is SpanStatusCode.ERROR) cumulative_llm_token_count_prompt = cast( int, get_attribute_value(span.attributes, SpanAttributes.LLM_TOKEN_COUNT_PROMPT) or 0 diff --git a/src/phoenix/db/migrations/versions/4ded9e43755f_create_project_session_table.py b/src/phoenix/db/migrations/versions/4ded9e43755f_create_project_sessions_table.py similarity index 79% rename from src/phoenix/db/migrations/versions/4ded9e43755f_create_project_session_table.py rename to src/phoenix/db/migrations/versions/4ded9e43755f_create_project_sessions_table.py index 5939626a7f..61dc0e7b7d 100644 --- a/src/phoenix/db/migrations/versions/4ded9e43755f_create_project_session_table.py +++ b/src/phoenix/db/migrations/versions/4ded9e43755f_create_project_sessions_table.py @@ -23,6 +23,7 @@ def upgrade() -> None: "project_sessions", sa.Column("id", sa.Integer, primary_key=True), sa.Column("session_id", sa.String, unique=True, nullable=False), + sa.Column("session_user", sa.String, index=True), sa.Column( "project_id", sa.Integer, @@ -31,26 +32,26 @@ def upgrade() -> None: index=True, ), sa.Column("start_time", sa.TIMESTAMP(timezone=True), index=True, nullable=False), - sa.Column("end_time", sa.TIMESTAMP(timezone=True), index=True, nullable=False), + sa.Column("end_time", sa.TIMESTAMP(timezone=True), nullable=False), ) with op.batch_alter_table("traces") as batch_op: batch_op.add_column( sa.Column( - "project_session_id", + "project_session_rowid", sa.Integer, sa.ForeignKey("project_sessions.id", ondelete="CASCADE"), nullable=True, ), ) op.create_index( - "ix_traces_project_session_id", + "ix_traces_project_session_rowid", "traces", - ["project_session_id"], + ["project_session_rowid"], ) def downgrade() -> None: - op.drop_index("ix_traces_project_session_id") + op.drop_index("ix_traces_project_session_rowid") with op.batch_alter_table("traces") as batch_op: - batch_op.drop_column("project_session_id") + batch_op.drop_column("project_session_rowid") op.drop_table("project_sessions") diff --git a/src/phoenix/db/models.py b/src/phoenix/db/models.py index 5f8c207b0d..5eefd6aa0d 100644 --- a/src/phoenix/db/models.py +++ b/src/phoenix/db/models.py @@ -160,6 +160,7 @@ class ProjectSession(Base): __tablename__ = "project_sessions" id: Mapped[int] = mapped_column(primary_key=True) session_id: Mapped[str] = mapped_column(String, nullable=False, unique=True) + session_user: Mapped[Optional[str]] = mapped_column(index=True) project_id: Mapped[int] = mapped_column( ForeignKey("projects.id", ondelete="CASCADE"), nullable=False, @@ -179,12 +180,12 @@ class Trace(Base): id: Mapped[int] = mapped_column(primary_key=True) project_rowid: Mapped[int] = mapped_column( ForeignKey("projects.id", ondelete="CASCADE"), + nullable=False, index=True, ) trace_id: Mapped[str] - project_session_id: Mapped[int] = mapped_column( + project_session_rowid: Mapped[Optional[int]] = mapped_column( ForeignKey("project_sessions.id", ondelete="CASCADE"), - nullable=True, index=True, ) start_time: Mapped[datetime] = mapped_column(UtcTimeStamp, index=True) diff --git a/src/phoenix/server/api/context.py b/src/phoenix/server/api/context.py index ed39375614..e777a9ef28 100644 --- a/src/phoenix/server/api/context.py +++ b/src/phoenix/server/api/context.py @@ -36,7 +36,7 @@ SpanDescendantsDataLoader, SpanProjectsDataLoader, TokenCountDataLoader, - TraceRowIdsDataLoader, + TraceByTraceIdsDataLoader, UserRolesDataLoader, UsersDataLoader, ) @@ -73,7 +73,7 @@ class DataLoaders: span_descendants: SpanDescendantsDataLoader span_projects: SpanProjectsDataLoader token_counts: TokenCountDataLoader - trace_row_ids: TraceRowIdsDataLoader + trace_by_trace_ids: TraceByTraceIdsDataLoader project_by_name: ProjectByNameDataLoader users: UsersDataLoader user_roles: UserRolesDataLoader diff --git a/src/phoenix/server/api/dataloaders/__init__.py b/src/phoenix/server/api/dataloaders/__init__.py index 3024283c05..8e33ee97b9 100644 --- a/src/phoenix/server/api/dataloaders/__init__.py +++ b/src/phoenix/server/api/dataloaders/__init__.py @@ -24,7 +24,7 @@ from .span_descendants import SpanDescendantsDataLoader from .span_projects import SpanProjectsDataLoader from .token_counts import TokenCountCache, TokenCountDataLoader -from .trace_row_ids import TraceRowIdsDataLoader +from .trace_by_trace_ids import TraceByTraceIdsDataLoader from .user_roles import UserRolesDataLoader from .users import UsersDataLoader @@ -49,7 +49,7 @@ "SpanDescendantsDataLoader", "SpanProjectsDataLoader", "TokenCountDataLoader", - "TraceRowIdsDataLoader", + "TraceByTraceIdsDataLoader", "ProjectByNameDataLoader", "SpanAnnotationsDataLoader", "UsersDataLoader", diff --git a/src/phoenix/server/api/dataloaders/trace_by_trace_ids.py b/src/phoenix/server/api/dataloaders/trace_by_trace_ids.py new file mode 100644 index 0000000000..e8d2fe6326 --- /dev/null +++ b/src/phoenix/server/api/dataloaders/trace_by_trace_ids.py @@ -0,0 +1,26 @@ +from typing import List, Optional + +from sqlalchemy import select +from strawberry.dataloader import DataLoader +from typing_extensions import TypeAlias + +from phoenix.db import models +from phoenix.server.types import DbSessionFactory + +TraceId: TypeAlias = str +Key: TypeAlias = TraceId +TraceRowId: TypeAlias = int +ProjectRowId: TypeAlias = int +Result: TypeAlias = Optional[models.Trace] + + +class TraceByTraceIdsDataLoader(DataLoader[Key, Result]): + def __init__(self, db: DbSessionFactory) -> None: + super().__init__(load_fn=self._load_fn) + self._db = db + + async def _load_fn(self, keys: List[Key]) -> List[Result]: + stmt = select(models.Trace).where(models.Trace.trace_id.in_(keys)) + async with self._db() as session: + result = {trace.trace_id: trace for trace in await session.scalars(stmt)} + return [result.get(trace_id) for trace_id in keys] diff --git a/src/phoenix/server/api/dataloaders/trace_row_ids.py b/src/phoenix/server/api/dataloaders/trace_row_ids.py deleted file mode 100644 index 101501d5c9..0000000000 --- a/src/phoenix/server/api/dataloaders/trace_row_ids.py +++ /dev/null @@ -1,33 +0,0 @@ -from typing import Optional - -from sqlalchemy import select -from strawberry.dataloader import DataLoader -from typing_extensions import TypeAlias - -from phoenix.db import models -from phoenix.server.types import DbSessionFactory - -TraceId: TypeAlias = str -Key: TypeAlias = TraceId -TraceRowId: TypeAlias = int -ProjectRowId: TypeAlias = int -Result: TypeAlias = Optional[tuple[TraceRowId, ProjectRowId]] - - -class TraceRowIdsDataLoader(DataLoader[Key, Result]): - def __init__(self, db: DbSessionFactory) -> None: - super().__init__(load_fn=self._load_fn) - self._db = db - - async def _load_fn(self, keys: list[Key]) -> list[Result]: - stmt = select( - models.Trace.trace_id, - models.Trace.id, - models.Trace.project_rowid, - ).where(models.Trace.trace_id.in_(keys)) - async with self._db() as session: - result = { - trace_id: (id_, project_rowid) - async for trace_id, id_, project_rowid in await session.stream(stmt) - } - return list(map(result.get, keys)) diff --git a/src/phoenix/server/api/mutations/project_mutations.py b/src/phoenix/server/api/mutations/project_mutations.py index 30d38620f1..1eb046d23a 100644 --- a/src/phoenix/server/api/mutations/project_mutations.py +++ b/src/phoenix/server/api/mutations/project_mutations.py @@ -38,10 +38,20 @@ async def clear_project(self, info: Info[Context, None], input: ClearProjectInpu project_id = from_global_id_with_expected_type( global_id=input.id, expected_type_name="Project" ) - delete_statement = delete(models.Trace).where(models.Trace.project_rowid == project_id) + delete_statement = ( + delete(models.Trace) + .where(models.Trace.project_rowid == project_id) + .returning(models.Trace.project_session_rowid) + ) if input.end_time: delete_statement = delete_statement.where(models.Trace.start_time < input.end_time) async with info.context.db() as session: - await session.execute(delete_statement) + deleted_trace_project_session_ids = await session.scalars(delete_statement) + if deleted_trace_project_session_ids: + await session.execute( + delete(models.ProjectSession).where( + models.ProjectSession.id.in_(set(deleted_trace_project_session_ids)) + ) + ) info.context.event_queue.put(SpanDeleteEvent((project_id,))) return Query() diff --git a/src/phoenix/server/api/queries.py b/src/phoenix/server/api/queries.py index 724b2ecf76..7762683e5b 100644 --- a/src/phoenix/server/api/queries.py +++ b/src/phoenix/server/api/queries.py @@ -82,7 +82,7 @@ from phoenix.server.api.types.SortDir import SortDir from phoenix.server.api.types.Span import Span, to_gql_span from phoenix.server.api.types.SystemApiKey import SystemApiKey -from phoenix.server.api.types.Trace import Trace +from phoenix.server.api.types.Trace import to_gql_trace from phoenix.server.api.types.User import User, to_gql_user from phoenix.server.api.types.UserApiKey import UserApiKey, to_gql_api_key from phoenix.server.api.types.UserRole import UserRole @@ -446,17 +446,12 @@ async def node(self, id: GlobalID, info: Info[Context, None]) -> Node: gradient_end_color=project.gradient_end_color, ) elif type_name == "Trace": - trace_stmt = select( - models.Trace.id, - models.Trace.project_rowid, - ).where(models.Trace.id == node_id) + trace_stmt = select(models.Trace).filter_by(id=node_id) async with info.context.db() as session: - trace = (await session.execute(trace_stmt)).first() + trace = await session.scalar(trace_stmt) if trace is None: raise NotFound(f"Unknown trace: {id}") - return Trace( - id_attr=trace.id, trace_id=trace.trace_id, project_rowid=trace.project_rowid - ) + return to_gql_trace(trace) elif type_name == Span.__name__: span_stmt = ( select(models.Span) @@ -470,14 +465,6 @@ async def node(self, id: GlobalID, info: Info[Context, None]) -> Node: if span is None: raise NotFound(f"Unknown span: {id}") return to_gql_span(span) - elif type_name == ProjectSession.__name__: - async with info.context.db() as session: - project_session = await session.scalar( - select(models.ProjectSession).filter_by(id=node_id) - ) - if project_session is None: - raise NotFound(f"Unknown project_session: {id}") - return to_gql_project_session(project_session) elif type_name == Dataset.__name__: dataset_stmt = select(models.Dataset).where(models.Dataset.id == node_id) async with info.context.db() as session: @@ -553,6 +540,15 @@ async def node(self, id: GlobalID, info: Info[Context, None]) -> Node: ): raise NotFound(f"Unknown user: {id}") return to_gql_user(user) + elif type_name == ProjectSession.__name__: + async with info.context.db() as session: + if not ( + project_session := await session.scalar( + select(models.ProjectSession).filter_by(id=node_id) + ) + ): + raise NotFound(f"Unknown user: {id}") + return to_gql_project_session(project_session) raise NotFound(f"Unknown node type: {type_name}") @strawberry.field diff --git a/src/phoenix/server/api/types/ExperimentRun.py b/src/phoenix/server/api/types/ExperimentRun.py index 6150323a4b..c01f913f18 100644 --- a/src/phoenix/server/api/types/ExperimentRun.py +++ b/src/phoenix/server/api/types/ExperimentRun.py @@ -20,7 +20,7 @@ CursorString, connection_from_list, ) -from phoenix.server.api.types.Trace import Trace +from phoenix.server.api.types.Trace import Trace, to_gql_trace if TYPE_CHECKING: from phoenix.server.api.types.DatasetExample import DatasetExample @@ -61,11 +61,10 @@ async def annotations( async def trace(self, info: Info) -> Optional[Trace]: if not self.trace_id: return None - dataloader = info.context.data_loaders.trace_row_ids + dataloader = info.context.data_loaders.trace_by_trace_ids if (trace := await dataloader.load(self.trace_id)) is None: return None - trace_rowid, project_rowid = trace - return Trace(id_attr=trace_rowid, trace_id=self.trace_id, project_rowid=project_rowid) + return to_gql_trace(trace) @strawberry.field async def example( diff --git a/src/phoenix/server/api/types/ExperimentRunAnnotation.py b/src/phoenix/server/api/types/ExperimentRunAnnotation.py index f144b715f2..a1e8d539ec 100644 --- a/src/phoenix/server/api/types/ExperimentRunAnnotation.py +++ b/src/phoenix/server/api/types/ExperimentRunAnnotation.py @@ -8,7 +8,7 @@ from phoenix.db import models from phoenix.server.api.types.AnnotatorKind import ExperimentRunAnnotatorKind -from phoenix.server.api.types.Trace import Trace +from phoenix.server.api.types.Trace import Trace, to_gql_trace @strawberry.type @@ -29,11 +29,10 @@ class ExperimentRunAnnotation(Node): async def trace(self, info: Info) -> Optional[Trace]: if not self.trace_id: return None - dataloader = info.context.data_loaders.trace_row_ids + dataloader = info.context.data_loaders.trace_by_trace_ids if (trace := await dataloader.load(self.trace_id)) is None: return None - trace_row_id, project_row_id = trace - return Trace(id_attr=trace_row_id, trace_id=self.trace_id, project_rowid=project_row_id) + return to_gql_trace(trace) def to_gql_experiment_run_annotation( diff --git a/src/phoenix/server/api/types/Project.py b/src/phoenix/server/api/types/Project.py index 3dfa4100b2..c8129abb31 100644 --- a/src/phoenix/server/api/types/Project.py +++ b/src/phoenix/server/api/types/Project.py @@ -31,7 +31,7 @@ from phoenix.server.api.types.ProjectSession import ProjectSession, to_gql_project_session from phoenix.server.api.types.SortDir import SortDir from phoenix.server.api.types.Span import Span, to_gql_span -from phoenix.server.api.types.Trace import Trace +from phoenix.server.api.types.Trace import Trace, to_gql_trace from phoenix.server.api.types.ValidationResult import ValidationResult from phoenix.trace.dsl import SpanFilter @@ -146,14 +146,14 @@ async def span_latency_ms_quantile( @strawberry.field async def trace(self, trace_id: ID, info: Info[Context, None]) -> Optional[Trace]: stmt = ( - select(models.Trace.id) + select(models.Trace) .where(models.Trace.trace_id == str(trace_id)) .where(models.Trace.project_rowid == self.id_attr) ) async with info.context.db() as session: - if (id_attr := await session.scalar(stmt)) is None: + if (trace := await session.scalar(stmt)) is None: return None - return Trace(id_attr=id_attr, trace_id=trace_id, project_rowid=self.id_attr) + return to_gql_trace(trace) @strawberry.field async def spans( @@ -259,12 +259,12 @@ async def sessions( stmt = stmt.where(table.start_time < time_range.end) if after: cursor = Cursor.from_string(after) - stmt = stmt.where(table.id > cursor.rowid) + stmt = stmt.where(table.id < cursor.rowid) if first: stmt = stmt.limit( first + 1 # over-fetch by one to determine whether there's a next page ) - stmt = stmt.order_by(table.id) + stmt = stmt.order_by(table.id.desc()) cursors_and_nodes = [] async with info.context.db() as session: records = await session.scalars(stmt) diff --git a/src/phoenix/server/api/types/ProjectSession.py b/src/phoenix/server/api/types/ProjectSession.py index 436b0f494a..666937fb66 100644 --- a/src/phoenix/server/api/types/ProjectSession.py +++ b/src/phoenix/server/api/types/ProjectSession.py @@ -1,20 +1,125 @@ -from typing import Optional +from datetime import datetime +from typing import TYPE_CHECKING, Annotated, ClassVar, Optional, Type import strawberry -from sqlalchemy import desc, select -from strawberry import UNSET, Info +from openinference.semconv.trace import SpanAttributes +from sqlalchemy import func, select +from sqlalchemy.sql.functions import coalesce +from strawberry import UNSET, Info, lazy from strawberry.relay import Connection, Node, NodeID from phoenix.db import models from phoenix.server.api.context import Context +from phoenix.server.api.types.MimeType import MimeType from phoenix.server.api.types.pagination import ConnectionArgs, CursorString, connection_from_list -from phoenix.server.api.types.Trace import Trace, to_gql_trace +from phoenix.server.api.types.SpanIOValue import SpanIOValue +from phoenix.server.api.types.TokenUsage import TokenUsage + +if TYPE_CHECKING: + from phoenix.server.api.types.Trace import Trace @strawberry.type class ProjectSession(Node): + _table: ClassVar[Type[models.ProjectSession]] = models.ProjectSession id_attr: NodeID[int] session_id: str + session_user: Optional[str] + start_time: datetime + end_time: datetime + + @strawberry.field(description="Duration of the session in seconds") # type: ignore + async def duration_s(self) -> float: + return (self.end_time - self.start_time).total_seconds() + + @strawberry.field + async def num_traces( + self, + info: Info[Context, None], + ) -> int: + stmt = select(func.count(models.Trace.id)).filter_by(project_session_rowid=self.id_attr) + async with info.context.db() as session: + return await session.scalar(stmt) or 0 + + @strawberry.field + async def first_input( + self, + info: Info[Context, None], + ) -> Optional[SpanIOValue]: + stmt = ( + select( + models.Span.attributes[INPUT_VALUE].label("value"), + models.Span.attributes[INPUT_MIME_TYPE].label("mime_type"), + ) + .join(models.Trace) + .filter_by(project_session_rowid=self.id_attr) + .where(models.Span.parent_id.is_(None)) + .order_by(models.Trace.start_time.asc()) + .limit(1) + ) + async with info.context.db() as session: + record = (await session.execute(stmt)).first() + if record is None or record.value is None: + return None + return SpanIOValue( + mime_type=MimeType(record.mime_type), + value=record.value, + ) + + @strawberry.field + async def last_output( + self, + info: Info[Context, None], + ) -> Optional[SpanIOValue]: + stmt = ( + select( + models.Span.attributes[OUTPUT_VALUE].label("value"), + models.Span.attributes[OUTPUT_MIME_TYPE].label("mime_type"), + ) + .join(models.Trace) + .filter_by(project_session_rowid=self.id_attr) + .where(models.Span.parent_id.is_(None)) + .order_by(models.Trace.start_time.desc()) + .limit(1) + ) + async with info.context.db() as session: + record = (await session.execute(stmt)).first() + if record is None or record.value is None: + return None + return SpanIOValue( + mime_type=MimeType(record.mime_type), + value=record.value, + ) + + @strawberry.field + async def token_usage( + self, + info: Info[Context, None], + ) -> TokenUsage: + stmt = ( + select( + func.sum(coalesce(models.Span.cumulative_llm_token_count_prompt, 0)).label( + "prompt" + ), + func.sum(coalesce(models.Span.cumulative_llm_token_count_completion, 0)).label( + "completion" + ), + ) + .join(models.Trace) + .filter_by(project_session_rowid=self.id_attr) + .where(models.Span.parent_id.is_(None)) + .limit(1) + ) + async with info.context.db() as session: + usage = (await session.execute(stmt)).first() + return ( + TokenUsage( + prompt=usage.prompt or 0, + completion=usage.completion or 0, + ) + if usage + else TokenUsage() + ) @strawberry.field async def traces( @@ -24,7 +129,9 @@ async def traces( last: Optional[int] = UNSET, after: Optional[CursorString] = UNSET, before: Optional[CursorString] = UNSET, - ) -> Connection[Trace]: + ) -> Connection[Annotated["Trace", lazy(".Trace")]]: + from phoenix.server.api.types.Trace import to_gql_trace + args = ConnectionArgs( first=first, after=after if isinstance(after, CursorString) else None, @@ -33,8 +140,8 @@ async def traces( ) stmt = ( select(models.Trace) - .filter_by(project_session_id=self.id_attr) - .order_by(desc(models.Trace.id)) + .filter_by(project_session_rowid=self.id_attr) + .order_by(models.Trace.start_time) .limit(first) ) async with info.context.db() as session: @@ -43,10 +150,17 @@ async def traces( return connection_from_list(data=data, args=args) -def to_gql_project_session( - project_session: models.ProjectSession, -) -> ProjectSession: +def to_gql_project_session(project_session: models.ProjectSession) -> ProjectSession: return ProjectSession( id_attr=project_session.id, session_id=project_session.session_id, + session_user=project_session.session_user, + start_time=project_session.start_time, + end_time=project_session.end_time, ) + + +INPUT_VALUE = SpanAttributes.INPUT_VALUE.split(".") +INPUT_MIME_TYPE = SpanAttributes.INPUT_MIME_TYPE.split(".") +OUTPUT_VALUE = SpanAttributes.OUTPUT_VALUE.split(".") +OUTPUT_MIME_TYPE = SpanAttributes.OUTPUT_MIME_TYPE.split(".") diff --git a/src/phoenix/server/api/types/Span.py b/src/phoenix/server/api/types/Span.py index 19a9966518..eef31a2809 100644 --- a/src/phoenix/server/api/types/Span.py +++ b/src/phoenix/server/api/types/Span.py @@ -24,17 +24,16 @@ SpanAnnotationColumn, SpanAnnotationSort, ) +from phoenix.server.api.types.DocumentRetrievalMetrics import DocumentRetrievalMetrics +from phoenix.server.api.types.Evaluation import DocumentEvaluation +from phoenix.server.api.types.ExampleRevisionInterface import ExampleRevision from phoenix.server.api.types.GenerativeProvider import GenerativeProvider +from phoenix.server.api.types.MimeType import MimeType from phoenix.server.api.types.SortDir import SortDir -from phoenix.server.api.types.SpanAnnotation import to_gql_span_annotation +from phoenix.server.api.types.SpanAnnotation import SpanAnnotation, to_gql_span_annotation +from phoenix.server.api.types.SpanIOValue import SpanIOValue from phoenix.trace.attributes import get_attribute_value -from .DocumentRetrievalMetrics import DocumentRetrievalMetrics -from .Evaluation import DocumentEvaluation -from .ExampleRevisionInterface import ExampleRevision -from .MimeType import MimeType -from .SpanAnnotation import SpanAnnotation - if TYPE_CHECKING: from phoenix.server.api.types.Project import Project @@ -71,18 +70,6 @@ class SpanContext: span_id: ID -@strawberry.type -class SpanIOValue: - mime_type: MimeType - value: str - - @strawberry.field( - description="Truncate value up to `chars` characters, appending '...' if truncated.", - ) # type: ignore - def truncated_value(self, chars: int = 100) -> str: - return f"{self.value[: max(0, chars - 3)]}..." if len(self.value) > chars else self.value - - @strawberry.enum class SpanStatusCode(Enum): OK = "OK" diff --git a/src/phoenix/server/api/types/SpanIOValue.py b/src/phoenix/server/api/types/SpanIOValue.py new file mode 100644 index 0000000000..d6395af686 --- /dev/null +++ b/src/phoenix/server/api/types/SpanIOValue.py @@ -0,0 +1,15 @@ +import strawberry + +from phoenix.server.api.types.MimeType import MimeType + + +@strawberry.type +class SpanIOValue: + mime_type: MimeType + value: str + + @strawberry.field( + description="Truncate value up to `chars` characters, appending '...' if truncated.", + ) # type: ignore + def truncated_value(self, chars: int = 100) -> str: + return f"{self.value[: max(0, chars - 3)]}..." if len(self.value) > chars else self.value diff --git a/src/phoenix/server/api/types/TokenUsage.py b/src/phoenix/server/api/types/TokenUsage.py new file mode 100644 index 0000000000..5e9c9896ba --- /dev/null +++ b/src/phoenix/server/api/types/TokenUsage.py @@ -0,0 +1,11 @@ +import strawberry + + +@strawberry.type +class TokenUsage: + prompt: int = 0 + completion: int = 0 + + @strawberry.field + async def total(self) -> int: + return self.prompt + self.completion diff --git a/src/phoenix/server/api/types/Trace.py b/src/phoenix/server/api/types/Trace.py index 5ba5348b12..22bc9df476 100644 --- a/src/phoenix/server/api/types/Trace.py +++ b/src/phoenix/server/api/types/Trace.py @@ -1,11 +1,13 @@ from __future__ import annotations -from typing import Optional +from datetime import datetime +from typing import TYPE_CHECKING, Annotated, Optional, Union import strawberry +from openinference.semconv.trace import SpanAttributes from sqlalchemy import desc, select from sqlalchemy.orm import contains_eager -from strawberry import UNSET, Private +from strawberry import UNSET, Private, lazy from strawberry.relay import Connection, GlobalID, Node, NodeID from strawberry.types import Info @@ -21,12 +23,18 @@ from phoenix.server.api.types.Span import Span, to_gql_span from phoenix.server.api.types.TraceAnnotation import TraceAnnotation, to_gql_trace_annotation +if TYPE_CHECKING: + from phoenix.server.api.types.ProjectSession import ProjectSession + @strawberry.type class Trace(Node): id_attr: NodeID[int] project_rowid: Private[int] + project_session_rowid: Private[Optional[int]] trace_id: str + start_time: datetime + end_time: datetime @strawberry.field async def project_id(self) -> GlobalID: @@ -34,6 +42,49 @@ async def project_id(self) -> GlobalID: return GlobalID(type_name=Project.__name__, node_id=str(self.project_rowid)) + @strawberry.field + async def project_session_id(self) -> Optional[GlobalID]: + if self.project_session_rowid is None: + return None + from phoenix.server.api.types.ProjectSession import ProjectSession + + return GlobalID(type_name=ProjectSession.__name__, node_id=str(self.project_session_rowid)) + + @strawberry.field + async def session( + self, + info: Info[Context, None], + ) -> Union[Annotated["ProjectSession", lazy(".ProjectSession")], None]: + if self.project_session_rowid is None: + return None + from phoenix.server.api.types.ProjectSession import to_gql_project_session + + stmt = select(models.ProjectSession).filter_by(id=self.project_session_rowid) + async with info.context.db() as session: + project_session = await session.scalar(stmt) + if project_session is None: + return None + return to_gql_project_session(project_session) + + @strawberry.field + async def root_span( + self, + info: Info[Context, None], + ) -> Optional[Span]: + stmt = ( + select(models.Span) + .join(models.Trace) + .where(models.Trace.id == self.id_attr) + .options(contains_eager(models.Span.trace).load_only(models.Trace.trace_id)) + .where(models.Span.parent_id.is_(None)) + .limit(1) + ) + async with info.context.db() as session: + span = await session.scalar(stmt) + if span is None: + return None + return to_gql_span(span) + @strawberry.field async def spans( self, @@ -88,5 +139,12 @@ def to_gql_trace(trace: models.Trace) -> Trace: return Trace( id_attr=trace.id, project_rowid=trace.project_rowid, + project_session_rowid=trace.project_session_rowid, trace_id=trace.trace_id, + start_time=trace.start_time, + end_time=trace.end_time, ) + + +INPUT_VALUE = SpanAttributes.INPUT_VALUE.split(".") +OUTPUT_VALUE = SpanAttributes.OUTPUT_VALUE.split(".") diff --git a/src/phoenix/server/app.py b/src/phoenix/server/app.py index c580d2daf0..b7fb5c2372 100644 --- a/src/phoenix/server/app.py +++ b/src/phoenix/server/app.py @@ -92,7 +92,7 @@ SpanDescendantsDataLoader, SpanProjectsDataLoader, TokenCountDataLoader, - TraceRowIdsDataLoader, + TraceByTraceIdsDataLoader, UserRolesDataLoader, UsersDataLoader, ) @@ -617,7 +617,7 @@ def get_context() -> Context: db, cache_map=cache_for_dataloaders.token_count if cache_for_dataloaders else None, ), - trace_row_ids=TraceRowIdsDataLoader(db), + trace_by_trace_ids=TraceByTraceIdsDataLoader(db), project_by_name=ProjectByNameDataLoader(db), users=UsersDataLoader(db), user_roles=UserRolesDataLoader(db), diff --git a/tests/integration/db_migrations/test_up_and_down_migrations.py b/tests/integration/db_migrations/test_up_and_down_migrations.py index 0b81f1724d..08c6b7bb5f 100644 --- a/tests/integration/db_migrations/test_up_and_down_migrations.py +++ b/tests/integration/db_migrations/test_up_and_down_migrations.py @@ -68,6 +68,12 @@ def test_up_and_down_migrations( assert isinstance(column.type, VARCHAR) del column + column = columns.pop("session_user", None) + assert column is not None + assert column.nullable + assert isinstance(column.type, VARCHAR) + del column + column = columns.pop("project_id", None) assert column is not None assert not column.nullable @@ -96,7 +102,7 @@ def test_up_and_down_migrations( assert not index.unique del index - index = indexes.pop("ix_project_sessions_end_time", None) + index = indexes.pop("ix_project_sessions_session_user", None) assert index is not None assert not index.unique del index @@ -164,7 +170,7 @@ def test_up_and_down_migrations( assert isinstance(column.type, TIMESTAMP) del column - column = columns.pop("project_session_id", None) + column = columns.pop("project_session_rowid", None) assert column is not None assert column.nullable assert isinstance(column.type, INTEGER) @@ -185,7 +191,7 @@ def test_up_and_down_migrations( assert not index.unique del index - index = indexes.pop("ix_traces_project_session_id", None) + index = indexes.pop("ix_traces_project_session_rowid", None) assert index is not None assert not index.unique del index @@ -208,7 +214,7 @@ def test_up_and_down_migrations( assert constraint.ondelete == "CASCADE" del constraint - constraint = constraints.pop("fk_traces_project_session_id_project_sessions", None) + constraint = constraints.pop("fk_traces_project_session_rowid_project_sessions", None) assert isinstance(constraint, ForeignKeyConstraint) assert constraint.ondelete == "CASCADE" del constraint diff --git a/tests/unit/_helpers.py b/tests/unit/_helpers.py new file mode 100644 index 0000000000..85f8107622 --- /dev/null +++ b/tests/unit/_helpers.py @@ -0,0 +1,133 @@ +from datetime import datetime, timedelta, timezone +from secrets import token_hex +from typing import Any, Dict, Optional, Type, TypeVar, cast + +import httpx +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession +from strawberry.relay import GlobalID + +from phoenix.db import models + + +async def _node( + field: str, + type_name: str, + id_: int, + httpx_client: httpx.AsyncClient, +) -> dict[str, Any]: + query = "query($id:GlobalID!){node(id:$id){... on " + type_name + "{" + field + "}}}" + gid = str(GlobalID(type_name, str(id_))) + response = await httpx_client.post( + "/graphql", + json={"query": query, "variables": {"id": gid}}, + ) + assert response.status_code == 200 + response_json = response.json() + assert response_json.get("errors") is None + key = field.split("{")[0].split("(")[0] + return cast(dict[str, Any], response_json["data"]["node"][key]) + + +_RecordT = TypeVar("_RecordT", bound=models.Base) + + +async def _get_record_by_id( + session: AsyncSession, + table: Type[_RecordT], + id_: int, +) -> Optional[_RecordT]: + return cast(Optional[_RecordT], await session.scalar(select(table).filter_by(id=id_))) + + +async def _add_project( + session: AsyncSession, + name: Optional[str] = None, +) -> models.Project: + project = models.Project(name=name or token_hex(4)) + session.add(project) + await session.flush() + assert isinstance(await _get_record_by_id(session, models.Project, project.id), models.Project) + return project + + +async def _add_trace( + session: AsyncSession, + project: models.Project, + project_session: Optional[models.ProjectSession] = None, + start_time: Optional[datetime] = None, + end_time: Optional[datetime] = None, +) -> models.Trace: + start_time = start_time or datetime.now(timezone.utc) + end_time = end_time or (start_time + timedelta(seconds=10)) + trace = models.Trace( + trace_id=token_hex(16), + start_time=start_time, + end_time=end_time, + project_rowid=project.id, + project_session_rowid=None if project_session is None else project_session.id, + ) + session.add(trace) + await session.flush() + assert isinstance(await _get_record_by_id(session, models.Trace, trace.id), models.Trace) + return trace + + +async def _add_span( + session: AsyncSession, + trace: models.Trace, + attributes: Optional[Dict[str, Any]] = None, + start_time: Optional[datetime] = None, + end_time: Optional[datetime] = None, + parent_span: Optional[models.Span] = None, + span_kind: str = "LLM", + cumulative_llm_token_count_prompt: int = 0, + cumulative_llm_token_count_completion: int = 0, +) -> models.Span: + start_time = start_time or datetime.now(timezone.utc) + end_time = end_time or (start_time + timedelta(seconds=10)) + span = models.Span( + name=token_hex(4), + span_id=token_hex(8), + parent_id=None if parent_span is None else parent_span.span_id, + span_kind=span_kind, + start_time=start_time, + end_time=end_time, + status_code="OK", + status_message="test_status_message", + cumulative_error_count=0, + cumulative_llm_token_count_prompt=cumulative_llm_token_count_prompt, + cumulative_llm_token_count_completion=cumulative_llm_token_count_completion, + attributes=attributes or {}, + trace_rowid=trace.id, + ) + session.add(span) + await session.flush() + assert isinstance(await _get_record_by_id(session, models.Span, span.id), models.Span) + return span + + +async def _add_project_session( + session: AsyncSession, + project: models.Project, + session_id: Optional[str] = None, + session_user: Optional[str] = None, + start_time: Optional[datetime] = None, + end_time: Optional[datetime] = None, +) -> models.ProjectSession: + start_time = start_time or datetime.now(timezone.utc) + end_time = end_time or (start_time + timedelta(seconds=10)) + project_session = models.ProjectSession( + session_id=session_id or token_hex(4), + session_user=session_user, + project_id=project.id, + start_time=start_time, + end_time=end_time, + ) + session.add(project_session) + await session.flush() + assert isinstance( + await _get_record_by_id(session, models.ProjectSession, project_session.id), + models.ProjectSession, + ) + return project_session diff --git a/tests/unit/server/api/types/test_ProjectSession.py b/tests/unit/server/api/types/test_ProjectSession.py new file mode 100644 index 0000000000..13a0a7939d --- /dev/null +++ b/tests/unit/server/api/types/test_ProjectSession.py @@ -0,0 +1,146 @@ +from datetime import datetime, timedelta, timezone +from typing import Any + +import httpx +import pytest +from faker import Faker +from strawberry.relay import GlobalID +from typing_extensions import TypeAlias + +from phoenix.db import models +from phoenix.server.api.types.ProjectSession import ProjectSession +from phoenix.server.api.types.Trace import Trace +from phoenix.server.types import DbSessionFactory + +from ...._helpers import _add_project, _add_project_session, _add_span, _add_trace, _node + +_Data: TypeAlias = tuple[ + list[models.ProjectSession], + list[models.Trace], + list[models.Project], +] + + +class TestProjectSession: + @staticmethod + async def _node( + field: str, + project_session: models.ProjectSession, + httpx_client: httpx.AsyncClient, + ) -> Any: + return await _node( + field, + ProjectSession.__name__, + project_session.id, + httpx_client, + ) + + @pytest.fixture + async def _data( + self, + db: DbSessionFactory, + fake: Faker, + ) -> _Data: + project_sessions = [] + traces = [] + async with db() as session: + project = await _add_project(session) + start_time = datetime.now(timezone.utc) + project_sessions.append( + await _add_project_session( + session, + project, + session_user="xyz", + start_time=start_time, + ) + ) + traces.append( + await _add_trace( + session, + project, + project_sessions[-1], + start_time=start_time, + ) + ) + await _add_span( + session, + traces[-1], + attributes={"input": {"value": "123"}, "output": {"value": "321"}}, + cumulative_llm_token_count_prompt=1, + cumulative_llm_token_count_completion=2, + ) + traces.append( + await _add_trace( + session, + project, + project_sessions[-1], + start_time=start_time + timedelta(seconds=1), + ) + ) + await _add_span( + session, + traces[-1], + attributes={"input": {"value": "1234"}, "output": {"value": "4321"}}, + cumulative_llm_token_count_prompt=3, + cumulative_llm_token_count_completion=4, + ) + project_sessions.append(await _add_project_session(session, project)) + return project_sessions, traces, [project] + + async def test_session_user( + self, + _data: _Data, + httpx_client: httpx.AsyncClient, + ) -> None: + assert await self._node("sessionUser", _data[0][0], httpx_client) == "xyz" + assert await self._node("sessionUser", _data[0][1], httpx_client) is None + + async def test_num_traces( + self, + _data: _Data, + httpx_client: httpx.AsyncClient, + ) -> None: + assert await self._node("numTraces", _data[0][0], httpx_client) == 2 + + async def test_first_input( + self, + _data: _Data, + httpx_client: httpx.AsyncClient, + ) -> None: + assert await self._node( + "firstInput{value mimeType}", + _data[0][0], + httpx_client, + ) == {"value": "123", "mimeType": "text"} + + async def test_last_output( + self, + _data: _Data, + httpx_client: httpx.AsyncClient, + ) -> None: + assert await self._node( + "lastOutput{value mimeType}", + _data[0][0], + httpx_client, + ) == {"value": "4321", "mimeType": "text"} + + async def test_traces( + self, + _data: _Data, + httpx_client: httpx.AsyncClient, + ) -> None: + traces = await self._node("traces{edges{node{id}}}", _data[0][0], httpx_client) + assert {edge["node"]["id"] for edge in traces["edges"]} == { + str(GlobalID(Trace.__name__, str(trace.id))) for trace in _data[1] + } + + async def test_token_usage( + self, + _data: _Data, + httpx_client: httpx.AsyncClient, + ) -> None: + assert await self._node( + "tokenUsage{prompt completion total}", + _data[0][0], + httpx_client, + ) == {"prompt": 4, "completion": 6, "total": 10} diff --git a/tests/unit/server/api/types/test_Trace.py b/tests/unit/server/api/types/test_Trace.py new file mode 100644 index 0000000000..1fb817d85f --- /dev/null +++ b/tests/unit/server/api/types/test_Trace.py @@ -0,0 +1,56 @@ +from typing import Any + +import httpx +import pytest +from strawberry.relay import GlobalID +from typing_extensions import TypeAlias + +from phoenix.db import models +from phoenix.server.api.types.ProjectSession import ProjectSession +from phoenix.server.api.types.Trace import Trace +from phoenix.server.types import DbSessionFactory + +from ...._helpers import _add_project, _add_project_session, _add_trace, _node + +_Data: TypeAlias = tuple[ + list[models.Trace], + list[models.ProjectSession], + list[models.Project], +] + + +class TestTrace: + @staticmethod + async def _node( + field: str, + trace: models.Trace, + httpx_client: httpx.AsyncClient, + ) -> Any: + return await _node( + field, + Trace.__name__, + trace.id, + httpx_client, + ) + + @pytest.fixture + async def _data(self, db: DbSessionFactory) -> _Data: + traces = [] + async with db() as session: + project = await _add_project(session) + project_session = await _add_project_session(session, project) + traces.append(await _add_trace(session, project)) + traces.append(await _add_trace(session, project, project_session)) + return traces, [project_session], [project] + + async def test_session( + self, + _data: _Data, + httpx_client: httpx.AsyncClient, + ) -> None: + traces = _data[0] + project_session = _data[1][0] + assert await self._node("session{id}", traces[0], httpx_client) is None + assert await self._node("session{id}", traces[1], httpx_client) == { + "id": str(GlobalID(ProjectSession.__name__, str(project_session.id))) + }