Skip to content

Commit

Permalink
feat(playground): add token counts for anthropic (#5161)
Browse files Browse the repository at this point in the history
  • Loading branch information
axiomofjoy authored Oct 23, 2024
1 parent 19021e4 commit 2eae8c5
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 35 deletions.
106 changes: 73 additions & 33 deletions src/phoenix/server/api/subscriptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter
from opentelemetry.trace import StatusCode
from opentelemetry.util.types import AttributeValue
from sqlalchemy import insert, select
from strawberry import UNSET
from strawberry.scalars import JSON as JSONScalarType
Expand Down Expand Up @@ -68,6 +69,7 @@
PLAYGROUND_PROJECT_NAME = "playground"

ToolCallID: TypeAlias = str
SetSpanAttributesFn: TypeAlias = Callable[[Dict[str, AttributeValue]], None]


@strawberry.enum
Expand Down Expand Up @@ -147,7 +149,13 @@ def decorator(cls: Type["PlaygroundStreamingClient"]) -> Type["PlaygroundStreami


class PlaygroundStreamingClient(ABC):
def __init__(self, model: GenerativeModelInput, api_key: Optional[str] = None) -> None: ...
def __init__(
self,
model: GenerativeModelInput,
api_key: Optional[str] = None,
set_span_attributes: Optional[SetSpanAttributesFn] = None,
) -> None:
self._set_span_attributes = set_span_attributes

@abstractmethod
async def chat_completion_create(
Expand All @@ -162,19 +170,20 @@ async def chat_completion_create(
# https://mypy.readthedocs.io/en/stable/more_types.html#asynchronous-iterators
yield TextChunk(content="")

@property
@abstractmethod
def attributes(self) -> Dict[str, Any]: ...


@register_llm_client(GenerativeProviderKey.OPENAI)
class OpenAIStreamingClient(PlaygroundStreamingClient):
def __init__(self, model: GenerativeModelInput, api_key: Optional[str] = None) -> None:
def __init__(
self,
model: GenerativeModelInput,
api_key: Optional[str] = None,
set_span_attributes: Optional[SetSpanAttributesFn] = None,
) -> None:
from openai import AsyncOpenAI

super().__init__(model=model, api_key=api_key, set_span_attributes=set_span_attributes)
self.client = AsyncOpenAI(api_key=api_key)
self.model_name = model.name
self._attributes: Dict[str, Any] = {}

async def chat_completion_create(
self,
Expand Down Expand Up @@ -225,8 +234,8 @@ async def chat_completion_create(
),
)
yield tool_call_chunk
if token_usage is not None:
self._attributes.update(_llm_token_counts(token_usage))
if token_usage is not None and self._set_span_attributes:
self._set_span_attributes(dict(self._llm_token_counts(token_usage)))

def to_openai_chat_completion_param(
self,
Expand Down Expand Up @@ -297,16 +306,24 @@ def to_openai_tool_call_param(
type="function",
)

@property
def attributes(self) -> Dict[str, Any]:
return self._attributes
@staticmethod
def _llm_token_counts(usage: "CompletionUsage") -> Iterator[Tuple[str, Any]]:
yield LLM_TOKEN_COUNT_PROMPT, usage.prompt_tokens
yield LLM_TOKEN_COUNT_COMPLETION, usage.completion_tokens
yield LLM_TOKEN_COUNT_TOTAL, usage.total_tokens


@register_llm_client(GenerativeProviderKey.AZURE_OPENAI)
class AzureOpenAIStreamingClient(OpenAIStreamingClient):
def __init__(self, model: GenerativeModelInput, api_key: Optional[str] = None):
def __init__(
self,
model: GenerativeModelInput,
api_key: Optional[str] = None,
set_span_attributes: Optional[SetSpanAttributesFn] = None,
):
from openai import AsyncAzureOpenAI

super().__init__(model=model, api_key=api_key, set_span_attributes=set_span_attributes)
if model.endpoint is None or model.api_version is None:
raise ValueError("endpoint and api_version are required for Azure OpenAI models")
self.client = AsyncAzureOpenAI(
Expand All @@ -318,9 +335,15 @@ def __init__(self, model: GenerativeModelInput, api_key: Optional[str] = None):

@register_llm_client(GenerativeProviderKey.ANTHROPIC)
class AnthropicStreamingClient(PlaygroundStreamingClient):
def __init__(self, model: GenerativeModelInput, api_key: Optional[str] = None) -> None:
def __init__(
self,
model: GenerativeModelInput,
api_key: Optional[str] = None,
set_span_attributes: Optional[SetSpanAttributesFn] = None,
) -> None:
import anthropic

super().__init__(model=model, api_key=api_key, set_span_attributes=set_span_attributes)
self.client = anthropic.AsyncAnthropic(api_key=api_key)
self.model_name = model.name

Expand All @@ -332,6 +355,9 @@ async def chat_completion_create(
tools: List[JSONScalarType],
**invocation_parameters: Any,
) -> AsyncIterator[ChatCompletionSubscriptionPayload]:
import anthropic.lib.streaming as anthropic_streaming
import anthropic.types as anthropic_types

anthropic_messages, system_prompt = self._build_anthropic_messages(messages)

anthropic_params = {
Expand All @@ -341,10 +367,35 @@ async def chat_completion_create(
"max_tokens": 1024,
**invocation_parameters,
}

async with self.client.messages.stream(**anthropic_params) as stream:
async for text in stream.text_stream:
yield TextChunk(content=text)
async for event in stream:
if isinstance(event, anthropic_types.RawMessageStartEvent):
if self._set_span_attributes:
self._set_span_attributes(
{LLM_TOKEN_COUNT_PROMPT: event.message.usage.input_tokens}
)
elif isinstance(event, anthropic_streaming.TextEvent):
yield TextChunk(content=event.text)
elif isinstance(event, anthropic_streaming.MessageStopEvent):
if self._set_span_attributes:
self._set_span_attributes(
{LLM_TOKEN_COUNT_COMPLETION: event.message.usage.output_tokens}
)
elif isinstance(
event,
(
anthropic_types.RawContentBlockStartEvent,
anthropic_types.RawContentBlockDeltaEvent,
anthropic_types.RawMessageDeltaEvent,
anthropic_streaming.ContentBlockStopEvent,
),
):
# event types emitted by the stream that don't contain useful information
pass
elif isinstance(event, anthropic_streaming.InputJsonEvent):
raise NotImplementedError
else:
assert_never(event)

def _build_anthropic_messages(
self,
Expand All @@ -366,10 +417,6 @@ def _build_anthropic_messages(

return anthropic_messages, system_prompt

@property
def attributes(self) -> Dict[str, Any]:
return dict()


@strawberry.type
class Subscription:
Expand All @@ -383,8 +430,6 @@ async def chat_completion(
if llm_client_class is None:
raise ValueError(f"No LLM client registered for provider '{provider_key}'")

llm_client = llm_client_class(model=input.model, api_key=input.api_key)

messages = [
(
message.role,
Expand Down Expand Up @@ -424,6 +469,9 @@ async def chat_completion(
text_chunks: List[TextChunk] = []
tool_call_chunks: DefaultDict[ToolCallID, List[ToolCallChunk]] = defaultdict(list)

llm_client = llm_client_class(
model=input.model, api_key=input.api_key, set_span_attributes=span.set_attributes
)
async for chunk in llm_client.chat_completion_create(
messages=messages,
tools=input.tools or [],
Expand All @@ -438,13 +486,11 @@ async def chat_completion(
tool_call_chunks[chunk.id].append(chunk)

span.set_status(StatusCode.OK)
llm_client_attributes = llm_client.attributes

span.set_attributes(
dict(
chain(
_output_value_and_mime_type(response_chunks),
llm_client_attributes.items(),
_llm_output_messages(text_chunks, tool_call_chunks),
)
)
Expand All @@ -456,8 +502,8 @@ async def chat_completion(
assert (attributes := finished_span.attributes) is not None
start_time = _datetime(epoch_nanoseconds=finished_span.start_time)
end_time = _datetime(epoch_nanoseconds=finished_span.end_time)
prompt_tokens = llm_client_attributes.get(LLM_TOKEN_COUNT_PROMPT, 0)
completion_tokens = llm_client_attributes.get(LLM_TOKEN_COUNT_COMPLETION, 0)
prompt_tokens = attributes.get(LLM_TOKEN_COUNT_PROMPT, 0)
completion_tokens = attributes.get(LLM_TOKEN_COUNT_COMPLETION, 0)
trace_id = _hex(finished_span.context.trace_id)
span_id = _hex(finished_span.context.span_id)
status = finished_span.status
Expand Down Expand Up @@ -524,12 +570,6 @@ def _llm_tools(tools: List[JSONScalarType]) -> Iterator[Tuple[str, Any]]:
yield f"{LLM_TOOLS}.{tool_index}.{TOOL_JSON_SCHEMA}", json.dumps(tool)


def _llm_token_counts(usage: "CompletionUsage") -> Iterator[Tuple[str, Any]]:
yield LLM_TOKEN_COUNT_PROMPT, usage.prompt_tokens
yield LLM_TOKEN_COUNT_COMPLETION, usage.completion_tokens
yield LLM_TOKEN_COUNT_TOTAL, usage.total_tokens


def _input_value_and_mime_type(input: ChatCompletionInput) -> Iterator[Tuple[str, Any]]:
assert (api_key := "api_key") in (input_data := jsonify(input))
input_data = {k: v for k, v in input_data.items() if k != api_key}
Expand Down
6 changes: 4 additions & 2 deletions tests/unit/server/api/test_subscriptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,8 +642,8 @@ async def test_anthropic_text_response_emits_expected_payloads_and_records_expec
assert isinstance(token_count_total := span.pop("tokenCountTotal"), int)
assert isinstance(token_count_prompt := span.pop("tokenCountPrompt"), int)
assert isinstance(token_count_completion := span.pop("tokenCountCompletion"), int)
assert token_count_prompt == 0
assert token_count_completion == 0
assert token_count_prompt > 0
assert token_count_completion > 0
assert token_count_total == token_count_prompt + token_count_completion
assert (input := span.pop("input")).pop("mimeType") == "json"
assert (input_value := input.pop("value"))
Expand Down Expand Up @@ -672,6 +672,8 @@ async def test_anthropic_text_response_emits_expected_payloads_and_records_expec
assert attributes.pop(OPENINFERENCE_SPAN_KIND) == LLM
assert attributes.pop(LLM_MODEL_NAME) == "claude-3-5-sonnet-20240620"
assert attributes.pop(LLM_INVOCATION_PARAMETERS) == json.dumps({"temperature": 0.1})
assert attributes.pop(LLM_TOKEN_COUNT_PROMPT) == token_count_prompt
assert attributes.pop(LLM_TOKEN_COUNT_COMPLETION) == token_count_completion
assert attributes.pop(INPUT_VALUE)
assert attributes.pop(INPUT_MIME_TYPE) == JSON
assert attributes.pop(OUTPUT_VALUE)
Expand Down

0 comments on commit 2eae8c5

Please sign in to comment.