From 2eae8c5df25c4454352d4167b3435675db19ae75 Mon Sep 17 00:00:00 2001 From: Xander Song Date: Wed, 23 Oct 2024 14:55:43 -0700 Subject: [PATCH] feat(playground): add token counts for anthropic (#5161) --- src/phoenix/server/api/subscriptions.py | 106 ++++++++++++++------ tests/unit/server/api/test_subscriptions.py | 6 +- 2 files changed, 77 insertions(+), 35 deletions(-) diff --git a/src/phoenix/server/api/subscriptions.py b/src/phoenix/server/api/subscriptions.py index bc865b8ced..31bd00c6d9 100644 --- a/src/phoenix/server/api/subscriptions.py +++ b/src/phoenix/server/api/subscriptions.py @@ -35,6 +35,7 @@ from opentelemetry.sdk.trace.export import SimpleSpanProcessor from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter from opentelemetry.trace import StatusCode +from opentelemetry.util.types import AttributeValue from sqlalchemy import insert, select from strawberry import UNSET from strawberry.scalars import JSON as JSONScalarType @@ -68,6 +69,7 @@ PLAYGROUND_PROJECT_NAME = "playground" ToolCallID: TypeAlias = str +SetSpanAttributesFn: TypeAlias = Callable[[Dict[str, AttributeValue]], None] @strawberry.enum @@ -147,7 +149,13 @@ def decorator(cls: Type["PlaygroundStreamingClient"]) -> Type["PlaygroundStreami class PlaygroundStreamingClient(ABC): - def __init__(self, model: GenerativeModelInput, api_key: Optional[str] = None) -> None: ... + def __init__( + self, + model: GenerativeModelInput, + api_key: Optional[str] = None, + set_span_attributes: Optional[SetSpanAttributesFn] = None, + ) -> None: + self._set_span_attributes = set_span_attributes @abstractmethod async def chat_completion_create( @@ -162,19 +170,20 @@ async def chat_completion_create( # https://mypy.readthedocs.io/en/stable/more_types.html#asynchronous-iterators yield TextChunk(content="") - @property - @abstractmethod - def attributes(self) -> Dict[str, Any]: ... - @register_llm_client(GenerativeProviderKey.OPENAI) class OpenAIStreamingClient(PlaygroundStreamingClient): - def __init__(self, model: GenerativeModelInput, api_key: Optional[str] = None) -> None: + def __init__( + self, + model: GenerativeModelInput, + api_key: Optional[str] = None, + set_span_attributes: Optional[SetSpanAttributesFn] = None, + ) -> None: from openai import AsyncOpenAI + super().__init__(model=model, api_key=api_key, set_span_attributes=set_span_attributes) self.client = AsyncOpenAI(api_key=api_key) self.model_name = model.name - self._attributes: Dict[str, Any] = {} async def chat_completion_create( self, @@ -225,8 +234,8 @@ async def chat_completion_create( ), ) yield tool_call_chunk - if token_usage is not None: - self._attributes.update(_llm_token_counts(token_usage)) + if token_usage is not None and self._set_span_attributes: + self._set_span_attributes(dict(self._llm_token_counts(token_usage))) def to_openai_chat_completion_param( self, @@ -297,16 +306,24 @@ def to_openai_tool_call_param( type="function", ) - @property - def attributes(self) -> Dict[str, Any]: - return self._attributes + @staticmethod + def _llm_token_counts(usage: "CompletionUsage") -> Iterator[Tuple[str, Any]]: + yield LLM_TOKEN_COUNT_PROMPT, usage.prompt_tokens + yield LLM_TOKEN_COUNT_COMPLETION, usage.completion_tokens + yield LLM_TOKEN_COUNT_TOTAL, usage.total_tokens @register_llm_client(GenerativeProviderKey.AZURE_OPENAI) class AzureOpenAIStreamingClient(OpenAIStreamingClient): - def __init__(self, model: GenerativeModelInput, api_key: Optional[str] = None): + def __init__( + self, + model: GenerativeModelInput, + api_key: Optional[str] = None, + set_span_attributes: Optional[SetSpanAttributesFn] = None, + ): from openai import AsyncAzureOpenAI + super().__init__(model=model, api_key=api_key, set_span_attributes=set_span_attributes) if model.endpoint is None or model.api_version is None: raise ValueError("endpoint and api_version are required for Azure OpenAI models") self.client = AsyncAzureOpenAI( @@ -318,9 +335,15 @@ def __init__(self, model: GenerativeModelInput, api_key: Optional[str] = None): @register_llm_client(GenerativeProviderKey.ANTHROPIC) class AnthropicStreamingClient(PlaygroundStreamingClient): - def __init__(self, model: GenerativeModelInput, api_key: Optional[str] = None) -> None: + def __init__( + self, + model: GenerativeModelInput, + api_key: Optional[str] = None, + set_span_attributes: Optional[SetSpanAttributesFn] = None, + ) -> None: import anthropic + super().__init__(model=model, api_key=api_key, set_span_attributes=set_span_attributes) self.client = anthropic.AsyncAnthropic(api_key=api_key) self.model_name = model.name @@ -332,6 +355,9 @@ async def chat_completion_create( tools: List[JSONScalarType], **invocation_parameters: Any, ) -> AsyncIterator[ChatCompletionSubscriptionPayload]: + import anthropic.lib.streaming as anthropic_streaming + import anthropic.types as anthropic_types + anthropic_messages, system_prompt = self._build_anthropic_messages(messages) anthropic_params = { @@ -341,10 +367,35 @@ async def chat_completion_create( "max_tokens": 1024, **invocation_parameters, } - async with self.client.messages.stream(**anthropic_params) as stream: - async for text in stream.text_stream: - yield TextChunk(content=text) + async for event in stream: + if isinstance(event, anthropic_types.RawMessageStartEvent): + if self._set_span_attributes: + self._set_span_attributes( + {LLM_TOKEN_COUNT_PROMPT: event.message.usage.input_tokens} + ) + elif isinstance(event, anthropic_streaming.TextEvent): + yield TextChunk(content=event.text) + elif isinstance(event, anthropic_streaming.MessageStopEvent): + if self._set_span_attributes: + self._set_span_attributes( + {LLM_TOKEN_COUNT_COMPLETION: event.message.usage.output_tokens} + ) + elif isinstance( + event, + ( + anthropic_types.RawContentBlockStartEvent, + anthropic_types.RawContentBlockDeltaEvent, + anthropic_types.RawMessageDeltaEvent, + anthropic_streaming.ContentBlockStopEvent, + ), + ): + # event types emitted by the stream that don't contain useful information + pass + elif isinstance(event, anthropic_streaming.InputJsonEvent): + raise NotImplementedError + else: + assert_never(event) def _build_anthropic_messages( self, @@ -366,10 +417,6 @@ def _build_anthropic_messages( return anthropic_messages, system_prompt - @property - def attributes(self) -> Dict[str, Any]: - return dict() - @strawberry.type class Subscription: @@ -383,8 +430,6 @@ async def chat_completion( if llm_client_class is None: raise ValueError(f"No LLM client registered for provider '{provider_key}'") - llm_client = llm_client_class(model=input.model, api_key=input.api_key) - messages = [ ( message.role, @@ -424,6 +469,9 @@ async def chat_completion( text_chunks: List[TextChunk] = [] tool_call_chunks: DefaultDict[ToolCallID, List[ToolCallChunk]] = defaultdict(list) + llm_client = llm_client_class( + model=input.model, api_key=input.api_key, set_span_attributes=span.set_attributes + ) async for chunk in llm_client.chat_completion_create( messages=messages, tools=input.tools or [], @@ -438,13 +486,11 @@ async def chat_completion( tool_call_chunks[chunk.id].append(chunk) span.set_status(StatusCode.OK) - llm_client_attributes = llm_client.attributes span.set_attributes( dict( chain( _output_value_and_mime_type(response_chunks), - llm_client_attributes.items(), _llm_output_messages(text_chunks, tool_call_chunks), ) ) @@ -456,8 +502,8 @@ async def chat_completion( assert (attributes := finished_span.attributes) is not None start_time = _datetime(epoch_nanoseconds=finished_span.start_time) end_time = _datetime(epoch_nanoseconds=finished_span.end_time) - prompt_tokens = llm_client_attributes.get(LLM_TOKEN_COUNT_PROMPT, 0) - completion_tokens = llm_client_attributes.get(LLM_TOKEN_COUNT_COMPLETION, 0) + prompt_tokens = attributes.get(LLM_TOKEN_COUNT_PROMPT, 0) + completion_tokens = attributes.get(LLM_TOKEN_COUNT_COMPLETION, 0) trace_id = _hex(finished_span.context.trace_id) span_id = _hex(finished_span.context.span_id) status = finished_span.status @@ -524,12 +570,6 @@ def _llm_tools(tools: List[JSONScalarType]) -> Iterator[Tuple[str, Any]]: yield f"{LLM_TOOLS}.{tool_index}.{TOOL_JSON_SCHEMA}", json.dumps(tool) -def _llm_token_counts(usage: "CompletionUsage") -> Iterator[Tuple[str, Any]]: - yield LLM_TOKEN_COUNT_PROMPT, usage.prompt_tokens - yield LLM_TOKEN_COUNT_COMPLETION, usage.completion_tokens - yield LLM_TOKEN_COUNT_TOTAL, usage.total_tokens - - def _input_value_and_mime_type(input: ChatCompletionInput) -> Iterator[Tuple[str, Any]]: assert (api_key := "api_key") in (input_data := jsonify(input)) input_data = {k: v for k, v in input_data.items() if k != api_key} diff --git a/tests/unit/server/api/test_subscriptions.py b/tests/unit/server/api/test_subscriptions.py index 22fb9ae7de..655eed91c1 100644 --- a/tests/unit/server/api/test_subscriptions.py +++ b/tests/unit/server/api/test_subscriptions.py @@ -642,8 +642,8 @@ async def test_anthropic_text_response_emits_expected_payloads_and_records_expec assert isinstance(token_count_total := span.pop("tokenCountTotal"), int) assert isinstance(token_count_prompt := span.pop("tokenCountPrompt"), int) assert isinstance(token_count_completion := span.pop("tokenCountCompletion"), int) - assert token_count_prompt == 0 - assert token_count_completion == 0 + assert token_count_prompt > 0 + assert token_count_completion > 0 assert token_count_total == token_count_prompt + token_count_completion assert (input := span.pop("input")).pop("mimeType") == "json" assert (input_value := input.pop("value")) @@ -672,6 +672,8 @@ async def test_anthropic_text_response_emits_expected_payloads_and_records_expec assert attributes.pop(OPENINFERENCE_SPAN_KIND) == LLM assert attributes.pop(LLM_MODEL_NAME) == "claude-3-5-sonnet-20240620" assert attributes.pop(LLM_INVOCATION_PARAMETERS) == json.dumps({"temperature": 0.1}) + assert attributes.pop(LLM_TOKEN_COUNT_PROMPT) == token_count_prompt + assert attributes.pop(LLM_TOKEN_COUNT_COMPLETION) == token_count_completion assert attributes.pop(INPUT_VALUE) assert attributes.pop(INPUT_MIME_TYPE) == JSON assert attributes.pop(OUTPUT_VALUE)