diff --git a/app/schema.graphql b/app/schema.graphql index e42847c242..806cbd6eb9 100644 --- a/app/schema.graphql +++ b/app/schema.graphql @@ -65,8 +65,6 @@ enum AuthMethod { union Bin = NominalBin | IntervalBin | MissingValueBin -union ChatCompletionChunk = TextChunk | ToolCallChunk - input ChatCompletionInput { messages: [ChatCompletionMessageInput!]! model: GenerativeModelInput! @@ -90,6 +88,8 @@ enum ChatCompletionMessageRole { AI } +union ChatCompletionSubscriptionPayload = TextChunk | ToolCallChunk | FinishedChatCompletion + input ClearProjectInput { id: GlobalID! @@ -821,6 +821,10 @@ type ExportedFile { fileName: String! } +type FinishedChatCompletion { + span: Span! +} + type FunctionCallChunk { name: String! arguments: String! @@ -1442,7 +1446,7 @@ enum SpanStatusCode { } type Subscription { - chatCompletion(input: ChatCompletionInput!): ChatCompletionChunk! + chatCompletion(input: ChatCompletionInput!): ChatCompletionSubscriptionPayload! } type SystemApiKey implements ApiKey & Node { diff --git a/src/phoenix/server/api/subscriptions.py b/src/phoenix/server/api/subscriptions.py index 86a1ed7ddb..5e79980b30 100644 --- a/src/phoenix/server/api/subscriptions.py +++ b/src/phoenix/server/api/subscriptions.py @@ -34,6 +34,7 @@ from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter from opentelemetry.trace import StatusCode from sqlalchemy import insert, select +from sqlalchemy.orm import joinedload from strawberry import UNSET from strawberry.scalars import JSON as JSONScalarType from strawberry.types import Info @@ -45,6 +46,7 @@ from phoenix.server.api.input_types.InvocationParameters import InvocationParameters from phoenix.server.api.types.ChatCompletionMessageRole import ChatCompletionMessageRole from phoenix.server.api.types.GenerativeProvider import GenerativeProviderKey +from phoenix.server.api.types.Span import Span, to_gql_span from phoenix.server.dml_event import SpanInsertEvent from phoenix.trace.attributes import unflatten from phoenix.utilities.json import jsonify @@ -94,8 +96,14 @@ class ToolCallChunk: function: FunctionCallChunk -ChatCompletionChunk: TypeAlias = Annotated[ - Union[TextChunk, ToolCallChunk], strawberry.union("ChatCompletionChunk") +@strawberry.type +class FinishedChatCompletion: + span: Span + + +ChatCompletionSubscriptionPayload: TypeAlias = Annotated[ + Union[TextChunk, ToolCallChunk, FinishedChatCompletion], + strawberry.union("ChatCompletionSubscriptionPayload"), ] @@ -160,7 +168,7 @@ class Subscription: @strawberry.subscription async def chat_completion( self, info: Info[Context, None], input: ChatCompletionInput - ) -> AsyncIterator[ChatCompletionChunk]: + ) -> AsyncIterator[ChatCompletionSubscriptionPayload]: from openai import NOT_GIVEN, AsyncAzureOpenAI, AsyncOpenAI from openai.types.chat import ChatCompletionStreamOptionsParam @@ -294,8 +302,10 @@ async def chat_completion( end_time=end_time, ) ) - await session.execute( - insert(models.Span).values( + span_id = await session.scalar( + insert(models.Span) + .returning(models.Span.id) + .values( trace_rowid=trace_rowid, span_id=span_id, parent_id=None, @@ -314,6 +324,15 @@ async def chat_completion( llm_token_count_completion=completion_tokens, ) ) + playground_span = await session.scalar( + select(models.Span) + .where(models.Span.id == span_id) + .options( + joinedload(models.Span.trace, innerjoin=True).load_only(models.Trace.trace_id) + ) + ) + assert playground_span is not None + yield FinishedChatCompletion(span=to_gql_span(playground_span)) info.context.event_queue.put(SpanInsertEvent(ids=(playground_project_id,)))