Skip to content

Commit

Permalink
feat: add get_current_span helper function for llama-index (#1165)
Browse files Browse the repository at this point in the history
  • Loading branch information
RogerHYang authored Dec 12, 2024
1 parent 64d5ac6 commit b46931c
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ packages = ["src/openinference"]

[tool.pytest.ini_options]
asyncio_mode = "auto"
asyncio_default_fixture_loop_scope = "function"
testpaths = [
"tests",
]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import logging
from typing import Any, Collection
from typing import Any, Collection, Optional

from opentelemetry import trace as trace_api
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor # type: ignore
from opentelemetry.trace import Span

from openinference.instrumentation import OITracer, TraceConfig
from openinference.instrumentation.llama_index.package import _instruments
Expand Down Expand Up @@ -104,3 +105,13 @@ def _uninstrument(self, **kwargs: Any) -> None:
dispatcher.event_handlers,
)
self._event_handler = None


def get_current_span() -> Optional[Span]:
from llama_index.core.instrumentation.span import active_span_id

if not isinstance(id_ := active_span_id.get(), str):
return None
if (span := LlamaIndexInstrumentor()._span_handler.open_spans.get(id_)) is None:
return None
return span._otel_span
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from asyncio import create_task, gather, sleep
from random import random
from typing import Iterator

import pytest
from llama_index.core.instrumentation import get_dispatcher
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter

from openinference.instrumentation.llama_index import LlamaIndexInstrumentor, get_current_span
from openinference.semconv.trace import SpanAttributes

dispatcher = get_dispatcher(__name__)


@dispatcher.span # type: ignore[misc,unused-ignore]
async def foo(k: int) -> str:
child = create_task(foo(k - 1)) if k > 1 else None
await sleep(random() / 100)
span = get_current_span()
if child:
await child
return str(span.get_span_context().span_id) if span else ""


async def test_get_current_span(
in_memory_span_exporter: InMemorySpanExporter,
) -> None:
n, k = 10, 5
await gather(*(foo(k) for _ in range(n)))
spans = in_memory_span_exporter.get_finished_spans()
assert len(spans) == n * k
seen = set()
for span in spans:
assert span.attributes and span.context
assert (expected := str(span.context.span_id)) not in seen
seen.add(expected)
assert span.attributes.get(OUTPUT_VALUE) == expected


@pytest.fixture(autouse=True)
def instrument(
tracer_provider: TracerProvider,
in_memory_span_exporter: InMemorySpanExporter,
) -> Iterator[None]:
LlamaIndexInstrumentor().instrument(tracer_provider=tracer_provider)
yield
LlamaIndexInstrumentor().uninstrument()


OUTPUT_VALUE = SpanAttributes.OUTPUT_VALUE

0 comments on commit b46931c

Please sign in to comment.