Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: update OpenAIInstrumentor to support openai>=1.0.0 and deprecate support for openai<1.0.0 #1723

Merged
merged 10 commits into from
Nov 10, 2023
89 changes: 49 additions & 40 deletions src/phoenix/trace/openai/instrumentor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@
List,
Mapping,
Optional,
cast,
)

from typing_extensions import TypeGuard

from phoenix.trace.schemas import (
SpanAttributes,
SpanEvent,
Expand Down Expand Up @@ -44,7 +45,7 @@
from ..tracer import Tracer

if TYPE_CHECKING:
from openai.openai_response import OpenAIResponse
from openai.types.chat import ChatCompletion


Parameters = Mapping[str, Any]
Expand Down Expand Up @@ -75,21 +76,21 @@ def instrument(self) -> None:
"""
openai = import_package("openai")
is_instrumented = hasattr(
openai.api_requestor.APIRequestor.request,
openai.OpenAI,
INSTRUMENTED_ATTRIBUTE_NAME,
)
if not is_instrumented:
openai.api_requestor.APIRequestor.request = _wrap_openai_api_requestor(
openai.api_requestor.APIRequestor.request, self._tracer
openai.OpenAI.request = _wrapped_openai_client_request_function(
openai.OpenAI.request, self._tracer
)
setattr(
openai.api_requestor.APIRequestor.request,
openai.OpenAI,
INSTRUMENTED_ATTRIBUTE_NAME,
True,
)


def _wrap_openai_api_requestor(
def _wrapped_openai_client_request_function(
request_fn: Callable[..., Any], tracer: Tracer
) -> Callable[..., Any]:
"""Wraps the OpenAI APIRequestor.request method to create spans for each API call.
Expand All @@ -105,9 +106,10 @@ def _wrap_openai_api_requestor(
def wrapped(*args: Any, **kwargs: Any) -> Any:
call_signature = signature(request_fn)
bound_arguments = call_signature.bind(*args, **kwargs)
parameters = bound_arguments.arguments["params"]
is_streaming = parameters.get("stream", False)
url = bound_arguments.arguments["url"]
is_streaming = bound_arguments.arguments["stream"]
options = bound_arguments.arguments["options"]
parameters = options.json_data
url = options.url
current_status_code = SpanStatusCode.UNSET
events: List[SpanEvent] = []
attributes: SpanAttributes = dict()
Expand All @@ -118,13 +120,13 @@ def wrapped(*args: Any, **kwargs: Any) -> Any:
) in _PARAMETER_ATTRIBUTE_FUNCTIONS.items():
if (attribute_value := get_parameter_attribute_fn(parameters)) is not None:
attributes[attribute_name] = attribute_value
outputs = None
response = None
try:
start_time = datetime.now()
outputs = request_fn(*args, **kwargs)
response = request_fn(*args, **kwargs)
end_time = datetime.now()
current_status_code = SpanStatusCode.OK
return outputs
return response
except Exception as error:
end_time = datetime.now()
current_status_code = SpanStatusCode.ERROR
Expand All @@ -138,16 +140,17 @@ def wrapped(*args: Any, **kwargs: Any) -> Any:
)
raise
finally:
if outputs:
response = outputs[0]
if _is_chat_completion(response):
for (
attribute_name,
get_response_attribute_fn,
) in _RESPONSE_ATTRIBUTE_FUNCTIONS.items():
if (attribute_value := get_response_attribute_fn(response)) is not None:
get_chat_completion_attribute_fn,
) in _CHAT_COMPLETION_ATTRIBUTE_FUNCTIONS.items():
if (
attribute_value := get_chat_completion_attribute_fn(response)
) is not None:
attributes[attribute_name] = attribute_value
tracer.create_span(
name="openai.ChatCompletion.create",
name="OpenAI Chat Completion",
span_kind=SpanKind.LLM,
start_time=start_time,
end_time=end_time,
Expand Down Expand Up @@ -182,48 +185,46 @@ def _llm_invocation_parameters(
return json.dumps(parameters)


def _output_value(response: "OpenAIResponse") -> str:
return json.dumps(response.data)
def _output_value(chat_completion: "ChatCompletion") -> str:
return chat_completion.json()


def _output_mime_type(_: Any) -> MimeType:
return MimeType.JSON


def _llm_output_messages(response: "OpenAIResponse") -> List[OpenInferenceMessage]:
def _llm_output_messages(chat_completion: "ChatCompletion") -> List[OpenInferenceMessage]:
return [
_to_openinference_message(choice["message"], expects_name=False)
for choice in response.data["choices"]
_to_openinference_message(choice.message.dict(), expects_name=False)
for choice in chat_completion.choices
]


def _llm_token_count_prompt(response: "OpenAIResponse") -> Optional[int]:
if token_usage := response.data.get("usage"):
return cast(int, token_usage["prompt_tokens"])
def _llm_token_count_prompt(chat_completion: "ChatCompletion") -> Optional[int]:
if completion_usage := chat_completion.usage:
return completion_usage.prompt_tokens
return None


def _llm_token_count_completion(response: "OpenAIResponse") -> Optional[int]:
if token_usage := response.data.get("usage"):
return cast(int, token_usage["completion_tokens"])
def _llm_token_count_completion(chat_completion: "ChatCompletion") -> Optional[int]:
if completion_usage := chat_completion.usage:
return completion_usage.completion_tokens
return None


def _llm_token_count_total(response: "OpenAIResponse") -> Optional[int]:
if token_usage := response.data.get("usage"):
return cast(int, token_usage["total_tokens"])
def _llm_token_count_total(chat_completion: "ChatCompletion") -> Optional[int]:
if completion_usage := chat_completion.usage:
return completion_usage.total_tokens
return None


def _llm_function_call(
response: "OpenAIResponse",
chat_completion: "ChatCompletion",
) -> Optional[str]:
choices = response.data["choices"]
choices = chat_completion.choices
choice = choices[0]
if choice.get("finish_reason") == "function_call" and (
function_call_data := choice["message"].get("function_call")
):
return json.dumps(function_call_data)
if choice.finish_reason == "function_call" and (function_call := choice.message.function_call):
return function_call.json()
return None


Expand Down Expand Up @@ -274,7 +275,7 @@ def _to_openinference_message(
LLM_INPUT_MESSAGES: _llm_input_messages,
LLM_INVOCATION_PARAMETERS: _llm_invocation_parameters,
}
_RESPONSE_ATTRIBUTE_FUNCTIONS: Dict[str, Callable[["OpenAIResponse"], Any]] = {
_CHAT_COMPLETION_ATTRIBUTE_FUNCTIONS: Dict[str, Callable[["ChatCompletion"], Any]] = {
OUTPUT_VALUE: _output_value,
OUTPUT_MIME_TYPE: _output_mime_type,
LLM_OUTPUT_MESSAGES: _llm_output_messages,
Expand All @@ -283,3 +284,11 @@ def _to_openinference_message(
LLM_TOKEN_COUNT_TOTAL: _llm_token_count_total,
LLM_FUNCTION_CALL: _llm_function_call,
}


def _is_chat_completion(response: Any) -> TypeGuard["ChatCompletion"]:
"""
Type guard for ChatCompletion.
"""
openai = import_package("openai")
return isinstance(response, openai.types.chat.ChatCompletion)
Loading