Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: enhance llama-index callback support for exception events #1814

Merged
merged 12 commits into from
Nov 28, 2023
8 changes: 4 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,15 @@ dev = [
"strawberry-graphql[debug-server]==0.208.2",
"pre-commit",
"arize[AutoEmbeddings, LLM_Evaluation]",
"llama-index>=0.9.0,<0.9.8",
"llama-index>=0.9.8",
"langchain>=0.0.334",
"litellm>=1.0.3"
]
experimental = [
"tenacity",
]
llama-index = [
"llama-index>=0.9.0,<0.9.8"
"llama-index>=0.9.8",
]

[project.urls]
Expand Down Expand Up @@ -98,7 +98,7 @@ dependencies = [
"arize",
"langchain>=0.0.334",
"litellm>=1.0.3",
"llama-index>=0.9.0,<0.9.8",
"llama-index>=0.9.8",
"openai>=1.0.0",
"tenacity",
"nltk==3.8.1",
Expand All @@ -117,7 +117,7 @@ dependencies = [
[tool.hatch.envs.type]
dependencies = [
"mypy==1.5.1",
"llama-index>=0.9.0,<0.9.8",
"llama-index>=0.9.8",
"pandas-stubs<=2.0.2.230605", # version 2.0.3.230814 is causing a dependency conflict.
"types-psutil",
"types-tqdm",
Expand Down
52 changes: 25 additions & 27 deletions src/phoenix/trace/llama_index/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@
from phoenix.trace.utils import extract_version_triplet, get_stacktrace
from phoenix.utilities.error_handling import graceful_fallback

LLAMA_INDEX_MINIMUM_VERSION_TRIPLET = (0, 9, 0)
LLAMA_INDEX_MINIMUM_VERSION_TRIPLET = (0, 9, 8)
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())

Expand Down Expand Up @@ -355,12 +355,16 @@ def _add_spans_to_tracer(
parent_child_id_stack: List[Tuple[Optional[SpanID], CBEventID]] = [
(None, root_event_id) for root_event_id in trace_map["root"]
]
span_exceptions: List[SpanEvent] = []
while parent_child_id_stack:
parent_span_id, event_id = parent_child_id_stack.pop()
event_data = event_id_to_event_data[event_id]
event_type = event_data.event_type
attributes = event_data.attributes
if not (start_event := event_data.start_event):
# if the callback system has broken its contract by calling
# on_event_end without on_event_start, do not create a span
axiomofjoy marked this conversation as resolved.
Show resolved Hide resolved
continue

if event_type is CBEventType.LLM:
while parent_child_id_stack:
preceding_event_parent_span_id, preceding_event_id = parent_child_id_stack[-1]
Expand All @@ -375,32 +379,11 @@ def _add_spans_to_tracer(
# Add template attributes to the LLM span to which they belong.
attributes.update(_template_attributes(preceding_payload))

start_time = None
if start_event := event_data.start_event:
start_time = _timestamp_to_tz_aware_datetime(start_event.time)
start_time = _timestamp_to_tz_aware_datetime(start_event.time)
span_exceptions = _get_span_exceptions(event_data, start_time)
end_time = _get_end_time(event_data, span_exceptions)
start_time = start_time or end_time or datetime.now(timezone.utc)

if event_type is CBEventType.EXCEPTION:
# LlamaIndex has exception callback events that are sibling events of the events in
# which the exception occurred. We collect all the exception events and add them to
# the relevant span.
if (
not start_event
or not start_event.payload
or (error := start_event.payload.get(EventPayload.EXCEPTION)) is None
):
continue
span_exceptions.append(
SpanException(
message=str(error),
timestamp=start_time,
exception_type=type(error).__name__,
exception_stacktrace=get_stacktrace(error),
)
)
continue

name = event_name if (event_name := event_data.name) is not None else "unknown"
span_kind = _get_span_kind(event_type)
span = tracer.create_span(
Expand All @@ -416,7 +399,6 @@ def _add_spans_to_tracer(
events=sorted(span_exceptions, key=lambda event: event.timestamp) or None,
conversation=None,
)
span_exceptions = []
new_parent_span_id = span.context.span_id
for new_child_event_id in trace_map.get(event_id, []):
parent_child_id_stack.append((new_parent_span_id, new_child_event_id))
Expand Down Expand Up @@ -511,7 +493,7 @@ def _get_response_output(response: Any) -> Iterator[Tuple[str, Any]]:
yield OUTPUT_MIME_TYPE, MimeType.TEXT


def _get_end_time(event_data: CBEventData, span_events: List[SpanEvent]) -> Optional[datetime]:
def _get_end_time(event_data: CBEventData, span_events: Iterable[SpanEvent]) -> Optional[datetime]:
"""
A best-effort attempt to get the end time of an event.

Expand All @@ -528,6 +510,22 @@ def _get_end_time(event_data: CBEventData, span_events: List[SpanEvent]) -> Opti
return _tz_naive_to_tz_aware_datetime(tz_naive_end_time)


def _get_span_exceptions(event_data: CBEventData, start_time: datetime) -> List[SpanException]:
"""Collects exceptions from the start and end events, if present."""
span_exceptions = []
for event in [event_data.start_event, event_data.end_event]:
if event and (payload := event.payload) and (error := payload.get(EventPayload.EXCEPTION)):
span_exceptions.append(
SpanException(
message=str(error),
timestamp=start_time,
exception_type=type(error).__name__,
exception_stacktrace=get_stacktrace(error),
)
)
return span_exceptions


def _timestamp_to_tz_aware_datetime(timestamp: str) -> datetime:
"""Converts a timestamp string to a timezone-aware datetime."""
return _tz_naive_to_tz_aware_datetime(_timestamp_to_tz_naive_datetime(timestamp))
Expand Down
41 changes: 36 additions & 5 deletions tests/trace/llama_index/test_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def test_callback_llm_span_contains_template_attributes(
assert isinstance(span.attributes[LLM_PROMPT_TEMPLATE_VARIABLES], dict)


def test_callback_llm_internal_error_has_exception_event(
def test_callback_internal_error_has_exception_event(
monkeypatch: pytest.MonkeyPatch,
) -> None:
monkeypatch.setenv(OPENAI_API_KEY_ENVVAR_NAME, "sk-0123456789")
Expand All @@ -160,10 +160,8 @@ def test_callback_llm_internal_error_has_exception_event(
query_engine.query(query)

spans = list(callback_handler.get_spans())
assert all(
span.status_code == SpanStatusCode.OK for span in spans if span.span_kind != SpanKind.LLM
)
span = next(span for span in spans if span.span_kind == SpanKind.LLM)
assert all(span.status_code == SpanStatusCode.OK for span in spans if span.name != "synthesize")
span = next(span for span in spans if span.name == "synthesize")
assert span.status_code == SpanStatusCode.ERROR
events = span.events
event = events[0]
Expand All @@ -175,6 +173,39 @@ def test_callback_llm_internal_error_has_exception_event(
assert isinstance(event.attributes[EXCEPTION_STACKTRACE], str)


def test_callback_exception_event_produces_root_chain_span_with_exception_events() -> None:
llm = OpenAI(model="gpt-3.5-turbo", max_retries=1)
query = "What are the seven wonders of the world?"
callback_handler = OpenInferenceTraceCallbackHandler(exporter=NoOpExporter())
index = ListIndex(nodes)
service_context = ServiceContext.from_defaults(
llm=llm, callback_manager=CallbackManager([callback_handler])
)
query_engine = index.as_query_engine(service_context=service_context)

# mock the _query method to raise an exception before any event has begun
# to produce an independent exception event
with patch.object(query_engine, "_query") as mocked_query:
mocked_query.side_effect = Exception("message")
with pytest.raises(Exception):
query_engine.query(query)

spans = list(callback_handler.get_spans())
assert len(spans) == 1
span = spans[0]
assert span.span_kind == SpanKind.CHAIN
assert span.status_code == SpanStatusCode.ERROR
assert span.name == "exception"
events = span.events
event = events[0]
assert isinstance(event, SpanException)
assert isinstance(event.timestamp, datetime)
assert len(event.attributes) == 3
assert event.attributes[EXCEPTION_TYPE] == "Exception"
assert event.attributes[EXCEPTION_MESSAGE] == "message"
assert isinstance(event.attributes[EXCEPTION_STACKTRACE], str)


def test_callback_llm_rate_limit_error_has_exception_event_with_missing_start(
monkeypatch: pytest.MonkeyPatch,
) -> None:
Expand Down
2 changes: 1 addition & 1 deletion tutorials/internal/llama_index_tracing_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
"from llama_index.callbacks import CallbackManager\n",
"from llama_index.embeddings.openai import OpenAIEmbedding\n",
"from llama_index.graph_stores.simple import SimpleGraphStore\n",
"from llama_index.indices.postprocessor.cohere_rerank import CohereRerank\n",
"from llama_index.postprocessor.cohere_rerank import CohereRerank\n",
"from phoenix.trace.llama_index import (\n",
" OpenInferenceTraceCallbackHandler,\n",
")\n",
Expand Down
Loading