Skip to content

Commit

Permalink
refactor(anthropic): reorganize imports and improve response handling
Browse files Browse the repository at this point in the history
  • Loading branch information
teocns committed Nov 14, 2024
1 parent a0d85c5 commit e926814
Showing 1 changed file with 54 additions and 20 deletions.
74 changes: 54 additions & 20 deletions agentops/llms/anthropic.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
import json
import pprint
from typing import Optional

from anthropic import APIResponse
from anthropic._legacy_response import LegacyAPIResponse

from agentops.llms.instrumented_provider import InstrumentedProvider
from agentops.time_travel import fetch_completion_override_from_time_travel_cache

from ..event import ErrorEvent, LLMEvent, ToolEvent
from ..session import Session
from ..log_config import logger
from ..helpers import check_call_stack_for_agent_id, get_ISO_time
from ..log_config import logger
from ..session import Session
from ..singleton import singleton


@singleton
class AnthropicProvider(InstrumentedProvider):

original_create = None
original_create_async = None

Expand All @@ -27,9 +30,9 @@ def handle_response(
self, response, kwargs, init_timestamp, session: Optional[Session] = None
):
"""Handle responses for Anthropic"""
from anthropic import Stream, AsyncStream
from anthropic.resources import AsyncMessages
import anthropic.resources.beta.messages.messages as beta_messages
from anthropic import AsyncStream, Stream
from anthropic.resources import AsyncMessages
from anthropic.types import Message

llm_event = LLMEvent(init_timestamp=init_timestamp, params=kwargs)
Expand Down Expand Up @@ -66,9 +69,9 @@ def handle_stream_chunk(chunk: Message):
llm_event.completion["content"] += chunk.delta.text

elif chunk.delta.type == "input_json_delta":
self.tool_event[self.tool_id].logs[
"input"
] += chunk.delta.partial_json
self.tool_event[self.tool_id].logs["input"] += (
chunk.delta.partial_json
)

elif chunk.type == "content_block_stop":
pass
Expand Down Expand Up @@ -125,17 +128,45 @@ async def async_generator():

# Handle object responses
try:
llm_event.returns = response.model_dump()
llm_event.agent_id = check_call_stack_for_agent_id()
llm_event.prompt = kwargs["messages"]
llm_event.prompt_tokens = response.usage.input_tokens
llm_event.completion = {
"role": "assistant",
"content": response.content[0].text,
}
llm_event.completion_tokens = response.usage.output_tokens
llm_event.model = response.model
# AttributeError("'LegacyAPIResponse' object has no attribute 'model_dump'")
if isinstance(response, (APIResponse, LegacyAPIResponse)) or not hasattr(
response, "model_dump"
):
"""
response's data structure:
dict_keys(['id', 'type', 'role', 'model', 'content', 'stop_reason', 'stop_sequence', 'usage'])
{'id': 'msg_018Gk9N2pcWaYLS7mxXbPD5i', 'type': 'message', 'role': 'assistant', 'model': 'claude-3-5-sonnet-20241022', 'content': [{'type': 'text', 'text': 'I\'ll help you investigate, but I notice you\'ve just written "stack" without any context. Before I can assist effectively, I need to know:\n\n1. What kind of stack information are you looking for? (e.g., technology stack, call stack, stack trace from an error)\n2. If there\'s a specific program or process you want to examine\n3. What specific details you\'re trying to understand\n\nPlease provide more details about what stack information you\'re interested in, and I\'ll help you analyze it using the available tools.'}], 'stop_reason': 'end_turn', 'stop_sequence': None, 'usage': {'input_tokens': 2419, 'output_tokens': 116}}
"""
response_data = json.loads(response.text)
llm_event.returns = response_data
llm_event.model = response_data["model"]
llm_event.completion = {
"role": response_data.get("role"),
"content": response_data.get("content")[0].get("text")
if response_data.get("content")
else "",
}
if usage := response_data.get("usage"):
llm_event.prompt_tokens = usage.get("input_tokens")
llm_event.completion_tokens = usage.get("output_tokens")

# Han
else:
# This bets on the fact that the response object has a model_dump method
llm_event.returns = response.model_dump()
llm_event.prompt_tokens = response.usage.input_tokens
llm_event.completion_tokens = response.usage.output_tokens

llm_event.completion = {
"role": "assistant",
"content": response.content[0].text,
}
llm_event.model = response.model

llm_event.end_timestamp = get_ISO_time()
llm_event.prompt = kwargs["messages"]
llm_event.agent_id = check_call_stack_for_agent_id()

self._safe_record(session, llm_event)
except Exception as e:
Expand All @@ -155,8 +186,8 @@ def override(self):
self._override_async_completion()

def _override_completion(self):
from anthropic.resources import messages
import anthropic.resources.beta.messages.messages as beta_messages
from anthropic.resources import messages
from anthropic.types import (
Message,
RawContentBlockDeltaEvent,
Expand All @@ -175,6 +206,9 @@ def create_patched_function(is_beta=False):
def patched_function(*args, **kwargs):
init_timestamp = get_ISO_time()
session = kwargs.get("session", None)
# if is_beta:
# breakpoint()

if "session" in kwargs.keys():
del kwargs["session"]

Expand Down Expand Up @@ -229,6 +263,7 @@ def patched_function(*args, **kwargs):
beta_messages.Messages.create = create_patched_function(is_beta=True)

def _override_async_completion(self):
import anthropic.resources.beta.messages.messages as beta_messages
from anthropic.resources import messages
from anthropic.types import (
Message,
Expand All @@ -239,7 +274,6 @@ def _override_async_completion(self):
RawMessageStartEvent,
RawMessageStopEvent,
)
import anthropic.resources.beta.messages.messages as beta_messages

# Store the original method
self.original_create_async = messages.AsyncMessages.create
Expand Down

0 comments on commit e926814

Please sign in to comment.