diff --git a/src/phoenix/server/api/subscriptions.py b/src/phoenix/server/api/subscriptions.py index adc0d425ec..86a1ed7ddb 100644 --- a/src/phoenix/server/api/subscriptions.py +++ b/src/phoenix/server/api/subscriptions.py @@ -55,6 +55,7 @@ ) if TYPE_CHECKING: + from openai.types import CompletionUsage from openai.types.chat import ( ChatCompletionMessageParam, ) @@ -161,6 +162,7 @@ async def chat_completion( self, info: Info[Context, None], input: ChatCompletionInput ) -> AsyncIterator[ChatCompletionChunk]: from openai import NOT_GIVEN, AsyncAzureOpenAI, AsyncOpenAI + from openai.types.chat import ChatCompletionStreamOptionsParam client: Union[AsyncAzureOpenAI, AsyncOpenAI] @@ -208,14 +210,19 @@ async def chat_completion( text_chunks: List[TextChunk] = [] tool_call_chunks: DefaultDict[ToolCallIndex, List[ToolCallChunk]] = defaultdict(list) role: Optional[str] = None + token_usage: Optional[CompletionUsage] = None async for chunk in await client.chat.completions.create( messages=openai_messages, model=input.model.name, stream=True, tools=input.tools or NOT_GIVEN, + stream_options=ChatCompletionStreamOptionsParam(include_usage=True), **invocation_parameters, ): response_chunks.append(chunk) + if (usage := chunk.usage) is not None: + token_usage = usage + continue choice = chunk.choices[0] delta = choice.delta if role is None: @@ -246,6 +253,7 @@ async def chat_completion( dict( chain( _output_value_and_mime_type(response_chunks), + _llm_token_counts(token_usage) if token_usage is not None else [], _llm_output_messages(text_chunks, tool_call_chunks), ) ) @@ -257,6 +265,8 @@ async def chat_completion( assert (attributes := finished_span.attributes) is not None start_time = _datetime(epoch_nanoseconds=finished_span.start_time) end_time = _datetime(epoch_nanoseconds=finished_span.end_time) + prompt_tokens = token_usage.prompt_tokens if token_usage is not None else 0 + completion_tokens = token_usage.completion_tokens if token_usage is not None else 0 trace_id = _hex(finished_span.context.trace_id) span_id = _hex(finished_span.context.span_id) status = finished_span.status @@ -298,10 +308,10 @@ async def chat_completion( status_code=status.status_code.name, status_message=status.description or "", cumulative_error_count=int(not status.is_ok), - cumulative_llm_token_count_prompt=0, - cumulative_llm_token_count_completion=0, - llm_token_count_prompt=0, - llm_token_count_completion=0, + cumulative_llm_token_count_prompt=prompt_tokens, + cumulative_llm_token_count_completion=completion_tokens, + llm_token_count_prompt=prompt_tokens, + llm_token_count_completion=completion_tokens, ) ) info.context.event_queue.put(SpanInsertEvent(ids=(playground_project_id,))) @@ -324,6 +334,12 @@ def _llm_tools(tools: List[JSONScalarType]) -> Iterator[Tuple[str, Any]]: yield f"{LLM_TOOLS}.{tool_index}.{TOOL_JSON_SCHEMA}", json.dumps(tool) +def _llm_token_counts(usage: "CompletionUsage") -> Iterator[Tuple[str, Any]]: + yield LLM_TOKEN_COUNT_PROMPT, usage.prompt_tokens + yield LLM_TOKEN_COUNT_COMPLETION, usage.completion_tokens + yield LLM_TOKEN_COUNT_TOTAL, usage.total_tokens + + def _input_value_and_mime_type(input: ChatCompletionInput) -> Iterator[Tuple[str, Any]]: assert any(field.name == (api_key := "api_key") for field in fields(ChatCompletionInput)) yield INPUT_MIME_TYPE, JSON @@ -419,6 +435,9 @@ def _template_formatter(template_language: TemplateLanguage) -> TemplateFormatte LLM_MODEL_NAME = SpanAttributes.LLM_MODEL_NAME LLM_INVOCATION_PARAMETERS = SpanAttributes.LLM_INVOCATION_PARAMETERS LLM_TOOLS = SpanAttributes.LLM_TOOLS +LLM_TOKEN_COUNT_PROMPT = SpanAttributes.LLM_TOKEN_COUNT_PROMPT +LLM_TOKEN_COUNT_COMPLETION = SpanAttributes.LLM_TOKEN_COUNT_COMPLETION +LLM_TOKEN_COUNT_TOTAL = SpanAttributes.LLM_TOKEN_COUNT_TOTAL MESSAGE_CONTENT = MessageAttributes.MESSAGE_CONTENT MESSAGE_ROLE = MessageAttributes.MESSAGE_ROLE