Skip to content

Commit

Permalink
fix(llmobs): sync and async generators can be decorated [backport 2.1…
Browse files Browse the repository at this point in the history
…7] (#11531)

Backport 451d88d from #11413 to 2.17.

MLOB-1892

## What does this PR do
Fixes support for decorating sync and async generator functions. 

Some notes:

- While we will still capture function name and inputs (inputs only for
non `llm` and `embedding` decorators), we will not automatically capture
outputs from the generator
- This change is intended to allow user annotation to still be captured,
as previously, the generator would be captured as the "output", and user
annotation would never go through.
- Async generators are not identifiable by `iscoroutine` and
`isgeneratorfunction`, so we rely on `isasyncgenfunction` instead,
verified in tests.

## Checklist
- [x] PR author has checked that all the criteria below are met
- The PR description includes an overview of the change
- The PR description articulates the motivation for the change
- The change includes tests OR the PR description describes a testing
strategy
- The PR description notes risks associated with the change, if any
- Newly-added code is easy to change
- The change follows the [library release note
guidelines](https://ddtrace.readthedocs.io/en/stable/releasenotes.html)
- The change includes or references documentation updates if necessary
- Backport labels are set (if
[applicable](https://ddtrace.readthedocs.io/en/latest/contributing.html#backporting))

## Reviewer Checklist
- [x] Reviewer has checked that all the criteria below are met 
- Title is accurate
- All changes are related to the pull request's stated goal
- Avoids breaking
[API](https://ddtrace.readthedocs.io/en/stable/versioning.html#interfaces)
changes
- Testing strategy adequately addresses listed risks
- Newly-added code is easy to change
- Release note makes sense to a user of the library
- If necessary, author has acknowledged and discussed the performance
implications of this PR as reported in the benchmarks PR comment
- Backport labels are set in a manner that is consistent with the
[release branch maintenance
policy](https://ddtrace.readthedocs.io/en/latest/contributing.html#backporting)

Co-authored-by: Sam Brenner <106700075+sabrenner@users.noreply.github.com>
  • Loading branch information
github-actions[bot] and sabrenner authored Nov 25, 2024
1 parent 1c2fdc0 commit 22184b0
Show file tree
Hide file tree
Showing 3 changed files with 486 additions and 26 deletions.
158 changes: 132 additions & 26 deletions ddtrace/llmobs/decorators.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from functools import wraps
from inspect import isasyncgenfunction
from inspect import signature
import sys
from typing import Callable
from typing import Optional

from ddtrace.internal.compat import iscoroutinefunction
from ddtrace.internal.compat import isgeneratorfunction
from ddtrace.internal.logger import get_logger
from ddtrace.llmobs import LLMObs
from ddtrace.llmobs._constants import OUTPUT_VALUE
Expand All @@ -13,6 +16,44 @@
log = get_logger(__name__)


def _get_llmobs_span_options(name, model_name, func):
traced_model_name = model_name
if traced_model_name is None:
traced_model_name = "custom"

span_name = name
if span_name is None:
span_name = func.__name__

return traced_model_name, span_name


async def yield_from_async_gen(func, span, args, kwargs):
try:
gen = func(*args, **kwargs)
next_val = await gen.asend(None)
while True:
try:
i = yield next_val
next_val = await gen.asend(i)
except GeneratorExit:
await gen.aclose()
break
except StopAsyncIteration as e:
await gen.athrow(e)
break
except Exception as e:
await gen.athrow(e)
raise
except (StopAsyncIteration, GeneratorExit):
raise
except Exception:
span.set_exc_info(*sys.exc_info())
raise
finally:
span.finish()


def _model_decorator(operation_kind):
def decorator(
original_func: Optional[Callable] = None,
Expand All @@ -23,20 +64,31 @@ def decorator(
ml_app: Optional[str] = None,
):
def inner(func):
if iscoroutinefunction(func):
if iscoroutinefunction(func) or isasyncgenfunction(func):

@wraps(func)
def generator_wrapper(*args, **kwargs):
if not LLMObs.enabled:
log.warning(SPAN_START_WHILE_DISABLED_WARNING)
return func(*args, **kwargs)
traced_model_name, span_name = _get_llmobs_span_options(name, model_name, func)
traced_operation = getattr(LLMObs, operation_kind, LLMObs.llm)
span = traced_operation(
model_name=traced_model_name,
model_provider=model_provider,
name=span_name,
session_id=session_id,
ml_app=ml_app,
)
return yield_from_async_gen(func, span, args, kwargs)

@wraps(func)
async def wrapper(*args, **kwargs):
if not LLMObs.enabled:
log.warning(SPAN_START_WHILE_DISABLED_WARNING)
return await func(*args, **kwargs)
traced_model_name = model_name
if traced_model_name is None:
traced_model_name = "custom"
span_name = name
if span_name is None:
span_name = func.__name__
traced_operation = getattr(LLMObs, operation_kind, "llm")
traced_model_name, span_name = _get_llmobs_span_options(name, model_name, func)
traced_operation = getattr(LLMObs, operation_kind, LLMObs.llm)
with traced_operation(
model_name=traced_model_name,
model_provider=model_provider,
Expand All @@ -48,18 +100,38 @@ async def wrapper(*args, **kwargs):

else:

@wraps(func)
def generator_wrapper(*args, **kwargs):
if not LLMObs.enabled:
log.warning(SPAN_START_WHILE_DISABLED_WARNING)
yield from func(*args, **kwargs)
else:
traced_model_name, span_name = _get_llmobs_span_options(name, model_name, func)
traced_operation = getattr(LLMObs, operation_kind, LLMObs.llm)
span = traced_operation(
model_name=traced_model_name,
model_provider=model_provider,
name=span_name,
session_id=session_id,
ml_app=ml_app,
)
try:
yield from func(*args, **kwargs)
except (StopIteration, GeneratorExit):
raise
except Exception:
span.set_exc_info(*sys.exc_info())
raise
finally:
span.finish()

@wraps(func)
def wrapper(*args, **kwargs):
if not LLMObs.enabled:
log.warning(SPAN_START_WHILE_DISABLED_WARNING)
return func(*args, **kwargs)
traced_model_name = model_name
if traced_model_name is None:
traced_model_name = "custom"
span_name = name
if span_name is None:
span_name = func.__name__
traced_operation = getattr(LLMObs, operation_kind, "llm")
traced_model_name, span_name = _get_llmobs_span_options(name, model_name, func)
traced_operation = getattr(LLMObs, operation_kind, LLMObs.llm)
with traced_operation(
model_name=traced_model_name,
model_provider=model_provider,
Expand All @@ -69,7 +141,7 @@ def wrapper(*args, **kwargs):
):
return func(*args, **kwargs)

return wrapper
return generator_wrapper if (isgeneratorfunction(func) or isasyncgenfunction(func)) else wrapper

if original_func and callable(original_func):
return inner(original_func)
Expand All @@ -87,17 +159,29 @@ def decorator(
_automatic_io_annotation: bool = True,
):
def inner(func):
if iscoroutinefunction(func):
if iscoroutinefunction(func) or isasyncgenfunction(func):

@wraps(func)
def generator_wrapper(*args, **kwargs):
if not LLMObs.enabled:
log.warning(SPAN_START_WHILE_DISABLED_WARNING)
return func(*args, **kwargs)
_, span_name = _get_llmobs_span_options(name, None, func)
traced_operation = getattr(LLMObs, operation_kind, LLMObs.workflow)
span = traced_operation(name=span_name, session_id=session_id, ml_app=ml_app)
func_signature = signature(func)
bound_args = func_signature.bind_partial(*args, **kwargs)
if _automatic_io_annotation and bound_args.arguments:
LLMObs.annotate(span=span, input_data=bound_args.arguments)
return yield_from_async_gen(func, span, args, kwargs)

@wraps(func)
async def wrapper(*args, **kwargs):
if not LLMObs.enabled:
log.warning(SPAN_START_WHILE_DISABLED_WARNING)
return await func(*args, **kwargs)
span_name = name
if span_name is None:
span_name = func.__name__
traced_operation = getattr(LLMObs, operation_kind, "workflow")
_, span_name = _get_llmobs_span_options(name, None, func)
traced_operation = getattr(LLMObs, operation_kind, LLMObs.workflow)
with traced_operation(name=span_name, session_id=session_id, ml_app=ml_app) as span:
func_signature = signature(func)
bound_args = func_signature.bind_partial(*args, **kwargs)
Expand All @@ -115,15 +199,37 @@ async def wrapper(*args, **kwargs):

else:

@wraps(func)
def generator_wrapper(*args, **kwargs):
if not LLMObs.enabled:
log.warning(SPAN_START_WHILE_DISABLED_WARNING)
yield from func(*args, **kwargs)
else:
_, span_name = _get_llmobs_span_options(name, None, func)
traced_operation = getattr(LLMObs, operation_kind, LLMObs.workflow)
span = traced_operation(name=span_name, session_id=session_id, ml_app=ml_app)
func_signature = signature(func)
bound_args = func_signature.bind_partial(*args, **kwargs)
if _automatic_io_annotation and bound_args.arguments:
LLMObs.annotate(span=span, input_data=bound_args.arguments)
try:
yield from func(*args, **kwargs)
except (StopIteration, GeneratorExit):
raise
except Exception:
span.set_exc_info(*sys.exc_info())
raise
finally:
if span:
span.finish()

@wraps(func)
def wrapper(*args, **kwargs):
if not LLMObs.enabled:
log.warning(SPAN_START_WHILE_DISABLED_WARNING)
return func(*args, **kwargs)
span_name = name
if span_name is None:
span_name = func.__name__
traced_operation = getattr(LLMObs, operation_kind, "workflow")
_, span_name = _get_llmobs_span_options(name, None, func)
traced_operation = getattr(LLMObs, operation_kind, LLMObs.workflow)
with traced_operation(name=span_name, session_id=session_id, ml_app=ml_app) as span:
func_signature = signature(func)
bound_args = func_signature.bind_partial(*args, **kwargs)
Expand All @@ -139,7 +245,7 @@ def wrapper(*args, **kwargs):
LLMObs.annotate(span=span, output_data=resp)
return resp

return wrapper
return generator_wrapper if (isgeneratorfunction(func) or isasyncgenfunction(func)) else wrapper

if original_func and callable(original_func):
return inner(original_func)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
fixes:
- |
LLM Observability: Fixes an issue where decorators were not tracing generator functions properly.
Loading

0 comments on commit 22184b0

Please sign in to comment.