Skip to content

Commit

Permalink
feat(SageMaker): Add SageMaker instrumentation (#2028)
Browse files Browse the repository at this point in the history
Co-authored-by: Nir Gazit <nirga@users.noreply.github.com>
  • Loading branch information
bobbywlindsey and nirga authored Oct 2, 2024
1 parent b814cdd commit 346d752
Show file tree
Hide file tree
Showing 22 changed files with 1,604 additions and 29 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ OpenLLMetry can instrument everything that [OpenTelemetry already instruments](h
- ✅ Mistral AI
- ✅ HuggingFace
- ✅ Bedrock (AWS)
- ✅ SageMaker (AWS)
- ✅ Replicate
- ✅ Vertex AI (GCP)
- ✅ Google Generative AI (Gemini)
Expand Down
11 changes: 11 additions & 0 deletions packages/opentelemetry-instrumentation-sagemaker/.flake8
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
[flake8]
exclude =
.git,
__pycache__,
build,
dist,
.tox,
venv,
.venv,
.pytest_cache
max-line-length = 120
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
3.9.5
33 changes: 33 additions & 0 deletions packages/opentelemetry-instrumentation-sagemaker/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# OpenTelemetry SageMaker Instrumentation

<a href="https://pypi.org/project/opentelemetry-instrumentation-sagemaker/">
<img src="https://badge.fury.io/py/opentelemetry-instrumentation-sagemaker.svg">
</a>

This library allows tracing of any models deployed on Amazon SageMaker and invoked with [Boto3](https://github.com/boto/boto3) to SageMaker.

## Installation

```bash
pip install opentelemetry-instrumentation-sagemaker
```

## Example usage

```python
from opentelemetry.instrumentation.sagemaker import SageMakerInstrumentor

SageMakerInstrumentor().instrument()
```

## Privacy

**By default, this instrumentation logs SageMaker endpoint request bodies and responses to span attributes**. This gives you a clear visibility into how your LLM application is working, and can make it easy to debug and evaluate the quality of the outputs.

However, you may want to disable this logging for privacy reasons, as they may contain highly sensitive data from your users. You may also simply want to reduce the size of your traces.

To disable logging, set the `TRACELOOP_TRACE_CONTENT` environment variable to `false`.

```bash
TRACELOOP_TRACE_CONTENT=false
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
"""OpenTelemetry SageMaker instrumentation"""

from functools import wraps
import json
import logging
import os
from typing import Collection
from opentelemetry.instrumentation.sagemaker.config import Config
from opentelemetry.instrumentation.sagemaker.reusable_streaming_body import (
ReusableStreamingBody,
)
from opentelemetry.instrumentation.sagemaker.streaming_wrapper import StreamingWrapper
from opentelemetry.instrumentation.sagemaker.utils import dont_throw
from wrapt import wrap_function_wrapper

from opentelemetry import context as context_api
from opentelemetry.trace import get_tracer, SpanKind

from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
from opentelemetry.instrumentation.utils import (
_SUPPRESS_INSTRUMENTATION_KEY,
unwrap,
)

from opentelemetry.semconv_ai import (
SUPPRESS_LANGUAGE_MODEL_INSTRUMENTATION_KEY,
SpanAttributes,
)
from opentelemetry.instrumentation.sagemaker.version import __version__

logger = logging.getLogger(__name__)

_instruments = ("boto3 >= 1.28.57",)

WRAPPED_METHODS = [
{
"package": "botocore.client",
"object": "ClientCreator",
"method": "create_client",
},
{"package": "botocore.session", "object": "Session", "method": "create_client"},
]


def should_send_prompts():
return (
os.getenv("TRACELOOP_TRACE_CONTENT") or "true"
).lower() == "true" or context_api.get_value("override_enable_content_tracing")


def _set_span_attribute(span, name, value):
if value is not None:
if value != "":
span.set_attribute(name, value)
return


def _with_tracer_wrapper(func):
"""Helper for providing tracer for wrapper functions."""

def _with_tracer(tracer, to_wrap):
def wrapper(wrapped, instance, args, kwargs):
return func(tracer, to_wrap, wrapped, instance, args, kwargs)

return wrapper

return _with_tracer


@_with_tracer_wrapper
def _wrap(tracer, to_wrap, wrapped, instance, args, kwargs):
"""Instruments and calls every function defined in TO_WRAP."""
if context_api.get_value(_SUPPRESS_INSTRUMENTATION_KEY):
return wrapped(*args, **kwargs)

if kwargs.get("service_name") == "sagemaker-runtime":
client = wrapped(*args, **kwargs)
client.invoke_endpoint = _instrumented_endpoint_invoke(
client.invoke_endpoint, tracer
)
client.invoke_endpoint_with_response_stream = (
_instrumented_endpoint_invoke_with_response_stream(
client.invoke_endpoint_with_response_stream, tracer
)
)

return client

return wrapped(*args, **kwargs)


def _instrumented_endpoint_invoke(fn, tracer):
@wraps(fn)
def with_instrumentation(*args, **kwargs):
if context_api.get_value(SUPPRESS_LANGUAGE_MODEL_INSTRUMENTATION_KEY):
return fn(*args, **kwargs)

with tracer.start_as_current_span(
"sagemaker.completion", kind=SpanKind.CLIENT
) as span:
response = fn(*args, **kwargs)

if span.is_recording():
_handle_call(span, kwargs, response)

return response

return with_instrumentation


def _instrumented_endpoint_invoke_with_response_stream(fn, tracer):
@wraps(fn)
def with_instrumentation(*args, **kwargs):
if context_api.get_value(SUPPRESS_LANGUAGE_MODEL_INSTRUMENTATION_KEY):
return fn(*args, **kwargs)

span = tracer.start_span("sagemaker.completion", kind=SpanKind.CLIENT)
response = fn(*args, **kwargs)

if span.is_recording():
_handle_stream_call(span, kwargs, response)

return response

return with_instrumentation


def _handle_stream_call(span, kwargs, response):
@dont_throw
def stream_done(response_body):
request_body = json.loads(kwargs.get("Body"))

endpoint_name = kwargs.get("EndpointName")

_set_span_attribute(span, SpanAttributes.LLM_REQUEST_MODEL, endpoint_name)
_set_span_attribute(
span, SpanAttributes.TRACELOOP_ENTITY_INPUT, json.dumps(request_body)
)
_set_span_attribute(
span, SpanAttributes.TRACELOOP_ENTITY_OUTPUT, json.dumps(response_body)
)

span.end()

response["Body"] = StreamingWrapper(response["Body"], stream_done)


@dont_throw
def _handle_call(span, kwargs, response):
response["Body"] = ReusableStreamingBody(
response["Body"]._raw_stream, response["Body"]._content_length
)
request_body = json.loads(kwargs.get("Body"))
response_body = json.loads(response.get("Body").read())

endpoint_name = kwargs.get("EndpointName")

_set_span_attribute(span, SpanAttributes.LLM_REQUEST_MODEL, endpoint_name)
_set_span_attribute(
span, SpanAttributes.TRACELOOP_ENTITY_INPUT, json.dumps(request_body)
)
_set_span_attribute(
span, SpanAttributes.TRACELOOP_ENTITY_OUTPUT, json.dumps(response_body)
)


class SageMakerInstrumentor(BaseInstrumentor):
"""An instrumentor for Bedrock's client library."""

def __init__(self, enrich_token_usage: bool = False, exception_logger=None):
super().__init__()
Config.enrich_token_usage = enrich_token_usage
Config.exception_logger = exception_logger

def instrumentation_dependencies(self) -> Collection[str]:
return _instruments

def _instrument(self, **kwargs):
tracer_provider = kwargs.get("tracer_provider")
tracer = get_tracer(__name__, __version__, tracer_provider)
for wrapped_method in WRAPPED_METHODS:
wrap_package = wrapped_method.get("package")
wrap_object = wrapped_method.get("object")
wrap_method = wrapped_method.get("method")
wrap_function_wrapper(
wrap_package,
f"{wrap_object}.{wrap_method}",
_wrap(tracer, wrapped_method),
)

def _uninstrument(self, **kwargs):
for wrapped_method in WRAPPED_METHODS:
wrap_package = wrapped_method.get("package")
wrap_object = wrapped_method.get("object")
unwrap(
f"{wrap_package}.{wrap_object}",
wrapped_method.get("method"),
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
class Config:
enrich_token_usage = False
exception_logger = None
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from botocore.response import StreamingBody
from botocore.exceptions import (
ReadTimeoutError,
ResponseStreamingError,
)
from urllib3.exceptions import ProtocolError as URLLib3ProtocolError
from urllib3.exceptions import ReadTimeoutError as URLLib3ReadTimeoutError


class ReusableStreamingBody(StreamingBody):
"""Wrapper around StreamingBody that allows the body to be read multiple times."""

def __init__(self, raw_stream, content_length):
super().__init__(raw_stream, content_length)
self._buffer = None
self._buffer_cursor = 0

def read(self, amt=None):
"""Read at most amt bytes from the stream.
If the amt argument is omitted, read all data.
"""
if self._buffer is None:
try:
self._buffer = self._raw_stream.read()
except URLLib3ReadTimeoutError as e:
# TODO: the url will be None as urllib3 isn't setting it yet
raise ReadTimeoutError(endpoint_url=e.url, error=e)
except URLLib3ProtocolError as e:
raise ResponseStreamingError(error=e)

self._amount_read += len(self._buffer)
if amt is None or (not self._buffer and amt > 0):
# If the server sends empty contents or
# we ask to read all of the contents, then we know
# we need to verify the content length.
self._verify_content_length()

if amt is None:
return self._buffer[self._buffer_cursor:]
else:
self._buffer_cursor += amt
return self._buffer[self._buffer_cursor-amt:self._buffer_cursor]
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from opentelemetry.instrumentation.sagemaker.utils import dont_throw
from wrapt import ObjectProxy


class StreamingWrapper(ObjectProxy):
def __init__(
self,
response,
stream_done_callback=None,
):
super().__init__(response)

self._stream_done_callback = stream_done_callback
self._accumulating_body = ""

def __iter__(self):
for event in self.__wrapped__:
self._process_event(event)
yield event
self._stream_done_callback(self._accumulating_body)

@dont_throw
def _process_event(self, event):
payload_part = event.get("PayloadPart")
if not payload_part:
return

decoded_payload_part = payload_part.get("Bytes").decode()
self._accumulating_body += decoded_payload_part
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import logging
import traceback

from opentelemetry.instrumentation.sagemaker.config import Config


def dont_throw(func):
"""
A decorator that wraps the passed in function and logs exceptions instead of throwing them.
@param func: The function to wrap
@return: The wrapper function
"""
# Obtain a logger specific to the function's module
logger = logging.getLogger(func.__module__)

def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except Exception as e:
logger.debug(
"OpenLLMetry failed to trace in %s, error: %s",
func.__name__,
traceback.format_exc(),
)
if Config.exception_logger:
Config.exception_logger(e)

return wrapper
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
__version__ = "0.25.6"
Loading

0 comments on commit 346d752

Please sign in to comment.