Skip to content

Commit

Permalink
implement token counts (#5073)
Browse files Browse the repository at this point in the history
  • Loading branch information
axiomofjoy authored Oct 17, 2024
1 parent 491ac18 commit 954e394
Showing 1 changed file with 23 additions and 4 deletions.
27 changes: 23 additions & 4 deletions src/phoenix/server/api/subscriptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
)

if TYPE_CHECKING:
from openai.types import CompletionUsage
from openai.types.chat import (
ChatCompletionMessageParam,
)
Expand Down Expand Up @@ -161,6 +162,7 @@ async def chat_completion(
self, info: Info[Context, None], input: ChatCompletionInput
) -> AsyncIterator[ChatCompletionChunk]:
from openai import NOT_GIVEN, AsyncAzureOpenAI, AsyncOpenAI
from openai.types.chat import ChatCompletionStreamOptionsParam

client: Union[AsyncAzureOpenAI, AsyncOpenAI]

Expand Down Expand Up @@ -208,14 +210,19 @@ async def chat_completion(
text_chunks: List[TextChunk] = []
tool_call_chunks: DefaultDict[ToolCallIndex, List[ToolCallChunk]] = defaultdict(list)
role: Optional[str] = None
token_usage: Optional[CompletionUsage] = None
async for chunk in await client.chat.completions.create(
messages=openai_messages,
model=input.model.name,
stream=True,
tools=input.tools or NOT_GIVEN,
stream_options=ChatCompletionStreamOptionsParam(include_usage=True),
**invocation_parameters,
):
response_chunks.append(chunk)
if (usage := chunk.usage) is not None:
token_usage = usage
continue
choice = chunk.choices[0]
delta = choice.delta
if role is None:
Expand Down Expand Up @@ -246,6 +253,7 @@ async def chat_completion(
dict(
chain(
_output_value_and_mime_type(response_chunks),
_llm_token_counts(token_usage) if token_usage is not None else [],
_llm_output_messages(text_chunks, tool_call_chunks),
)
)
Expand All @@ -257,6 +265,8 @@ async def chat_completion(
assert (attributes := finished_span.attributes) is not None
start_time = _datetime(epoch_nanoseconds=finished_span.start_time)
end_time = _datetime(epoch_nanoseconds=finished_span.end_time)
prompt_tokens = token_usage.prompt_tokens if token_usage is not None else 0
completion_tokens = token_usage.completion_tokens if token_usage is not None else 0
trace_id = _hex(finished_span.context.trace_id)
span_id = _hex(finished_span.context.span_id)
status = finished_span.status
Expand Down Expand Up @@ -298,10 +308,10 @@ async def chat_completion(
status_code=status.status_code.name,
status_message=status.description or "",
cumulative_error_count=int(not status.is_ok),
cumulative_llm_token_count_prompt=0,
cumulative_llm_token_count_completion=0,
llm_token_count_prompt=0,
llm_token_count_completion=0,
cumulative_llm_token_count_prompt=prompt_tokens,
cumulative_llm_token_count_completion=completion_tokens,
llm_token_count_prompt=prompt_tokens,
llm_token_count_completion=completion_tokens,
)
)
info.context.event_queue.put(SpanInsertEvent(ids=(playground_project_id,)))
Expand All @@ -324,6 +334,12 @@ def _llm_tools(tools: List[JSONScalarType]) -> Iterator[Tuple[str, Any]]:
yield f"{LLM_TOOLS}.{tool_index}.{TOOL_JSON_SCHEMA}", json.dumps(tool)


def _llm_token_counts(usage: "CompletionUsage") -> Iterator[Tuple[str, Any]]:
yield LLM_TOKEN_COUNT_PROMPT, usage.prompt_tokens
yield LLM_TOKEN_COUNT_COMPLETION, usage.completion_tokens
yield LLM_TOKEN_COUNT_TOTAL, usage.total_tokens


def _input_value_and_mime_type(input: ChatCompletionInput) -> Iterator[Tuple[str, Any]]:
assert any(field.name == (api_key := "api_key") for field in fields(ChatCompletionInput))
yield INPUT_MIME_TYPE, JSON
Expand Down Expand Up @@ -419,6 +435,9 @@ def _template_formatter(template_language: TemplateLanguage) -> TemplateFormatte
LLM_MODEL_NAME = SpanAttributes.LLM_MODEL_NAME
LLM_INVOCATION_PARAMETERS = SpanAttributes.LLM_INVOCATION_PARAMETERS
LLM_TOOLS = SpanAttributes.LLM_TOOLS
LLM_TOKEN_COUNT_PROMPT = SpanAttributes.LLM_TOKEN_COUNT_PROMPT
LLM_TOKEN_COUNT_COMPLETION = SpanAttributes.LLM_TOKEN_COUNT_COMPLETION
LLM_TOKEN_COUNT_TOTAL = SpanAttributes.LLM_TOKEN_COUNT_TOTAL

MESSAGE_CONTENT = MessageAttributes.MESSAGE_CONTENT
MESSAGE_ROLE = MessageAttributes.MESSAGE_ROLE
Expand Down

0 comments on commit 954e394

Please sign in to comment.