diff --git a/pyproject.toml b/pyproject.toml index 8773f87780..1c34b14f1e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,7 +56,7 @@ dev = [ "strawberry-graphql[debug-server]==0.208.2", "pre-commit", "arize[AutoEmbeddings, LLM_Evaluation]", - "llama-index>=0.9.0,<0.9.8", + "llama-index>=0.9.8", "langchain>=0.0.334", "litellm>=1.0.3" ] @@ -64,7 +64,7 @@ experimental = [ "tenacity", ] llama-index = [ - "llama-index>=0.9.0,<0.9.8" + "llama-index>=0.9.8", ] [project.urls] @@ -98,7 +98,7 @@ dependencies = [ "arize", "langchain>=0.0.334", "litellm>=1.0.3", - "llama-index>=0.9.0,<0.9.8", + "llama-index>=0.9.8", "openai>=1.0.0", "tenacity", "nltk==3.8.1", @@ -117,7 +117,7 @@ dependencies = [ [tool.hatch.envs.type] dependencies = [ "mypy==1.5.1", - "llama-index>=0.9.0,<0.9.8", + "llama-index>=0.9.8", "pandas-stubs<=2.0.2.230605", # version 2.0.3.230814 is causing a dependency conflict. "types-psutil", "types-tqdm", diff --git a/src/phoenix/trace/llama_index/callback.py b/src/phoenix/trace/llama_index/callback.py index 1027c80b85..e209615fb6 100644 --- a/src/phoenix/trace/llama_index/callback.py +++ b/src/phoenix/trace/llama_index/callback.py @@ -92,7 +92,7 @@ from phoenix.trace.utils import extract_version_triplet, get_stacktrace from phoenix.utilities.error_handling import graceful_fallback -LLAMA_INDEX_MINIMUM_VERSION_TRIPLET = (0, 9, 0) +LLAMA_INDEX_MINIMUM_VERSION_TRIPLET = (0, 9, 8) logger = logging.getLogger(__name__) logger.addHandler(logging.NullHandler()) @@ -355,12 +355,16 @@ def _add_spans_to_tracer( parent_child_id_stack: List[Tuple[Optional[SpanID], CBEventID]] = [ (None, root_event_id) for root_event_id in trace_map["root"] ] - span_exceptions: List[SpanEvent] = [] while parent_child_id_stack: parent_span_id, event_id = parent_child_id_stack.pop() event_data = event_id_to_event_data[event_id] event_type = event_data.event_type attributes = event_data.attributes + if not (start_event := event_data.start_event): + # if the callback system has broken its contract by calling + # on_event_end without on_event_start, do not create a span + continue + if event_type is CBEventType.LLM: while parent_child_id_stack: preceding_event_parent_span_id, preceding_event_id = parent_child_id_stack[-1] @@ -375,32 +379,11 @@ def _add_spans_to_tracer( # Add template attributes to the LLM span to which they belong. attributes.update(_template_attributes(preceding_payload)) - start_time = None - if start_event := event_data.start_event: - start_time = _timestamp_to_tz_aware_datetime(start_event.time) + start_time = _timestamp_to_tz_aware_datetime(start_event.time) + span_exceptions = _get_span_exceptions(event_data, start_time) end_time = _get_end_time(event_data, span_exceptions) start_time = start_time or end_time or datetime.now(timezone.utc) - if event_type is CBEventType.EXCEPTION: - # LlamaIndex has exception callback events that are sibling events of the events in - # which the exception occurred. We collect all the exception events and add them to - # the relevant span. - if ( - not start_event - or not start_event.payload - or (error := start_event.payload.get(EventPayload.EXCEPTION)) is None - ): - continue - span_exceptions.append( - SpanException( - message=str(error), - timestamp=start_time, - exception_type=type(error).__name__, - exception_stacktrace=get_stacktrace(error), - ) - ) - continue - name = event_name if (event_name := event_data.name) is not None else "unknown" span_kind = _get_span_kind(event_type) span = tracer.create_span( @@ -416,7 +399,6 @@ def _add_spans_to_tracer( events=sorted(span_exceptions, key=lambda event: event.timestamp) or None, conversation=None, ) - span_exceptions = [] new_parent_span_id = span.context.span_id for new_child_event_id in trace_map.get(event_id, []): parent_child_id_stack.append((new_parent_span_id, new_child_event_id)) @@ -511,7 +493,7 @@ def _get_response_output(response: Any) -> Iterator[Tuple[str, Any]]: yield OUTPUT_MIME_TYPE, MimeType.TEXT -def _get_end_time(event_data: CBEventData, span_events: List[SpanEvent]) -> Optional[datetime]: +def _get_end_time(event_data: CBEventData, span_events: Iterable[SpanEvent]) -> Optional[datetime]: """ A best-effort attempt to get the end time of an event. @@ -528,6 +510,22 @@ def _get_end_time(event_data: CBEventData, span_events: List[SpanEvent]) -> Opti return _tz_naive_to_tz_aware_datetime(tz_naive_end_time) +def _get_span_exceptions(event_data: CBEventData, start_time: datetime) -> List[SpanException]: + """Collects exceptions from the start and end events, if present.""" + span_exceptions = [] + for event in [event_data.start_event, event_data.end_event]: + if event and (payload := event.payload) and (error := payload.get(EventPayload.EXCEPTION)): + span_exceptions.append( + SpanException( + message=str(error), + timestamp=start_time, + exception_type=type(error).__name__, + exception_stacktrace=get_stacktrace(error), + ) + ) + return span_exceptions + + def _timestamp_to_tz_aware_datetime(timestamp: str) -> datetime: """Converts a timestamp string to a timezone-aware datetime.""" return _tz_naive_to_tz_aware_datetime(_timestamp_to_tz_naive_datetime(timestamp)) diff --git a/tests/trace/llama_index/test_callback.py b/tests/trace/llama_index/test_callback.py index d4c61c0253..7b8731f6ee 100644 --- a/tests/trace/llama_index/test_callback.py +++ b/tests/trace/llama_index/test_callback.py @@ -135,7 +135,7 @@ def test_callback_llm_span_contains_template_attributes( assert isinstance(span.attributes[LLM_PROMPT_TEMPLATE_VARIABLES], dict) -def test_callback_llm_internal_error_has_exception_event( +def test_callback_internal_error_has_exception_event( monkeypatch: pytest.MonkeyPatch, ) -> None: monkeypatch.setenv(OPENAI_API_KEY_ENVVAR_NAME, "sk-0123456789") @@ -160,10 +160,8 @@ def test_callback_llm_internal_error_has_exception_event( query_engine.query(query) spans = list(callback_handler.get_spans()) - assert all( - span.status_code == SpanStatusCode.OK for span in spans if span.span_kind != SpanKind.LLM - ) - span = next(span for span in spans if span.span_kind == SpanKind.LLM) + assert all(span.status_code == SpanStatusCode.OK for span in spans if span.name != "synthesize") + span = next(span for span in spans if span.name == "synthesize") assert span.status_code == SpanStatusCode.ERROR events = span.events event = events[0] @@ -175,6 +173,39 @@ def test_callback_llm_internal_error_has_exception_event( assert isinstance(event.attributes[EXCEPTION_STACKTRACE], str) +def test_callback_exception_event_produces_root_chain_span_with_exception_events() -> None: + llm = OpenAI(model="gpt-3.5-turbo", max_retries=1) + query = "What are the seven wonders of the world?" + callback_handler = OpenInferenceTraceCallbackHandler(exporter=NoOpExporter()) + index = ListIndex(nodes) + service_context = ServiceContext.from_defaults( + llm=llm, callback_manager=CallbackManager([callback_handler]) + ) + query_engine = index.as_query_engine(service_context=service_context) + + # mock the _query method to raise an exception before any event has begun + # to produce an independent exception event + with patch.object(query_engine, "_query") as mocked_query: + mocked_query.side_effect = Exception("message") + with pytest.raises(Exception): + query_engine.query(query) + + spans = list(callback_handler.get_spans()) + assert len(spans) == 1 + span = spans[0] + assert span.span_kind == SpanKind.CHAIN + assert span.status_code == SpanStatusCode.ERROR + assert span.name == "exception" + events = span.events + event = events[0] + assert isinstance(event, SpanException) + assert isinstance(event.timestamp, datetime) + assert len(event.attributes) == 3 + assert event.attributes[EXCEPTION_TYPE] == "Exception" + assert event.attributes[EXCEPTION_MESSAGE] == "message" + assert isinstance(event.attributes[EXCEPTION_STACKTRACE], str) + + def test_callback_llm_rate_limit_error_has_exception_event_with_missing_start( monkeypatch: pytest.MonkeyPatch, ) -> None: diff --git a/tutorials/internal/llama_index_tracing_example.ipynb b/tutorials/internal/llama_index_tracing_example.ipynb index 1a62062a8b..c1ce5ee974 100644 --- a/tutorials/internal/llama_index_tracing_example.ipynb +++ b/tutorials/internal/llama_index_tracing_example.ipynb @@ -54,7 +54,7 @@ "from llama_index.callbacks import CallbackManager\n", "from llama_index.embeddings.openai import OpenAIEmbedding\n", "from llama_index.graph_stores.simple import SimpleGraphStore\n", - "from llama_index.indices.postprocessor.cohere_rerank import CohereRerank\n", + "from llama_index.postprocessor.cohere_rerank import CohereRerank\n", "from phoenix.trace.llama_index import (\n", " OpenInferenceTraceCallbackHandler,\n", ")\n",