From 06aaa6f41e399f8523cd5e0a226c3d98c0e87d24 Mon Sep 17 00:00:00 2001 From: Dustin Ngo Date: Fri, 20 Oct 2023 14:54:31 -0500 Subject: [PATCH] Add graceful fallback to langchain tracer --- src/phoenix/trace/langchain/tracer.py | 34 +++++++-------- tests/trace/langchain/test_tracer.py | 60 +++++++++++++++++++++++++++ 2 files changed, 77 insertions(+), 17 deletions(-) diff --git a/src/phoenix/trace/langchain/tracer.py b/src/phoenix/trace/langchain/tracer.py index 66c264698f..12489db443 100644 --- a/src/phoenix/trace/langchain/tracer.py +++ b/src/phoenix/trace/langchain/tracer.py @@ -2,16 +2,7 @@ import logging from copy import deepcopy from datetime import datetime -from typing import ( - Any, - Dict, - Iterable, - Iterator, - List, - Mapping, - Optional, - Tuple, -) +from typing import Any, Dict, Iterable, Iterator, List, Mapping, Optional, Tuple from uuid import UUID from langchain.callbacks.tracers.base import BaseTracer @@ -20,13 +11,7 @@ from langchain.schema.messages import BaseMessage from phoenix.trace.exporter import HttpExporter -from phoenix.trace.schemas import ( - Span, - SpanEvent, - SpanException, - SpanKind, - SpanStatusCode, -) +from phoenix.trace.schemas import Span, SpanEvent, SpanException, SpanKind, SpanStatusCode from phoenix.trace.semantic_conventions import ( DOCUMENT_CONTENT, DOCUMENT_METADATA, @@ -56,6 +41,7 @@ MimeType, ) from phoenix.trace.tracer import Tracer +from phoenix.utilities.error_handling import graceful_fallback logger = logging.getLogger(__name__) @@ -351,6 +337,20 @@ def _persist_run(self, run: Run) -> None: except Exception: logger.exception("Failed to convert run to spans") + def _null_fallback( + self, + serialized: Dict[str, Any], + messages: List[List[BaseMessage]], + *, + run_id: UUID, + tags: Optional[List[str]] = None, + parent_run_id: Optional[UUID] = None, + metadata: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> None: + pass + + @graceful_fallback(_null_fallback) def on_chat_model_start( self, serialized: Dict[str, Any], diff --git a/tests/trace/langchain/test_tracer.py b/tests/trace/langchain/test_tracer.py index e4cff40b08..95a3a83bce 100644 --- a/tests/trace/langchain/test_tracer.py +++ b/tests/trace/langchain/test_tracer.py @@ -1,5 +1,6 @@ from json import loads from typing import List +from unittest.mock import patch from uuid import UUID import numpy as np @@ -334,3 +335,62 @@ def test_tracer_retriever_with_exception() -> None: for span in spans.values(): assert json_string_to_span(span_to_json(span)) == span + + +@responses.activate +@pytest.mark.parametrize( + "messages", + [ + pytest.param( + [ + ChatMessage(role="system", content="system-message-content"), + ChatMessage(role="user", content="user-message-content"), + ChatMessage(role="assistant", content="assistant-message-content"), + ChatMessage(role="function", content="function-message-content"), + ], + id="chat-messages", + ), + pytest.param( + [ + SystemMessage(content="system-message-content"), + HumanMessage(content="user-message-content"), + AIMessage(content="assistant-message-content"), + FunctionMessage(name="function-name", content="function-message-content"), + ], + id="non-chat-messages", + ), + ], +) +def test_tracing_llm_chat_completions_fails_gracefully( + messages: List[BaseMessage], monkeypatch: pytest.MonkeyPatch, caplog +) -> None: + monkeypatch.setenv(OPENAI_API_KEY_ENVVAR_NAME, "sk-0123456789") + tracer = OpenInferenceTracer(exporter=NoOpExporter()) + with patch.object(OpenInferenceTracer, "_start_trace") as mock_tracer_internals: + mock_tracer_internals.side_effect = RuntimeError("This came from a test") + model_name = "gpt-4" + llm = ChatOpenAI(model_name=model_name) + expected_response = "response-text" + responses.post( + "https://api.openai.com/v1/chat/completions", + json={ + "id": "chatcmpl-123", + "object": "chat.completion", + "created": 1677652288, + "model": model_name, + "choices": [ + { + "index": 0, + "message": {"role": "assistant", "content": expected_response}, + "finish_reason": "stop", + } + ], + "usage": {"prompt_tokens": 1, "completion_tokens": 2, "total_tokens": 3}, + }, + status=200, + ) + llm(messages, callbacks=[tracer]) + fallback_log = caplog.records[0].message + assert "This came from a test" in fallback_log, "Error should be logged" + assert "Traceback" in fallback_log, "Traceback should be logged" + assert "Rerouting to fallback method" in fallback_log