Skip to content

Commit

Permalink
fix: enhance llama-index callback support for exception events (#1814)
Browse files Browse the repository at this point in the history
  • Loading branch information
axiomofjoy authored Nov 28, 2023
1 parent 9c314e8 commit 8db01df
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 37 deletions.
8 changes: 4 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,15 @@ dev = [
"strawberry-graphql[debug-server]==0.208.2",
"pre-commit",
"arize[AutoEmbeddings, LLM_Evaluation]",
"llama-index>=0.9.0,<0.9.8",
"llama-index>=0.9.8",
"langchain>=0.0.334",
"litellm>=1.0.3"
]
experimental = [
"tenacity",
]
llama-index = [
"llama-index>=0.9.0,<0.9.8"
"llama-index>=0.9.8",
]

[project.urls]
Expand Down Expand Up @@ -98,7 +98,7 @@ dependencies = [
"arize",
"langchain>=0.0.334",
"litellm>=1.0.3",
"llama-index>=0.9.0,<0.9.8",
"llama-index>=0.9.8",
"openai>=1.0.0",
"tenacity",
"nltk==3.8.1",
Expand All @@ -117,7 +117,7 @@ dependencies = [
[tool.hatch.envs.type]
dependencies = [
"mypy==1.5.1",
"llama-index>=0.9.0,<0.9.8",
"llama-index>=0.9.8",
"pandas-stubs<=2.0.2.230605", # version 2.0.3.230814 is causing a dependency conflict.
"types-psutil",
"types-tqdm",
Expand Down
52 changes: 25 additions & 27 deletions src/phoenix/trace/llama_index/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@
from phoenix.trace.utils import extract_version_triplet, get_stacktrace
from phoenix.utilities.error_handling import graceful_fallback

LLAMA_INDEX_MINIMUM_VERSION_TRIPLET = (0, 9, 0)
LLAMA_INDEX_MINIMUM_VERSION_TRIPLET = (0, 9, 8)
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())

Expand Down Expand Up @@ -355,12 +355,16 @@ def _add_spans_to_tracer(
parent_child_id_stack: List[Tuple[Optional[SpanID], CBEventID]] = [
(None, root_event_id) for root_event_id in trace_map["root"]
]
span_exceptions: List[SpanEvent] = []
while parent_child_id_stack:
parent_span_id, event_id = parent_child_id_stack.pop()
event_data = event_id_to_event_data[event_id]
event_type = event_data.event_type
attributes = event_data.attributes
if not (start_event := event_data.start_event):
# if the callback system has broken its contract by calling
# on_event_end without on_event_start, do not create a span
continue

if event_type is CBEventType.LLM:
while parent_child_id_stack:
preceding_event_parent_span_id, preceding_event_id = parent_child_id_stack[-1]
Expand All @@ -375,32 +379,11 @@ def _add_spans_to_tracer(
# Add template attributes to the LLM span to which they belong.
attributes.update(_template_attributes(preceding_payload))

start_time = None
if start_event := event_data.start_event:
start_time = _timestamp_to_tz_aware_datetime(start_event.time)
start_time = _timestamp_to_tz_aware_datetime(start_event.time)
span_exceptions = _get_span_exceptions(event_data, start_time)
end_time = _get_end_time(event_data, span_exceptions)
start_time = start_time or end_time or datetime.now(timezone.utc)

if event_type is CBEventType.EXCEPTION:
# LlamaIndex has exception callback events that are sibling events of the events in
# which the exception occurred. We collect all the exception events and add them to
# the relevant span.
if (
not start_event
or not start_event.payload
or (error := start_event.payload.get(EventPayload.EXCEPTION)) is None
):
continue
span_exceptions.append(
SpanException(
message=str(error),
timestamp=start_time,
exception_type=type(error).__name__,
exception_stacktrace=get_stacktrace(error),
)
)
continue

name = event_name if (event_name := event_data.name) is not None else "unknown"
span_kind = _get_span_kind(event_type)
span = tracer.create_span(
Expand All @@ -416,7 +399,6 @@ def _add_spans_to_tracer(
events=sorted(span_exceptions, key=lambda event: event.timestamp) or None,
conversation=None,
)
span_exceptions = []
new_parent_span_id = span.context.span_id
for new_child_event_id in trace_map.get(event_id, []):
parent_child_id_stack.append((new_parent_span_id, new_child_event_id))
Expand Down Expand Up @@ -511,7 +493,7 @@ def _get_response_output(response: Any) -> Iterator[Tuple[str, Any]]:
yield OUTPUT_MIME_TYPE, MimeType.TEXT


def _get_end_time(event_data: CBEventData, span_events: List[SpanEvent]) -> Optional[datetime]:
def _get_end_time(event_data: CBEventData, span_events: Iterable[SpanEvent]) -> Optional[datetime]:
"""
A best-effort attempt to get the end time of an event.
Expand All @@ -528,6 +510,22 @@ def _get_end_time(event_data: CBEventData, span_events: List[SpanEvent]) -> Opti
return _tz_naive_to_tz_aware_datetime(tz_naive_end_time)


def _get_span_exceptions(event_data: CBEventData, start_time: datetime) -> List[SpanException]:
"""Collects exceptions from the start and end events, if present."""
span_exceptions = []
for event in [event_data.start_event, event_data.end_event]:
if event and (payload := event.payload) and (error := payload.get(EventPayload.EXCEPTION)):
span_exceptions.append(
SpanException(
message=str(error),
timestamp=start_time,
exception_type=type(error).__name__,
exception_stacktrace=get_stacktrace(error),
)
)
return span_exceptions


def _timestamp_to_tz_aware_datetime(timestamp: str) -> datetime:
"""Converts a timestamp string to a timezone-aware datetime."""
return _tz_naive_to_tz_aware_datetime(_timestamp_to_tz_naive_datetime(timestamp))
Expand Down
41 changes: 36 additions & 5 deletions tests/trace/llama_index/test_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def test_callback_llm_span_contains_template_attributes(
assert isinstance(span.attributes[LLM_PROMPT_TEMPLATE_VARIABLES], dict)


def test_callback_llm_internal_error_has_exception_event(
def test_callback_internal_error_has_exception_event(
monkeypatch: pytest.MonkeyPatch,
) -> None:
monkeypatch.setenv(OPENAI_API_KEY_ENVVAR_NAME, "sk-0123456789")
Expand All @@ -160,10 +160,8 @@ def test_callback_llm_internal_error_has_exception_event(
query_engine.query(query)

spans = list(callback_handler.get_spans())
assert all(
span.status_code == SpanStatusCode.OK for span in spans if span.span_kind != SpanKind.LLM
)
span = next(span for span in spans if span.span_kind == SpanKind.LLM)
assert all(span.status_code == SpanStatusCode.OK for span in spans if span.name != "synthesize")
span = next(span for span in spans if span.name == "synthesize")
assert span.status_code == SpanStatusCode.ERROR
events = span.events
event = events[0]
Expand All @@ -175,6 +173,39 @@ def test_callback_llm_internal_error_has_exception_event(
assert isinstance(event.attributes[EXCEPTION_STACKTRACE], str)


def test_callback_exception_event_produces_root_chain_span_with_exception_events() -> None:
llm = OpenAI(model="gpt-3.5-turbo", max_retries=1)
query = "What are the seven wonders of the world?"
callback_handler = OpenInferenceTraceCallbackHandler(exporter=NoOpExporter())
index = ListIndex(nodes)
service_context = ServiceContext.from_defaults(
llm=llm, callback_manager=CallbackManager([callback_handler])
)
query_engine = index.as_query_engine(service_context=service_context)

# mock the _query method to raise an exception before any event has begun
# to produce an independent exception event
with patch.object(query_engine, "_query") as mocked_query:
mocked_query.side_effect = Exception("message")
with pytest.raises(Exception):
query_engine.query(query)

spans = list(callback_handler.get_spans())
assert len(spans) == 1
span = spans[0]
assert span.span_kind == SpanKind.CHAIN
assert span.status_code == SpanStatusCode.ERROR
assert span.name == "exception"
events = span.events
event = events[0]
assert isinstance(event, SpanException)
assert isinstance(event.timestamp, datetime)
assert len(event.attributes) == 3
assert event.attributes[EXCEPTION_TYPE] == "Exception"
assert event.attributes[EXCEPTION_MESSAGE] == "message"
assert isinstance(event.attributes[EXCEPTION_STACKTRACE], str)


def test_callback_llm_rate_limit_error_has_exception_event_with_missing_start(
monkeypatch: pytest.MonkeyPatch,
) -> None:
Expand Down
2 changes: 1 addition & 1 deletion tutorials/internal/llama_index_tracing_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
"from llama_index.callbacks import CallbackManager\n",
"from llama_index.embeddings.openai import OpenAIEmbedding\n",
"from llama_index.graph_stores.simple import SimpleGraphStore\n",
"from llama_index.indices.postprocessor.cohere_rerank import CohereRerank\n",
"from llama_index.postprocessor.cohere_rerank import CohereRerank\n",
"from phoenix.trace.llama_index import (\n",
" OpenInferenceTraceCallbackHandler,\n",
")\n",
Expand Down

0 comments on commit 8db01df

Please sign in to comment.