diff --git a/src/phoenix/trace/openai/instrumentor.py b/src/phoenix/trace/openai/instrumentor.py index 9a812a7dee..b4661b05a4 100644 --- a/src/phoenix/trace/openai/instrumentor.py +++ b/src/phoenix/trace/openai/instrumentor.py @@ -10,9 +10,10 @@ List, Mapping, Optional, - cast, ) +from typing_extensions import TypeGuard + from phoenix.trace.schemas import ( SpanAttributes, SpanEvent, @@ -44,7 +45,7 @@ from ..tracer import Tracer if TYPE_CHECKING: - from openai.openai_response import OpenAIResponse + from openai.types.chat import ChatCompletion Parameters = Mapping[str, Any] @@ -75,21 +76,21 @@ def instrument(self) -> None: """ openai = import_package("openai") is_instrumented = hasattr( - openai.api_requestor.APIRequestor.request, + openai.OpenAI, INSTRUMENTED_ATTRIBUTE_NAME, ) if not is_instrumented: - openai.api_requestor.APIRequestor.request = _wrap_openai_api_requestor( - openai.api_requestor.APIRequestor.request, self._tracer + openai.OpenAI.request = _wrapped_openai_client_request_function( + openai.OpenAI.request, self._tracer ) setattr( - openai.api_requestor.APIRequestor.request, + openai.OpenAI, INSTRUMENTED_ATTRIBUTE_NAME, True, ) -def _wrap_openai_api_requestor( +def _wrapped_openai_client_request_function( request_fn: Callable[..., Any], tracer: Tracer ) -> Callable[..., Any]: """Wraps the OpenAI APIRequestor.request method to create spans for each API call. @@ -105,9 +106,10 @@ def _wrap_openai_api_requestor( def wrapped(*args: Any, **kwargs: Any) -> Any: call_signature = signature(request_fn) bound_arguments = call_signature.bind(*args, **kwargs) - parameters = bound_arguments.arguments["params"] - is_streaming = parameters.get("stream", False) - url = bound_arguments.arguments["url"] + is_streaming = bound_arguments.arguments["stream"] + options = bound_arguments.arguments["options"] + parameters = options.json_data + url = options.url current_status_code = SpanStatusCode.UNSET events: List[SpanEvent] = [] attributes: SpanAttributes = dict() @@ -118,13 +120,13 @@ def wrapped(*args: Any, **kwargs: Any) -> Any: ) in _PARAMETER_ATTRIBUTE_FUNCTIONS.items(): if (attribute_value := get_parameter_attribute_fn(parameters)) is not None: attributes[attribute_name] = attribute_value - outputs = None + response = None try: start_time = datetime.now() - outputs = request_fn(*args, **kwargs) + response = request_fn(*args, **kwargs) end_time = datetime.now() current_status_code = SpanStatusCode.OK - return outputs + return response except Exception as error: end_time = datetime.now() current_status_code = SpanStatusCode.ERROR @@ -138,16 +140,17 @@ def wrapped(*args: Any, **kwargs: Any) -> Any: ) raise finally: - if outputs: - response = outputs[0] + if _is_chat_completion(response): for ( attribute_name, - get_response_attribute_fn, - ) in _RESPONSE_ATTRIBUTE_FUNCTIONS.items(): - if (attribute_value := get_response_attribute_fn(response)) is not None: + get_chat_completion_attribute_fn, + ) in _CHAT_COMPLETION_ATTRIBUTE_FUNCTIONS.items(): + if ( + attribute_value := get_chat_completion_attribute_fn(response) + ) is not None: attributes[attribute_name] = attribute_value tracer.create_span( - name="openai.ChatCompletion.create", + name="OpenAI Chat Completion", span_kind=SpanKind.LLM, start_time=start_time, end_time=end_time, @@ -182,48 +185,46 @@ def _llm_invocation_parameters( return json.dumps(parameters) -def _output_value(response: "OpenAIResponse") -> str: - return json.dumps(response.data) +def _output_value(chat_completion: "ChatCompletion") -> str: + return chat_completion.json() def _output_mime_type(_: Any) -> MimeType: return MimeType.JSON -def _llm_output_messages(response: "OpenAIResponse") -> List[OpenInferenceMessage]: +def _llm_output_messages(chat_completion: "ChatCompletion") -> List[OpenInferenceMessage]: return [ - _to_openinference_message(choice["message"], expects_name=False) - for choice in response.data["choices"] + _to_openinference_message(choice.message.dict(), expects_name=False) + for choice in chat_completion.choices ] -def _llm_token_count_prompt(response: "OpenAIResponse") -> Optional[int]: - if token_usage := response.data.get("usage"): - return cast(int, token_usage["prompt_tokens"]) +def _llm_token_count_prompt(chat_completion: "ChatCompletion") -> Optional[int]: + if completion_usage := chat_completion.usage: + return completion_usage.prompt_tokens return None -def _llm_token_count_completion(response: "OpenAIResponse") -> Optional[int]: - if token_usage := response.data.get("usage"): - return cast(int, token_usage["completion_tokens"]) +def _llm_token_count_completion(chat_completion: "ChatCompletion") -> Optional[int]: + if completion_usage := chat_completion.usage: + return completion_usage.completion_tokens return None -def _llm_token_count_total(response: "OpenAIResponse") -> Optional[int]: - if token_usage := response.data.get("usage"): - return cast(int, token_usage["total_tokens"]) +def _llm_token_count_total(chat_completion: "ChatCompletion") -> Optional[int]: + if completion_usage := chat_completion.usage: + return completion_usage.total_tokens return None def _llm_function_call( - response: "OpenAIResponse", + chat_completion: "ChatCompletion", ) -> Optional[str]: - choices = response.data["choices"] + choices = chat_completion.choices choice = choices[0] - if choice.get("finish_reason") == "function_call" and ( - function_call_data := choice["message"].get("function_call") - ): - return json.dumps(function_call_data) + if choice.finish_reason == "function_call" and (function_call := choice.message.function_call): + return function_call.json() return None @@ -274,7 +275,7 @@ def _to_openinference_message( LLM_INPUT_MESSAGES: _llm_input_messages, LLM_INVOCATION_PARAMETERS: _llm_invocation_parameters, } -_RESPONSE_ATTRIBUTE_FUNCTIONS: Dict[str, Callable[["OpenAIResponse"], Any]] = { +_CHAT_COMPLETION_ATTRIBUTE_FUNCTIONS: Dict[str, Callable[["ChatCompletion"], Any]] = { OUTPUT_VALUE: _output_value, OUTPUT_MIME_TYPE: _output_mime_type, LLM_OUTPUT_MESSAGES: _llm_output_messages, @@ -283,3 +284,11 @@ def _to_openinference_message( LLM_TOKEN_COUNT_TOTAL: _llm_token_count_total, LLM_FUNCTION_CALL: _llm_function_call, } + + +def _is_chat_completion(response: Any) -> TypeGuard["ChatCompletion"]: + """ + Type guard for ChatCompletion. + """ + openai = import_package("openai") + return isinstance(response, openai.types.chat.ChatCompletion) diff --git a/tests/trace/openai/test_instrumentor.py b/tests/trace/openai/test_instrumentor.py index 36eff54151..e18b0e0959 100644 --- a/tests/trace/openai/test_instrumentor.py +++ b/tests/trace/openai/test_instrumentor.py @@ -1,9 +1,10 @@ import json +import sys from importlib import reload +from types import ModuleType import openai import pytest -import respx from httpx import Response from openai import AuthenticationError, OpenAI from phoenix.trace.openai.instrumentor import OpenAIInstrumentor @@ -31,29 +32,43 @@ MimeType, ) from phoenix.trace.tracer import Tracer +from respx import MockRouter @pytest.fixture -def reload_openai_api_requestor() -> None: - """Reloads openai.api_requestor to reset the instrumented class method.""" - reload(openai.api_requestor) +def openai_module() -> ModuleType: + """ + Reloads openai module to reset patched class. Both the top-level module and + the sub-module containing the patched client class must be reloaded. + """ + # Cannot be reloaded with reload(openai._client) due to a naming conflict with a variable. + reload(sys.modules["openai._client"]) + return reload(openai) @pytest.fixture -def openai_api_key(monkeypatch) -> None: +def openai_api_key(monkeypatch: pytest.MonkeyPatch) -> None: + """ + Monkeypatches the environment variable for the OpenAI API key. + """ api_key = "sk-0123456789" monkeypatch.setenv("OPENAI_API_KEY", api_key) return api_key @pytest.fixture -def client(openai_api_key) -> OpenAI: - return OpenAI(api_key=openai_api_key) +def client(openai_api_key: str, openai_module: ModuleType) -> OpenAI: + """ + Instantiates the OpenAI client using the reloaded openai module, which is + necessary when running multiple tests at once due to the patch applied by + the OpenAIInstrumentor. + """ + return openai_module.OpenAI(api_key=openai_api_key) -@respx.mock def test_openai_instrumentor_includes_llm_attributes_on_chat_completion_success( - reload_openai_api_requestor, client + client: OpenAI, + respx_mock: MockRouter, ) -> None: tracer = Tracer() OpenAIInstrumentor(tracer).instrument() @@ -61,7 +76,7 @@ def test_openai_instrumentor_includes_llm_attributes_on_chat_completion_success( messages = [{"role": "user", "content": "Who won the World Cup in 2018?"}] temperature = 0.23 expected_response_text = "France won the World Cup in 2018." - respx.post(url="https://api.openai.com/v1/chat/completions").mock( + respx_mock.post("https://api.openai.com/v1/chat/completions").mock( return_value=Response( status_code=200, json={ @@ -86,7 +101,7 @@ def test_openai_instrumentor_includes_llm_attributes_on_chat_completion_success( response = client.chat.completions.create( model=model, messages=messages, temperature=temperature ) - response_text = response.choices[0]["message"]["content"] + response_text = response.choices[0].message.content assert response_text == expected_response_text @@ -122,9 +137,9 @@ def test_openai_instrumentor_includes_llm_attributes_on_chat_completion_success( assert attributes[OUTPUT_MIME_TYPE] == MimeType.JSON -@respx.mock def test_openai_instrumentor_includes_function_call_attributes( - reload_openai_api_requestor, openai_api_key + client: OpenAI, + respx_mock: MockRouter, ) -> None: tracer = Tracer() OpenAIInstrumentor(tracer).instrument() @@ -149,9 +164,9 @@ def test_openai_instrumentor_includes_function_call_attributes( } ] model = "gpt-4" - respx.post(url="https://api.openai.com/v1/chat/completions").mock( + respx_mock.post("https://api.openai.com/v1/chat/completions").mock( return_value=Response( - status=200, + status_code=200, json={ "id": "chatcmpl-85eqK3CCNTHQcTN0ZoWqL5B0OO5ip", "object": "chat.completion", @@ -177,10 +192,9 @@ def test_openai_instrumentor_includes_function_call_attributes( ) response = client.chat.completions.create(model=model, messages=messages, functions=functions) - function_call_data = response.choices[0]["message"]["function_call"] - assert set(function_call_data.keys()) == {"name", "arguments"} - assert function_call_data["name"] == "get_current_weather" - assert json.loads(function_call_data["arguments"]) == {"location": "Boston, MA"} + function_call = response.choices[0].message.function_call + assert function_call.name == "get_current_weather" + assert json.loads(function_call.arguments) == {"location": "Boston, MA"} spans = list(tracer.get_spans()) assert len(spans) == 1 @@ -213,9 +227,9 @@ def test_openai_instrumentor_includes_function_call_attributes( assert span.events == [] -@respx.mock def test_openai_instrumentor_includes_function_call_message_attributes( - reload_openai_api_requestor, client + client: OpenAI, + respx_mock: MockRouter, ) -> None: tracer = Tracer() OpenAIInstrumentor(tracer).instrument() @@ -253,9 +267,9 @@ def test_openai_instrumentor_includes_function_call_message_attributes( } ] model = "gpt-4" - respx.post(url="https://api.openai.com/v1/chat/completions").mock( + respx_mock.post("https://api.openai.com/v1/chat/completions").mock( return_value=Response( - status=200, + status_code=200, json={ "id": "chatcmpl-85euCH0n5RuhAWEmogmak8cDwyQcb", "object": "chat.completion", @@ -280,7 +294,7 @@ def test_openai_instrumentor_includes_function_call_message_attributes( ) response = client.chat.completions.create(model=model, messages=messages, functions=functions) - response_text = response.choices[0]["message"]["content"] + response_text = response.choices[0].message.content spans = list(tracer.get_spans()) span = spans[0] attributes = span.attributes @@ -314,15 +328,15 @@ def test_openai_instrumentor_includes_function_call_message_attributes( assert LLM_FUNCTION_CALL not in attributes -@respx.mock def test_openai_instrumentor_records_authentication_error( - reload_openai_api_requestor, openai_api_key + client: OpenAI, + respx_mock: MockRouter, ) -> None: tracer = Tracer() OpenAIInstrumentor(tracer).instrument() - respx.post(url="https://api.openai.com/v1/chat/completions").mock( + respx_mock.post("https://api.openai.com/v1/chat/completions").mock( return_value=Response( - status=401, + status_code=401, json={ "error": { "message": "error-message", @@ -348,21 +362,21 @@ def test_openai_instrumentor_records_authentication_error( assert isinstance(event, SpanException) attributes = event.attributes assert attributes[EXCEPTION_TYPE] == "AuthenticationError" - assert attributes[EXCEPTION_MESSAGE] == "error-message" + assert "error-message" in attributes[EXCEPTION_MESSAGE] assert "Traceback" in attributes[EXCEPTION_STACKTRACE] -@respx.mock def test_openai_instrumentor_does_not_interfere_with_completions_api( - reload_openai_api_requestor, client + client: OpenAI, + respx_mock: MockRouter, ) -> None: tracer = Tracer() OpenAIInstrumentor(tracer).instrument() model = "gpt-3.5-turbo-instruct" prompt = "Who won the World Cup in 2018?" - respx.post(url="https://api.openai.com/v1/completions").mock( + respx_mock.post("https://api.openai.com/v1/completions").mock( return_value=Response( - status=200, + status_code=200, json={ "id": "cmpl-85hqvKwCud3s3DWc80I0OeNmkfjSM", "object": "text_completion", @@ -381,25 +395,25 @@ def test_openai_instrumentor_does_not_interfere_with_completions_api( ) ) response = client.completions.create(model=model, prompt=prompt) - response_text = response.choices[0]["text"] + response_text = response.choices[0].text spans = list(tracer.get_spans()) assert "france" in response_text.lower() or "french" in response_text.lower() assert spans == [] -@respx.mock def test_openai_instrumentor_instrument_method_is_idempotent( - reload_openai_api_requestor, openai_api_key + client: OpenAI, + respx_mock: MockRouter, ) -> None: tracer = Tracer() OpenAIInstrumentor(tracer).instrument() # first call OpenAIInstrumentor(tracer).instrument() # second call model = "gpt-4" messages = [{"role": "user", "content": "Who won the World Cup in 2018?"}] - respx.post(url="https://api.openai.com/v1/chat/completions").mock( + respx_mock.post("https://api.openai.com/v1/chat/completions").mock( return_value=Response( - status=200, + status_code=200, json={ "id": "chatcmpl-85evOVGg6afU8iqiUsRtYQ5lYnGwn", "object": "chat.completion", @@ -420,7 +434,7 @@ def test_openai_instrumentor_instrument_method_is_idempotent( ) ) response = client.chat.completions.create(model=model, messages=messages) - response_text = response.choices[0]["message"]["content"] + response_text = response.choices[0].message.content spans = list(tracer.get_spans()) span = spans[0] diff --git a/tutorials/evals/evaluate_code_readability_classifications.ipynb b/tutorials/evals/evaluate_code_readability_classifications.ipynb index 6baeadf7bf..17ade2ff26 100644 --- a/tutorials/evals/evaluate_code_readability_classifications.ipynb +++ b/tutorials/evals/evaluate_code_readability_classifications.ipynb @@ -759,7 +759,9 @@ ], "source": [ "# Let's view the data\n", - "merged_df = pd.merge(small_df_sample, readability_classifications_df, left_index=True, right_index=True)\n", + "merged_df = pd.merge(\n", + " small_df_sample, readability_classifications_df, left_index=True, right_index=True\n", + ")\n", "merged_df[[\"query\", \"code\", \"label\", \"explanation\"]].head()" ] }, @@ -800,7 +802,9 @@ "readability_classifications = llm_classify(\n", " dataframe=df,\n", " template=CODE_READABILITY_PROMPT_TEMPLATE_STR,\n", - " model=OpenAIModel(model_name=\"gpt-3.5-turbo\", temperature=0.0, request_timeout=20, max_retries=0),\n", + " model=OpenAIModel(\n", + " model_name=\"gpt-3.5-turbo\", temperature=0.0, request_timeout=20, max_retries=0\n", + " ),\n", " rails=rails,\n", ")[\"label\"]" ] diff --git a/tutorials/evals/evaluate_hallucination_classifications.ipynb b/tutorials/evals/evaluate_hallucination_classifications.ipynb index c898c25407..fca1206364 100644 --- a/tutorials/evals/evaluate_hallucination_classifications.ipynb +++ b/tutorials/evals/evaluate_hallucination_classifications.ipynb @@ -620,7 +620,9 @@ ], "source": [ "# Let's view the data\n", - "merged_df = pd.merge(small_df_sample, hallucination_classifications_df, left_index=True, right_index=True)\n", + "merged_df = pd.merge(\n", + " small_df_sample, hallucination_classifications_df, left_index=True, right_index=True\n", + ")\n", "merged_df[[\"query\", \"reference\", \"response\", \"is_hallucination\", \"label\", \"explanation\"]].head()" ] }, diff --git a/tutorials/evals/evaluate_relevance_classifications.ipynb b/tutorials/evals/evaluate_relevance_classifications.ipynb index 84756c141c..33a50339ed 100644 --- a/tutorials/evals/evaluate_relevance_classifications.ipynb +++ b/tutorials/evals/evaluate_relevance_classifications.ipynb @@ -645,7 +645,9 @@ ], "source": [ "# Let's view the data\n", - "merged_df = pd.merge(small_df_sample, relevance_classifications_df, left_index=True, right_index=True)\n", + "merged_df = pd.merge(\n", + " small_df_sample, relevance_classifications_df, left_index=True, right_index=True\n", + ")\n", "merged_df[[\"query\", \"reference\", \"label\", \"explanation\"]].head()" ] }, diff --git a/tutorials/tracing/openai_tracing_tutorial.ipynb b/tutorials/tracing/openai_tracing_tutorial.ipynb index c12e2365e4..53849d4bed 100644 --- a/tutorials/tracing/openai_tracing_tutorial.ipynb +++ b/tutorials/tracing/openai_tracing_tutorial.ipynb @@ -50,7 +50,7 @@ "metadata": {}, "outputs": [], "source": [ - "!pip install \"openai<1\" arize-phoenix jsonschema" + "!pip install \"openai>=1.0.0\" arize-phoenix jsonschema" ] }, { @@ -72,9 +72,9 @@ "from typing import Any, Dict, Literal, TypedDict\n", "\n", "import jsonschema\n", - "import openai\n", "import pandas as pd\n", "import phoenix as px\n", + "from openai import OpenAI\n", "from phoenix.trace.exporter import HttpExporter\n", "from phoenix.trace.openai import OpenAIInstrumentor\n", "from phoenix.trace.tracer import Tracer\n", @@ -91,7 +91,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## 2. Configure Your OpenAI API Key\n", + "## 2. Configure Your OpenAI API Key and Instantiate Your OpenAI Client\n", "\n", "Set your OpenAI API key if it is not already set as an environment variable." ] @@ -104,8 +104,7 @@ "source": [ "if not (openai_api_key := os.getenv(\"OPENAI_API_KEY\")):\n", " openai_api_key = getpass(\"🔑 Enter your OpenAI API key: \")\n", - "openai.api_key = openai_api_key\n", - "os.environ[\"OPENAI_API_KEY\"] = openai_api_key" + "client = OpenAI(api_key=openai_api_key)" ] }, { @@ -225,9 +224,13 @@ "\n", "@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))\n", "def extract_raw_travel_request_attributes_string(\n", - " travel_request: str, function_schema: Dict[str, Any], system_message: str, model: str = \"gpt-4\"\n", + " travel_request: str,\n", + " function_schema: Dict[str, Any],\n", + " system_message: str,\n", + " client: OpenAI,\n", + " model: str = \"gpt-4\",\n", ") -> str:\n", - " response = openai.ChatCompletion.create(\n", + " chat_completion = client.chat.completions.create(\n", " model=model,\n", " messages=[\n", " {\"role\": \"system\", \"content\": system_message},\n", @@ -238,8 +241,8 @@ " # The line below forces the LLM to call the function so that the output conforms to the schema.\n", " function_call={\"name\": function_schema[\"name\"]},\n", " )\n", - " function_call_data = response[\"choices\"][0][\"message\"][\"function_call\"]\n", - " return function_call_data[\"arguments\"]" + " function_call = chat_completion.choices[0].message.function_call\n", + " return function_call.arguments" ] }, { @@ -262,7 +265,7 @@ " print(travel_request)\n", " print()\n", " raw_travel_attributes = extract_raw_travel_request_attributes_string(\n", - " travel_request, function_schema, system_message\n", + " travel_request, function_schema, system_message, client\n", " )\n", " raw_travel_attributes_column.append(raw_travel_attributes)\n", " print(\"Raw Travel Attributes:\")\n",