From 80195eea0c600a63edabb49289d9673184386bed Mon Sep 17 00:00:00 2001 From: Roger Yang Date: Wed, 20 Mar 2024 16:40:13 -0700 Subject: [PATCH] feat: add trace node --- app/schema.graphql | 33 ++ app/src/pages/trace/TracePage.tsx | 111 ++++--- .../__generated__/TracePageQuery.graphql.ts | 301 +++++++++--------- src/phoenix/core/project.py | 163 +++++++++- src/phoenix/server/api/types/Evaluation.py | 22 +- src/phoenix/server/api/types/Project.py | 56 +++- src/phoenix/server/api/types/Trace.py | 44 +++ 7 files changed, 510 insertions(+), 220 deletions(-) create mode 100644 src/phoenix/server/api/types/Trace.py diff --git a/app/schema.graphql b/app/schema.graphql index 5b9a7b3b7f..bd815ac562 100644 --- a/app/schema.graphql +++ b/app/schema.graphql @@ -546,8 +546,14 @@ type Project implements Node { tokenCountTotal: Int! latencyMsP50: Float latencyMsP99: Float + trace(traceId: ID!): Trace spans(timeRange: TimeRange, traceIds: [ID!], first: Int = 50, last: Int, after: String, before: String, sort: SpanSort, rootSpansOnly: Boolean, filterCondition: String): SpanConnection! + """ + Names of all available evaluations for traces. (The list contains no duplicates.) + """ + traceEvaluationNames: [String!]! + """ Names of all available evaluations for spans. (The list contains no duplicates.) """ @@ -555,6 +561,7 @@ type Project implements Node { """Names of available document evaluations.""" documentEvaluationNames(spanId: ID): [String!]! + traceEvaluationSummary(evaluationName: String!, timeRange: TimeRange): EvaluationSummary spanEvaluationSummary(evaluationName: String!, timeRange: TimeRange, filterCondition: String): EvaluationSummary documentEvaluationSummary(evaluationName: String!, timeRange: TimeRange, filterCondition: String): DocumentEvaluationSummary streamingLastUpdatedAt: DateTime @@ -795,6 +802,32 @@ type TimeSeriesDataPoint { value: Float } +type Trace { + traceId: ID! + spans(first: Int, last: Int, after: String, before: String): SpanConnection! + + """Evaluations associated with the trace""" + traceEvaluations: [TraceEvaluation!]! +} + +type TraceEvaluation implements Evaluation { + """Name of the evaluation, e.g. 'helpfulness' or 'relevance'.""" + name: String! + + """Result of the evaluation in the form of a numeric score.""" + score: Float + + """ + Result of the evaluation in the form of a string, e.g. 'helpful' or 'not helpful'. Note that the label is not necessarily binary. + """ + label: String + + """ + The evaluator's explanation for the evaluation result (i.e. score or label, or both) given to the subject. + """ + explanation: String +} + type UMAPPoint { id: GlobalID! diff --git a/app/src/pages/trace/TracePage.tsx b/app/src/pages/trace/TracePage.tsx index f9b9bf32f5..b2d46159a5 100644 --- a/app/src/pages/trace/TracePage.tsx +++ b/app/src/pages/trace/TracePage.tsx @@ -87,8 +87,8 @@ import { import { SpanEvaluationsTable } from "./SpanEvaluationsTable"; type Span = NonNullable< - TracePageQuery$data["project"]["spans"] ->["edges"][number]["span"]; + TracePageQuery$data["project"]["trace"] +>["spans"]["edges"][number]["span"]; type DocumentEvaluation = Span["documentEvaluations"][number]; /** * A span attribute object that is a map of string to an unknown value @@ -181,59 +181,57 @@ export function TracePage() { query TracePageQuery($traceId: ID!, $id: GlobalID!) { project: node(id: $id) { ... on Project { - spans( - traceIds: [$traceId] - sort: { col: startTime, dir: asc } - first: 1000 - ) { - edges { - span: node { - context { - spanId - } - name - spanKind - statusCode: propagatedStatusCode - statusMessage - startTime - parentId - latencyMs - tokenCountTotal - tokenCountPrompt - tokenCountCompletion - input { - value - mimeType - } - output { - value - mimeType - } - attributes - events { - name - message - timestamp - } - spanEvaluations { - name - label - score - } - documentRetrievalMetrics { - evaluationName - ndcg - precision - hit - } - documentEvaluations { - documentPosition + trace(traceId: $traceId) { + spans { + edges { + span: node { + context { + spanId + } name - label - score - explanation + spanKind + statusCode: propagatedStatusCode + statusMessage + startTime + parentId + latencyMs + tokenCountTotal + tokenCountPrompt + tokenCountCompletion + input { + value + mimeType + } + output { + value + mimeType + } + attributes + events { + name + message + timestamp + } + spanEvaluations { + name + label + score + } + documentRetrievalMetrics { + evaluationName + ndcg + precision + hit + } + documentEvaluations { + documentPosition + name + label + score + explanation + } + ...SpanEvaluationsTable_evals } - ...SpanEvaluationsTable_evals } } } @@ -246,10 +244,9 @@ export function TracePage() { fetchPolicy: "store-and-network", } ); - const spansList = useMemo(() => { - const gqlSpans = - data.project.spans || ([] as NonNullable); - return gqlSpans.edges.map((edge) => edge.span); + const spansList: Span[] = useMemo(() => { + const gqlSpans = data.project.trace?.spans.edges || []; + return gqlSpans.map((node) => node.span); }, [data]); const urlSelectedSpanId = searchParams.get("selectedSpanId"); const selectedSpanId = urlSelectedSpanId ?? spansList[0].context.spanId; diff --git a/app/src/pages/trace/__generated__/TracePageQuery.graphql.ts b/app/src/pages/trace/__generated__/TracePageQuery.graphql.ts index 9a6336e9e7..096c3355eb 100644 --- a/app/src/pages/trace/__generated__/TracePageQuery.graphql.ts +++ b/app/src/pages/trace/__generated__/TracePageQuery.graphql.ts @@ -1,5 +1,5 @@ /** - * @generated SignedSource<> + * @generated SignedSource<<13a08661653247c5bf9f082a0fe988ff>> * @lightSyntaxTransform * @nogrep */ @@ -19,58 +19,60 @@ export type TracePageQuery$variables = { }; export type TracePageQuery$data = { readonly project: { - readonly spans?: { - readonly edges: ReadonlyArray<{ - readonly span: { - readonly attributes: string; - readonly context: { - readonly spanId: string; - }; - readonly documentEvaluations: ReadonlyArray<{ - readonly documentPosition: number; - readonly explanation: string | null; - readonly label: string | null; - readonly name: string; - readonly score: number | null; - }>; - readonly documentRetrievalMetrics: ReadonlyArray<{ - readonly evaluationName: string; - readonly hit: number | null; - readonly ndcg: number | null; - readonly precision: number | null; - }>; - readonly events: ReadonlyArray<{ - readonly message: string; + readonly trace?: { + readonly spans: { + readonly edges: ReadonlyArray<{ + readonly span: { + readonly attributes: string; + readonly context: { + readonly spanId: string; + }; + readonly documentEvaluations: ReadonlyArray<{ + readonly documentPosition: number; + readonly explanation: string | null; + readonly label: string | null; + readonly name: string; + readonly score: number | null; + }>; + readonly documentRetrievalMetrics: ReadonlyArray<{ + readonly evaluationName: string; + readonly hit: number | null; + readonly ndcg: number | null; + readonly precision: number | null; + }>; + readonly events: ReadonlyArray<{ + readonly message: string; + readonly name: string; + readonly timestamp: string; + }>; + readonly input: { + readonly mimeType: MimeType; + readonly value: string; + } | null; + readonly latencyMs: number | null; readonly name: string; - readonly timestamp: string; - }>; - readonly input: { - readonly mimeType: MimeType; - readonly value: string; - } | null; - readonly latencyMs: number | null; - readonly name: string; - readonly output: { - readonly mimeType: MimeType; - readonly value: string; - } | null; - readonly parentId: string | null; - readonly spanEvaluations: ReadonlyArray<{ - readonly label: string | null; - readonly name: string; - readonly score: number | null; - }>; - readonly spanKind: SpanKind; - readonly startTime: string; - readonly statusCode: SpanStatusCode; - readonly statusMessage: string; - readonly tokenCountCompletion: number | null; - readonly tokenCountPrompt: number | null; - readonly tokenCountTotal: number | null; - readonly " $fragmentSpreads": FragmentRefs<"SpanEvaluationsTable_evals">; - }; - }>; - }; + readonly output: { + readonly mimeType: MimeType; + readonly value: string; + } | null; + readonly parentId: string | null; + readonly spanEvaluations: ReadonlyArray<{ + readonly label: string | null; + readonly name: string; + readonly score: number | null; + }>; + readonly spanKind: SpanKind; + readonly startTime: string; + readonly statusCode: SpanStatusCode; + readonly statusMessage: string; + readonly tokenCountCompletion: number | null; + readonly tokenCountPrompt: number | null; + readonly tokenCountTotal: number | null; + readonly " $fragmentSpreads": FragmentRefs<"SpanEvaluationsTable_evals">; + }; + }>; + }; + } | null; }; }; export type TracePageQuery = { @@ -98,28 +100,9 @@ v2 = [ ], v3 = [ { - "kind": "Literal", - "name": "first", - "value": 1000 - }, - { - "kind": "Literal", - "name": "sort", - "value": { - "col": "startTime", - "dir": "asc" - } - }, - { - "items": [ - { - "kind": "Variable", - "name": "traceIds.0", - "variableName": "traceId" - } - ], - "kind": "ListValue", - "name": "traceIds" + "kind": "Variable", + "name": "traceId", + "variableName": "traceId" } ], v4 = { @@ -385,62 +368,73 @@ return { { "alias": null, "args": (v3/*: any*/), - "concreteType": "SpanConnection", + "concreteType": "Trace", "kind": "LinkedField", - "name": "spans", + "name": "trace", "plural": false, "selections": [ { "alias": null, "args": null, - "concreteType": "SpanEdge", + "concreteType": "SpanConnection", "kind": "LinkedField", - "name": "edges", - "plural": true, + "name": "spans", + "plural": false, "selections": [ { - "alias": "span", + "alias": null, "args": null, - "concreteType": "Span", + "concreteType": "SpanEdge", "kind": "LinkedField", - "name": "node", - "plural": false, + "name": "edges", + "plural": true, "selections": [ - (v4/*: any*/), - (v5/*: any*/), - (v6/*: any*/), - (v7/*: any*/), - (v8/*: any*/), - (v9/*: any*/), - (v10/*: any*/), - (v11/*: any*/), - (v12/*: any*/), - (v13/*: any*/), - (v14/*: any*/), - (v16/*: any*/), - (v17/*: any*/), - (v18/*: any*/), - (v19/*: any*/), { - "alias": null, + "alias": "span", "args": null, - "concreteType": "SpanEvaluation", + "concreteType": "Span", "kind": "LinkedField", - "name": "spanEvaluations", - "plural": true, + "name": "node", + "plural": false, "selections": [ + (v4/*: any*/), (v5/*: any*/), - (v20/*: any*/), - (v21/*: any*/) + (v6/*: any*/), + (v7/*: any*/), + (v8/*: any*/), + (v9/*: any*/), + (v10/*: any*/), + (v11/*: any*/), + (v12/*: any*/), + (v13/*: any*/), + (v14/*: any*/), + (v16/*: any*/), + (v17/*: any*/), + (v18/*: any*/), + (v19/*: any*/), + { + "alias": null, + "args": null, + "concreteType": "SpanEvaluation", + "kind": "LinkedField", + "name": "spanEvaluations", + "plural": true, + "selections": [ + (v5/*: any*/), + (v20/*: any*/), + (v21/*: any*/) + ], + "storageKey": null + }, + (v22/*: any*/), + (v24/*: any*/), + { + "args": null, + "kind": "FragmentSpread", + "name": "SpanEvaluationsTable_evals" + } ], "storageKey": null - }, - (v22/*: any*/), - (v24/*: any*/), - { - "args": null, - "kind": "FragmentSpread", - "name": "SpanEvaluationsTable_evals" } ], "storageKey": null @@ -492,59 +486,70 @@ return { { "alias": null, "args": (v3/*: any*/), - "concreteType": "SpanConnection", + "concreteType": "Trace", "kind": "LinkedField", - "name": "spans", + "name": "trace", "plural": false, "selections": [ { "alias": null, "args": null, - "concreteType": "SpanEdge", + "concreteType": "SpanConnection", "kind": "LinkedField", - "name": "edges", - "plural": true, + "name": "spans", + "plural": false, "selections": [ { - "alias": "span", + "alias": null, "args": null, - "concreteType": "Span", + "concreteType": "SpanEdge", "kind": "LinkedField", - "name": "node", - "plural": false, + "name": "edges", + "plural": true, "selections": [ - (v4/*: any*/), - (v5/*: any*/), - (v6/*: any*/), - (v7/*: any*/), - (v8/*: any*/), - (v9/*: any*/), - (v10/*: any*/), - (v11/*: any*/), - (v12/*: any*/), - (v13/*: any*/), - (v14/*: any*/), - (v16/*: any*/), - (v17/*: any*/), - (v18/*: any*/), - (v19/*: any*/), { - "alias": null, + "alias": "span", "args": null, - "concreteType": "SpanEvaluation", + "concreteType": "Span", "kind": "LinkedField", - "name": "spanEvaluations", - "plural": true, + "name": "node", + "plural": false, "selections": [ + (v4/*: any*/), (v5/*: any*/), - (v20/*: any*/), - (v21/*: any*/), - (v23/*: any*/) + (v6/*: any*/), + (v7/*: any*/), + (v8/*: any*/), + (v9/*: any*/), + (v10/*: any*/), + (v11/*: any*/), + (v12/*: any*/), + (v13/*: any*/), + (v14/*: any*/), + (v16/*: any*/), + (v17/*: any*/), + (v18/*: any*/), + (v19/*: any*/), + { + "alias": null, + "args": null, + "concreteType": "SpanEvaluation", + "kind": "LinkedField", + "name": "spanEvaluations", + "plural": true, + "selections": [ + (v5/*: any*/), + (v20/*: any*/), + (v21/*: any*/), + (v23/*: any*/) + ], + "storageKey": null + }, + (v22/*: any*/), + (v24/*: any*/) ], "storageKey": null - }, - (v22/*: any*/), - (v24/*: any*/) + } ], "storageKey": null } @@ -575,16 +580,16 @@ return { ] }, "params": { - "cacheID": "21ef83105845ac504ea473ee46fafe2a", + "cacheID": "16e7e048ca173d571bd9183a5e8e5e07", "id": null, "metadata": {}, "name": "TracePageQuery", "operationKind": "query", - "text": "query TracePageQuery(\n $traceId: ID!\n $id: GlobalID!\n) {\n project: node(id: $id) {\n __typename\n ... on Project {\n spans(traceIds: [$traceId], sort: {col: startTime, dir: asc}, first: 1000) {\n edges {\n span: node {\n context {\n spanId\n }\n name\n spanKind\n statusCode: propagatedStatusCode\n statusMessage\n startTime\n parentId\n latencyMs\n tokenCountTotal\n tokenCountPrompt\n tokenCountCompletion\n input {\n value\n mimeType\n }\n output {\n value\n mimeType\n }\n attributes\n events {\n name\n message\n timestamp\n }\n spanEvaluations {\n name\n label\n score\n }\n documentRetrievalMetrics {\n evaluationName\n ndcg\n precision\n hit\n }\n documentEvaluations {\n documentPosition\n name\n label\n score\n explanation\n }\n ...SpanEvaluationsTable_evals\n }\n }\n }\n }\n __isNode: __typename\n id\n }\n}\n\nfragment SpanEvaluationsTable_evals on Span {\n spanEvaluations {\n name\n label\n score\n explanation\n }\n}\n" + "text": "query TracePageQuery(\n $traceId: ID!\n $id: GlobalID!\n) {\n project: node(id: $id) {\n __typename\n ... on Project {\n trace(traceId: $traceId) {\n spans {\n edges {\n span: node {\n context {\n spanId\n }\n name\n spanKind\n statusCode: propagatedStatusCode\n statusMessage\n startTime\n parentId\n latencyMs\n tokenCountTotal\n tokenCountPrompt\n tokenCountCompletion\n input {\n value\n mimeType\n }\n output {\n value\n mimeType\n }\n attributes\n events {\n name\n message\n timestamp\n }\n spanEvaluations {\n name\n label\n score\n }\n documentRetrievalMetrics {\n evaluationName\n ndcg\n precision\n hit\n }\n documentEvaluations {\n documentPosition\n name\n label\n score\n explanation\n }\n ...SpanEvaluationsTable_evals\n }\n }\n }\n }\n }\n __isNode: __typename\n id\n }\n}\n\nfragment SpanEvaluationsTable_evals on Span {\n spanEvaluations {\n name\n label\n score\n explanation\n }\n}\n" } }; })(); -(node as any).hash = "89241948b75d8d017caaa59858e4a77b"; +(node as any).hash = "a5c570251288f47b3e84e7cc62c1b019"; export default node; diff --git a/src/phoenix/core/project.py b/src/phoenix/core/project.py index 5592cee276..de99c66f86 100644 --- a/src/phoenix/core/project.py +++ b/src/phoenix/core/project.py @@ -111,9 +111,20 @@ def add_span(self, span: Span) -> None: def add_eval(self, pb_eval: pb.Evaluation) -> None: self._evals.add(pb_eval) + def has_trace(self, trace_id: TraceID) -> bool: + return self._spans.has_trace(trace_id) + def get_trace(self, trace_id: TraceID) -> Iterator[WrappedSpan]: yield from self._spans.get_trace(trace_id) + def get_trace_ids( + self, + start_time: Optional[datetime] = None, + stop_time: Optional[datetime] = None, + trace_ids: Optional[Iterable[TraceID]] = None, + ) -> Iterator[TraceID]: + yield from self._spans.get_trace_ids(start_time, stop_time, trace_ids) + def get_spans( self, start_time: Optional[datetime] = None, @@ -155,6 +166,21 @@ def token_count_total(self) -> int: def right_open_time_range(self) -> Tuple[Optional[datetime], Optional[datetime]]: return self._spans.right_open_time_range + def get_trace_evaluation(self, trace_id: TraceID, name: str) -> Optional[pb.Evaluation]: + return self._evals.get_trace_evaluation(trace_id, name) + + def get_trace_evaluation_names(self) -> List[EvaluationName]: + return self._evals.get_trace_evaluation_names() + + def get_trace_evaluation_labels(self, name: EvaluationName) -> Tuple[str, ...]: + return self._evals.get_trace_evaluation_labels(name) + + def get_trace_evaluation_trace_ids(self, name: EvaluationName) -> Tuple[TraceID, ...]: + return self._evals.get_trace_evaluation_trace_ids(name) + + def get_evaluations_by_trace_id(self, trace_id: TraceID) -> List[pb.Evaluation]: + return self._evals.get_evaluations_by_trace_id(trace_id) + def get_span_evaluation(self, span_id: SpanID, name: str) -> Optional[pb.Evaluation]: return self._evals.get_span_evaluation(span_id, name) @@ -201,12 +227,43 @@ def is_archived(self) -> bool: return self._is_archived +class _Trace: + def __init__(self, span: WrappedSpan) -> None: + self._trace_id: TraceID = span.context.trace_id + self._min_start_time: datetime = span.start_time + self._max_end_time: datetime = span.end_time + self._spans: List[WrappedSpan] = [span] + + @property + def trace_id(self) -> TraceID: + return self._trace_id + + @property + def start_time(self) -> datetime: + return self._min_start_time + + @property + def latency_ms(self) -> float: + return (self._max_end_time - self._min_start_time).total_seconds() * 1000 + + def add(self, span: WrappedSpan) -> None: + self._min_start_time = min(self._min_start_time, span.start_time) + self._max_end_time = max(self._max_end_time, span.end_time) + self._spans.append(span) + + def __eq__(self, other: Any) -> bool: + return self is other + + def __iter__(self) -> Iterator[WrappedSpan]: + yield from self._spans + + class _Spans: def __init__(self) -> None: self._lock = RLock() self._spans: Dict[SpanID, WrappedSpan] = {} self._parent_span_ids: Dict[SpanID, _ParentSpanID] = {} - self._traces: DefaultDict[TraceID, Set[WrappedSpan]] = defaultdict(set) + self._traces: Dict[TraceID, _Trace] = {} self._child_spans: DefaultDict[SpanID, Set[WrappedSpan]] = defaultdict(set) self._num_documents: DefaultDict[SpanID, int] = defaultdict(int) self._start_time_sorted_spans: SortedKeyList[WrappedSpan] = SortedKeyList( @@ -221,12 +278,18 @@ def __init__(self) -> None: (or will not arrive). For spans whose parent is not None, the root span status is temporary and will be revoked when its parent span arrives. """ - self._latency_sorted_root_spans: SortedKeyList[WrappedSpan] = SortedKeyList( - key=lambda span: span[ComputedAttributes.LATENCY_MS], + self._latency_sorted_traces: SortedKeyList[_Trace] = SortedKeyList( + key=lambda trace: trace.latency_ms, + ) + self._start_time_sorted_traces: SortedKeyList[_Trace] = SortedKeyList( + key=lambda trace: trace.start_time, ) self._token_count_total: int = 0 self._last_updated_at: Optional[datetime] = None + def has_trace(self, trace_id: TraceID) -> bool: + return trace_id in self._traces + def get_trace(self, trace_id: TraceID) -> Iterator[WrappedSpan]: with self._lock: # make a copy because source data can mutate during iteration @@ -236,6 +299,46 @@ def get_trace(self, trace_id: TraceID) -> Iterator[WrappedSpan]: for span in spans: yield span + def get_trace_ids( + self, + start_time: Optional[datetime] = None, + stop_time: Optional[datetime] = None, + trace_ids: Optional[Iterable[TraceID]] = None, + ) -> Iterator[TraceID]: + if not self._spans: + return + if start_time is None or stop_time is None: + min_start_time, max_stop_time = cast( + Tuple[datetime, datetime], + self.right_open_time_range, + ) + start_time = start_time or min_start_time + stop_time = stop_time or max_stop_time + if trace_ids is not None: + with self._lock: + traces = tuple( + trace + for trace_id in trace_ids + if ( + (trace := self._traces.get(trace_id)) + and start_time <= trace.start_time < stop_time + ) + ) + else: + sorted_traces = self._start_time_sorted_traces + # make a copy because source data can mutate during iteration + with self._lock: + traces = tuple( + sorted_traces.irange_key( + start_time.astimezone(timezone.utc), + stop_time.astimezone(timezone.utc), + inclusive=(True, False), + reverse=True, # most recent traces first + ) + ) + for trace in traces: + yield trace.trace_id + def get_spans( self, start_time: Optional[datetime] = None, @@ -289,15 +392,15 @@ def get_num_documents(self, span_id: SpanID) -> int: def root_span_latency_ms_quantiles(self, probability: float) -> Optional[float]: """Root span latency quantiles in milliseconds""" with self._lock: - spans = self._latency_sorted_root_spans - if not (n := len(spans)): + traces = self._latency_sorted_traces + if not (n := len(traces)): return None if probability >= 1: - return cast(float, spans[-1][ComputedAttributes.LATENCY_MS]) + return cast(float, traces[-1].latency_ms) if probability <= 0: - return cast(float, spans[0][ComputedAttributes.LATENCY_MS]) + return cast(float, traces[0].latency_ms) k = max(0, round(n * probability) - 1) - return cast(float, spans[k][ComputedAttributes.LATENCY_MS]) + return cast(float, traces[k].latency_ms) def get_descendant_spans(self, span_id: SpanID) -> Iterator[WrappedSpan]: for span in self._get_descendant_spans(span_id): @@ -373,7 +476,6 @@ def _add_span(self, span: WrappedSpan) -> None: # A root span is a span whose parent span is not in our collection. # Now that their parent span has arrived, they are no longer root spans. self._start_time_sorted_root_spans.remove(child_span) - self._latency_sorted_root_spans.remove(child_span) # Add computed attributes to span start_time = span.start_time @@ -383,11 +485,10 @@ def _add_span(self, span: WrappedSpan) -> None: # Store the new span (after adding computed attributes) self._spans[span_id] = span - self._traces[span.context.trace_id].add(span) + self._add_span_to_trace(span) self._start_time_sorted_spans.add(span) if parent_span_id is None or parent_span_id not in self._spans: self._start_time_sorted_root_spans.add(span) - self._latency_sorted_root_spans.add(span) self._propagate_cumulative_values(span) self._update_cached_statistics(span) @@ -395,6 +496,18 @@ def _add_span(self, span: WrappedSpan) -> None: # when they should refresh the page. self._last_updated_at = datetime.now(timezone.utc) + def _add_span_to_trace(self, span: WrappedSpan) -> None: + trace_id = span.context.trace_id + if (trace := self._traces.get(trace_id)) is None: + self._traces[trace_id] = trace = _Trace(span) + else: + # Must remove trace before mutating it. + self._latency_sorted_traces.remove(trace) + self._start_time_sorted_traces.remove(trace) + trace.add(span) + self._latency_sorted_traces.add(trace) + self._start_time_sorted_traces.add(trace) + def _update_cached_statistics(self, span: WrappedSpan) -> None: # Update statistics for quick access later span_id = span.context.span_id @@ -444,6 +557,7 @@ def __init__(self) -> None: self._evaluations_by_trace_id: DefaultDict[TraceID, Dict[EvaluationName, pb.Evaluation]] = ( defaultdict(dict) ) + self._trace_evaluation_labels: DefaultDict[EvaluationName, Set[str]] = defaultdict(set) self._span_evaluations_by_name: DefaultDict[EvaluationName, Dict[SpanID, pb.Evaluation]] = ( defaultdict(dict) ) @@ -484,6 +598,9 @@ def _add(self, evaluation: pb.Evaluation) -> None: trace_id = TraceID(subject_id.trace_id) self._evaluations_by_trace_id[trace_id][name] = evaluation self._trace_evaluations_by_name[name][trace_id] = evaluation + if evaluation.result.HasField("label"): + label = evaluation.result.label.value + self._trace_evaluation_labels[name].add(label) elif subject_id_kind is None: logger.warning( f"discarding evaluation with missing subject_id: {MessageToDict(evaluation)}" @@ -496,6 +613,30 @@ def _add(self, evaluation: pb.Evaluation) -> None: def last_updated_at(self) -> Optional[datetime]: return self._last_updated_at + def get_trace_evaluation(self, trace_id: TraceID, name: str) -> Optional[pb.Evaluation]: + with self._lock: + trace_evaluations = self._evaluations_by_trace_id.get(trace_id) + return trace_evaluations.get(name) if trace_evaluations else None + + def get_trace_evaluation_names(self) -> List[EvaluationName]: + with self._lock: + return list(self._trace_evaluations_by_name) + + def get_trace_evaluation_labels(self, name: EvaluationName) -> Tuple[str, ...]: + with self._lock: + labels = self._trace_evaluation_labels.get(name) + return tuple(labels) if labels else () + + def get_trace_evaluation_trace_ids(self, name: EvaluationName) -> Tuple[TraceID, ...]: + with self._lock: + trace_evaluations = self._trace_evaluations_by_name.get(name) + return tuple(trace_evaluations.keys()) if trace_evaluations else () + + def get_evaluations_by_trace_id(self, trace_id: TraceID) -> List[pb.Evaluation]: + with self._lock: + evaluations = self._evaluations_by_trace_id.get(trace_id) + return list(evaluations.values()) if evaluations else [] + def get_span_evaluation(self, span_id: SpanID, name: str) -> Optional[pb.Evaluation]: with self._lock: span_evaluations = self._evaluations_by_span_id.get(span_id) diff --git a/src/phoenix/server/api/types/Evaluation.py b/src/phoenix/server/api/types/Evaluation.py index 5b4ccd0915..c08bbc1aac 100644 --- a/src/phoenix/server/api/types/Evaluation.py +++ b/src/phoenix/server/api/types/Evaluation.py @@ -3,7 +3,7 @@ import strawberry import phoenix.trace.v1 as pb -from phoenix.trace.schemas import SpanID +from phoenix.trace.schemas import SpanID, TraceID @strawberry.interface @@ -24,6 +24,26 @@ class Evaluation: ) +@strawberry.type +class TraceEvaluation(Evaluation): + trace_id: strawberry.Private[TraceID] + + @staticmethod + def from_pb_evaluation(evaluation: pb.Evaluation) -> "TraceEvaluation": + result = evaluation.result + score = result.score.value if result.HasField("score") else None + label = result.label.value if result.HasField("label") else None + explanation = result.explanation.value if result.HasField("explanation") else None + trace_id = TraceID(evaluation.subject_id.trace_id) + return TraceEvaluation( + name=evaluation.name, + score=score, + label=label, + explanation=explanation, + trace_id=trace_id, + ) + + @strawberry.type class SpanEvaluation(Evaluation): span_id: strawberry.Private[SpanID] diff --git a/src/phoenix/server/api/types/Project.py b/src/phoenix/server/api/types/Project.py index 307efdfddd..926d28296b 100644 --- a/src/phoenix/server/api/types/Project.py +++ b/src/phoenix/server/api/types/Project.py @@ -19,6 +19,7 @@ connection_from_list, ) from phoenix.server.api.types.Span import Span, to_gql_span +from phoenix.server.api.types.Trace import Trace from phoenix.server.api.types.ValidationResult import ValidationResult from phoenix.trace.dsl import SpanFilter from phoenix.trace.schemas import SpanID, TraceID @@ -69,6 +70,12 @@ def latency_ms_p50(self) -> Optional[float]: def latency_ms_p99(self) -> Optional[float]: return self.project.root_span_latency_ms_quantiles(0.99) + @strawberry.field + def trace(self, trace_id: ID) -> Optional[Trace]: + if self.project.has_trace(TraceID(trace_id)): + return Trace(trace_id=trace_id, project=self.project) + return None + @strawberry.field def spans( self, @@ -88,7 +95,12 @@ def spans( last=last, before=before if isinstance(before, Cursor) else None, ) - if not (project := self.project).span_count(): + start_time = time_range.start if time_range else None + stop_time = time_range.end if time_range else None + if not (project := self.project).span_count( + start_time=start_time, + stop_time=stop_time, + ): return connection_from_list(data=[], args=args) predicate = ( SpanFilter( @@ -100,8 +112,8 @@ def spans( ) if not trace_ids: spans = project.get_spans( - start_time=time_range.start if time_range else None, - stop_time=time_range.end if time_range else None, + start_time=start_time, + stop_time=stop_time, root_spans_only=root_spans_only, ) else: @@ -115,6 +127,13 @@ def spans( data = [to_gql_span(span, project) for span in spans] return connection_from_list(data=data, args=args) + @strawberry.field( + description="Names of all available evaluations for traces. " + "(The list contains no duplicates.)" + ) # type: ignore + def trace_evaluation_names(self) -> List[str]: + return self.project.get_trace_evaluation_names() + @strawberry.field( description="Names of all available evaluations for spans. " "(The list contains no duplicates.)" @@ -133,6 +152,37 @@ def document_evaluation_names( None if span_id is UNSET else SpanID(span_id), ) + @strawberry.field + def trace_evaluation_summary( + self, + evaluation_name: str, + time_range: Optional[TimeRange] = UNSET, + ) -> Optional[EvaluationSummary]: + project = self.project + eval_trace_ids = project.get_trace_evaluation_trace_ids(evaluation_name) + if not eval_trace_ids: + return None + trace_ids = project.get_trace_ids( + start_time=time_range.start if time_range else None, + stop_time=time_range.end if time_range else None, + trace_ids=eval_trace_ids, + ) + evaluations = tuple( + evaluation + for trace_id in trace_ids + if ( + evaluation := project.get_trace_evaluation( + trace_id, + evaluation_name, + ) + ) + is not None + ) + if not evaluations: + return None + labels = project.get_trace_evaluation_labels(evaluation_name) + return EvaluationSummary(evaluations, labels) + @strawberry.field def span_evaluation_summary( self, diff --git a/src/phoenix/server/api/types/Trace.py b/src/phoenix/server/api/types/Trace.py new file mode 100644 index 0000000000..eee01fc672 --- /dev/null +++ b/src/phoenix/server/api/types/Trace.py @@ -0,0 +1,44 @@ +from typing import List, Optional + +import strawberry +from strawberry import ID, UNSET, Private + +from phoenix.core.project import Project +from phoenix.server.api.types.Evaluation import TraceEvaluation +from phoenix.server.api.types.pagination import ( + Connection, + ConnectionArgs, + Cursor, + connection_from_list, +) +from phoenix.server.api.types.Span import Span, to_gql_span +from phoenix.trace.schemas import TraceID + + +@strawberry.type +class Trace: + trace_id: ID + project: Private[Project] + + @strawberry.field + def spans( + self, + first: Optional[int] = UNSET, + last: Optional[int] = UNSET, + after: Optional[Cursor] = UNSET, + before: Optional[Cursor] = UNSET, + ) -> Connection[Span]: + args = ConnectionArgs( + first=first, + after=after if isinstance(after, Cursor) else None, + last=last, + before=before if isinstance(before, Cursor) else None, + ) + spans = self.project.get_trace(TraceID(self.trace_id)) + data = [to_gql_span(span, self.project) for span in spans] + return connection_from_list(data=data, args=args) + + @strawberry.field(description="Evaluations associated with the trace") # type: ignore + def trace_evaluations(self) -> List[TraceEvaluation]: + evaluations = self.project.get_evaluations_by_trace_id(TraceID(self.trace_id)) + return [TraceEvaluation.from_pb_evaluation(evaluation) for evaluation in evaluations]