diff --git a/src/phoenix/server/api/subscriptions.py b/src/phoenix/server/api/subscriptions.py index fb984500d2..49768f95cd 100644 --- a/src/phoenix/server/api/subscriptions.py +++ b/src/phoenix/server/api/subscriptions.py @@ -65,7 +65,7 @@ PLAYGROUND_PROJECT_NAME = "playground" -ToolCallIndex: TypeAlias = int +ToolCallID: TypeAlias = str @strawberry.enum @@ -315,8 +315,10 @@ def _build_anthropic_messages( anthropic_messages.append({"role": "assistant", "content": content}) elif role == ChatCompletionMessageRole.SYSTEM: system_prompt += content + "\n" + elif role == ChatCompletionMessageRole.TOOL: + raise NotImplementedError else: - raise ValueError(f"Unsupported role: {role}") + assert_never(role) return anthropic_messages, system_prompt @@ -369,7 +371,7 @@ async def chat_completion( ) as span: response_chunks = [] text_chunks: List[TextChunk] = [] - tool_call_chunks: DefaultDict[ToolCallIndex, List[ToolCallChunk]] = defaultdict(list) + tool_call_chunks: DefaultDict[ToolCallID, List[ToolCallChunk]] = defaultdict(list) async for chunk in llm_client.chat_completion_create( messages=messages, @@ -382,8 +384,7 @@ async def chat_completion( text_chunks.append(chunk) elif isinstance(chunk, ToolCallChunk): yield chunk - tool_call_index = int(chunk.id) if chunk.id.isdigit() else 0 - tool_call_chunks[tool_call_index].append(chunk) + tool_call_chunks[chunk.id].append(chunk) span.set_status(StatusCode.OK) llm_client_attributes = llm_client.attributes @@ -499,7 +500,7 @@ def _llm_input_messages( def _llm_output_messages( text_chunks: List[TextChunk], - tool_call_chunks: DefaultDict[ToolCallIndex, List[ToolCallChunk]], + tool_call_chunks: DefaultDict[ToolCallID, List[ToolCallChunk]], ) -> Iterator[Tuple[str, Any]]: yield f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_ROLE}", "assistant" if content := "".join(chunk.content for chunk in text_chunks):