Skip to content

Commit

Permalink
return completed span
Browse files Browse the repository at this point in the history
  • Loading branch information
axiomofjoy committed Oct 18, 2024
1 parent 95cf580 commit 1693a4d
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 8 deletions.
10 changes: 7 additions & 3 deletions app/schema.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,6 @@ enum AuthMethod {

union Bin = NominalBin | IntervalBin | MissingValueBin

union ChatCompletionChunk = TextChunk | ToolCallChunk

input ChatCompletionInput {
messages: [ChatCompletionMessageInput!]!
model: GenerativeModelInput!
Expand All @@ -90,6 +88,8 @@ enum ChatCompletionMessageRole {
AI
}

union ChatCompletionSubscriptionPayload = TextChunk | ToolCallChunk | FinishedChatCompletion

input ClearProjectInput {
id: GlobalID!

Expand Down Expand Up @@ -821,6 +821,10 @@ type ExportedFile {
fileName: String!
}

type FinishedChatCompletion {
span: Span!
}

type FunctionCallChunk {
name: String!
arguments: String!
Expand Down Expand Up @@ -1442,7 +1446,7 @@ enum SpanStatusCode {
}

type Subscription {
chatCompletion(input: ChatCompletionInput!): ChatCompletionChunk!
chatCompletion(input: ChatCompletionInput!): ChatCompletionSubscriptionPayload!
}

type SystemApiKey implements ApiKey & Node {
Expand Down
29 changes: 24 additions & 5 deletions src/phoenix/server/api/subscriptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter
from opentelemetry.trace import StatusCode
from sqlalchemy import insert, select
from sqlalchemy.orm import joinedload
from strawberry import UNSET
from strawberry.scalars import JSON as JSONScalarType
from strawberry.types import Info
Expand All @@ -45,6 +46,7 @@
from phoenix.server.api.input_types.InvocationParameters import InvocationParameters
from phoenix.server.api.types.ChatCompletionMessageRole import ChatCompletionMessageRole
from phoenix.server.api.types.GenerativeProvider import GenerativeProviderKey
from phoenix.server.api.types.Span import Span, to_gql_span
from phoenix.server.dml_event import SpanInsertEvent
from phoenix.trace.attributes import unflatten
from phoenix.utilities.json import jsonify
Expand Down Expand Up @@ -94,8 +96,14 @@ class ToolCallChunk:
function: FunctionCallChunk


ChatCompletionChunk: TypeAlias = Annotated[
Union[TextChunk, ToolCallChunk], strawberry.union("ChatCompletionChunk")
@strawberry.type
class FinishedChatCompletion:
span: Span


ChatCompletionSubscriptionPayload: TypeAlias = Annotated[
Union[TextChunk, ToolCallChunk, FinishedChatCompletion],
strawberry.union("ChatCompletionSubscriptionPayload"),
]


Expand Down Expand Up @@ -160,7 +168,7 @@ class Subscription:
@strawberry.subscription
async def chat_completion(
self, info: Info[Context, None], input: ChatCompletionInput
) -> AsyncIterator[ChatCompletionChunk]:
) -> AsyncIterator[ChatCompletionSubscriptionPayload]:
from openai import NOT_GIVEN, AsyncAzureOpenAI, AsyncOpenAI
from openai.types.chat import ChatCompletionStreamOptionsParam

Expand Down Expand Up @@ -294,8 +302,10 @@ async def chat_completion(
end_time=end_time,
)
)
await session.execute(
insert(models.Span).values(
span_id = await session.scalar(
insert(models.Span)
.returning(models.Span.id)
.values(
trace_rowid=trace_rowid,
span_id=span_id,
parent_id=None,
Expand All @@ -314,6 +324,15 @@ async def chat_completion(
llm_token_count_completion=completion_tokens,
)
)
playground_span = await session.scalar(
select(models.Span)
.where(models.Span.id == span_id)
.options(
joinedload(models.Span.trace, innerjoin=True).load_only(models.Trace.trace_id)
)
)
assert playground_span is not None
yield FinishedChatCompletion(span=to_gql_span(playground_span))
info.context.event_queue.put(SpanInsertEvent(ids=(playground_project_id,)))


Expand Down

0 comments on commit 1693a4d

Please sign in to comment.