-
Notifications
You must be signed in to change notification settings - Fork 678
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(SageMaker): Add SageMaker instrumentation (#2028)
Co-authored-by: Nir Gazit <nirga@users.noreply.github.com>
- Loading branch information
1 parent
b814cdd
commit 346d752
Showing
22 changed files
with
1,604 additions
and
29 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
[flake8] | ||
exclude = | ||
.git, | ||
__pycache__, | ||
build, | ||
dist, | ||
.tox, | ||
venv, | ||
.venv, | ||
.pytest_cache | ||
max-line-length = 120 |
1 change: 1 addition & 0 deletions
1
packages/opentelemetry-instrumentation-sagemaker/.python-version
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
3.9.5 |
33 changes: 33 additions & 0 deletions
33
packages/opentelemetry-instrumentation-sagemaker/README.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
# OpenTelemetry SageMaker Instrumentation | ||
|
||
<a href="https://pypi.org/project/opentelemetry-instrumentation-sagemaker/"> | ||
<img src="https://badge.fury.io/py/opentelemetry-instrumentation-sagemaker.svg"> | ||
</a> | ||
|
||
This library allows tracing of any models deployed on Amazon SageMaker and invoked with [Boto3](https://github.com/boto/boto3) to SageMaker. | ||
|
||
## Installation | ||
|
||
```bash | ||
pip install opentelemetry-instrumentation-sagemaker | ||
``` | ||
|
||
## Example usage | ||
|
||
```python | ||
from opentelemetry.instrumentation.sagemaker import SageMakerInstrumentor | ||
|
||
SageMakerInstrumentor().instrument() | ||
``` | ||
|
||
## Privacy | ||
|
||
**By default, this instrumentation logs SageMaker endpoint request bodies and responses to span attributes**. This gives you a clear visibility into how your LLM application is working, and can make it easy to debug and evaluate the quality of the outputs. | ||
|
||
However, you may want to disable this logging for privacy reasons, as they may contain highly sensitive data from your users. You may also simply want to reduce the size of your traces. | ||
|
||
To disable logging, set the `TRACELOOP_TRACE_CONTENT` environment variable to `false`. | ||
|
||
```bash | ||
TRACELOOP_TRACE_CONTENT=false | ||
``` |
198 changes: 198 additions & 0 deletions
198
...entelemetry-instrumentation-sagemaker/opentelemetry/instrumentation/sagemaker/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,198 @@ | ||
"""OpenTelemetry SageMaker instrumentation""" | ||
|
||
from functools import wraps | ||
import json | ||
import logging | ||
import os | ||
from typing import Collection | ||
from opentelemetry.instrumentation.sagemaker.config import Config | ||
from opentelemetry.instrumentation.sagemaker.reusable_streaming_body import ( | ||
ReusableStreamingBody, | ||
) | ||
from opentelemetry.instrumentation.sagemaker.streaming_wrapper import StreamingWrapper | ||
from opentelemetry.instrumentation.sagemaker.utils import dont_throw | ||
from wrapt import wrap_function_wrapper | ||
|
||
from opentelemetry import context as context_api | ||
from opentelemetry.trace import get_tracer, SpanKind | ||
|
||
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor | ||
from opentelemetry.instrumentation.utils import ( | ||
_SUPPRESS_INSTRUMENTATION_KEY, | ||
unwrap, | ||
) | ||
|
||
from opentelemetry.semconv_ai import ( | ||
SUPPRESS_LANGUAGE_MODEL_INSTRUMENTATION_KEY, | ||
SpanAttributes, | ||
) | ||
from opentelemetry.instrumentation.sagemaker.version import __version__ | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
_instruments = ("boto3 >= 1.28.57",) | ||
|
||
WRAPPED_METHODS = [ | ||
{ | ||
"package": "botocore.client", | ||
"object": "ClientCreator", | ||
"method": "create_client", | ||
}, | ||
{"package": "botocore.session", "object": "Session", "method": "create_client"}, | ||
] | ||
|
||
|
||
def should_send_prompts(): | ||
return ( | ||
os.getenv("TRACELOOP_TRACE_CONTENT") or "true" | ||
).lower() == "true" or context_api.get_value("override_enable_content_tracing") | ||
|
||
|
||
def _set_span_attribute(span, name, value): | ||
if value is not None: | ||
if value != "": | ||
span.set_attribute(name, value) | ||
return | ||
|
||
|
||
def _with_tracer_wrapper(func): | ||
"""Helper for providing tracer for wrapper functions.""" | ||
|
||
def _with_tracer(tracer, to_wrap): | ||
def wrapper(wrapped, instance, args, kwargs): | ||
return func(tracer, to_wrap, wrapped, instance, args, kwargs) | ||
|
||
return wrapper | ||
|
||
return _with_tracer | ||
|
||
|
||
@_with_tracer_wrapper | ||
def _wrap(tracer, to_wrap, wrapped, instance, args, kwargs): | ||
"""Instruments and calls every function defined in TO_WRAP.""" | ||
if context_api.get_value(_SUPPRESS_INSTRUMENTATION_KEY): | ||
return wrapped(*args, **kwargs) | ||
|
||
if kwargs.get("service_name") == "sagemaker-runtime": | ||
client = wrapped(*args, **kwargs) | ||
client.invoke_endpoint = _instrumented_endpoint_invoke( | ||
client.invoke_endpoint, tracer | ||
) | ||
client.invoke_endpoint_with_response_stream = ( | ||
_instrumented_endpoint_invoke_with_response_stream( | ||
client.invoke_endpoint_with_response_stream, tracer | ||
) | ||
) | ||
|
||
return client | ||
|
||
return wrapped(*args, **kwargs) | ||
|
||
|
||
def _instrumented_endpoint_invoke(fn, tracer): | ||
@wraps(fn) | ||
def with_instrumentation(*args, **kwargs): | ||
if context_api.get_value(SUPPRESS_LANGUAGE_MODEL_INSTRUMENTATION_KEY): | ||
return fn(*args, **kwargs) | ||
|
||
with tracer.start_as_current_span( | ||
"sagemaker.completion", kind=SpanKind.CLIENT | ||
) as span: | ||
response = fn(*args, **kwargs) | ||
|
||
if span.is_recording(): | ||
_handle_call(span, kwargs, response) | ||
|
||
return response | ||
|
||
return with_instrumentation | ||
|
||
|
||
def _instrumented_endpoint_invoke_with_response_stream(fn, tracer): | ||
@wraps(fn) | ||
def with_instrumentation(*args, **kwargs): | ||
if context_api.get_value(SUPPRESS_LANGUAGE_MODEL_INSTRUMENTATION_KEY): | ||
return fn(*args, **kwargs) | ||
|
||
span = tracer.start_span("sagemaker.completion", kind=SpanKind.CLIENT) | ||
response = fn(*args, **kwargs) | ||
|
||
if span.is_recording(): | ||
_handle_stream_call(span, kwargs, response) | ||
|
||
return response | ||
|
||
return with_instrumentation | ||
|
||
|
||
def _handle_stream_call(span, kwargs, response): | ||
@dont_throw | ||
def stream_done(response_body): | ||
request_body = json.loads(kwargs.get("Body")) | ||
|
||
endpoint_name = kwargs.get("EndpointName") | ||
|
||
_set_span_attribute(span, SpanAttributes.LLM_REQUEST_MODEL, endpoint_name) | ||
_set_span_attribute( | ||
span, SpanAttributes.TRACELOOP_ENTITY_INPUT, json.dumps(request_body) | ||
) | ||
_set_span_attribute( | ||
span, SpanAttributes.TRACELOOP_ENTITY_OUTPUT, json.dumps(response_body) | ||
) | ||
|
||
span.end() | ||
|
||
response["Body"] = StreamingWrapper(response["Body"], stream_done) | ||
|
||
|
||
@dont_throw | ||
def _handle_call(span, kwargs, response): | ||
response["Body"] = ReusableStreamingBody( | ||
response["Body"]._raw_stream, response["Body"]._content_length | ||
) | ||
request_body = json.loads(kwargs.get("Body")) | ||
response_body = json.loads(response.get("Body").read()) | ||
|
||
endpoint_name = kwargs.get("EndpointName") | ||
|
||
_set_span_attribute(span, SpanAttributes.LLM_REQUEST_MODEL, endpoint_name) | ||
_set_span_attribute( | ||
span, SpanAttributes.TRACELOOP_ENTITY_INPUT, json.dumps(request_body) | ||
) | ||
_set_span_attribute( | ||
span, SpanAttributes.TRACELOOP_ENTITY_OUTPUT, json.dumps(response_body) | ||
) | ||
|
||
|
||
class SageMakerInstrumentor(BaseInstrumentor): | ||
"""An instrumentor for Bedrock's client library.""" | ||
|
||
def __init__(self, enrich_token_usage: bool = False, exception_logger=None): | ||
super().__init__() | ||
Config.enrich_token_usage = enrich_token_usage | ||
Config.exception_logger = exception_logger | ||
|
||
def instrumentation_dependencies(self) -> Collection[str]: | ||
return _instruments | ||
|
||
def _instrument(self, **kwargs): | ||
tracer_provider = kwargs.get("tracer_provider") | ||
tracer = get_tracer(__name__, __version__, tracer_provider) | ||
for wrapped_method in WRAPPED_METHODS: | ||
wrap_package = wrapped_method.get("package") | ||
wrap_object = wrapped_method.get("object") | ||
wrap_method = wrapped_method.get("method") | ||
wrap_function_wrapper( | ||
wrap_package, | ||
f"{wrap_object}.{wrap_method}", | ||
_wrap(tracer, wrapped_method), | ||
) | ||
|
||
def _uninstrument(self, **kwargs): | ||
for wrapped_method in WRAPPED_METHODS: | ||
wrap_package = wrapped_method.get("package") | ||
wrap_object = wrapped_method.get("object") | ||
unwrap( | ||
f"{wrap_package}.{wrap_object}", | ||
wrapped_method.get("method"), | ||
) |
3 changes: 3 additions & 0 deletions
3
...opentelemetry-instrumentation-sagemaker/opentelemetry/instrumentation/sagemaker/config.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
class Config: | ||
enrich_token_usage = False | ||
exception_logger = None |
43 changes: 43 additions & 0 deletions
43
...trumentation-sagemaker/opentelemetry/instrumentation/sagemaker/reusable_streaming_body.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
from botocore.response import StreamingBody | ||
from botocore.exceptions import ( | ||
ReadTimeoutError, | ||
ResponseStreamingError, | ||
) | ||
from urllib3.exceptions import ProtocolError as URLLib3ProtocolError | ||
from urllib3.exceptions import ReadTimeoutError as URLLib3ReadTimeoutError | ||
|
||
|
||
class ReusableStreamingBody(StreamingBody): | ||
"""Wrapper around StreamingBody that allows the body to be read multiple times.""" | ||
|
||
def __init__(self, raw_stream, content_length): | ||
super().__init__(raw_stream, content_length) | ||
self._buffer = None | ||
self._buffer_cursor = 0 | ||
|
||
def read(self, amt=None): | ||
"""Read at most amt bytes from the stream. | ||
If the amt argument is omitted, read all data. | ||
""" | ||
if self._buffer is None: | ||
try: | ||
self._buffer = self._raw_stream.read() | ||
except URLLib3ReadTimeoutError as e: | ||
# TODO: the url will be None as urllib3 isn't setting it yet | ||
raise ReadTimeoutError(endpoint_url=e.url, error=e) | ||
except URLLib3ProtocolError as e: | ||
raise ResponseStreamingError(error=e) | ||
|
||
self._amount_read += len(self._buffer) | ||
if amt is None or (not self._buffer and amt > 0): | ||
# If the server sends empty contents or | ||
# we ask to read all of the contents, then we know | ||
# we need to verify the content length. | ||
self._verify_content_length() | ||
|
||
if amt is None: | ||
return self._buffer[self._buffer_cursor:] | ||
else: | ||
self._buffer_cursor += amt | ||
return self._buffer[self._buffer_cursor-amt:self._buffer_cursor] |
29 changes: 29 additions & 0 deletions
29
...ry-instrumentation-sagemaker/opentelemetry/instrumentation/sagemaker/streaming_wrapper.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
from opentelemetry.instrumentation.sagemaker.utils import dont_throw | ||
from wrapt import ObjectProxy | ||
|
||
|
||
class StreamingWrapper(ObjectProxy): | ||
def __init__( | ||
self, | ||
response, | ||
stream_done_callback=None, | ||
): | ||
super().__init__(response) | ||
|
||
self._stream_done_callback = stream_done_callback | ||
self._accumulating_body = "" | ||
|
||
def __iter__(self): | ||
for event in self.__wrapped__: | ||
self._process_event(event) | ||
yield event | ||
self._stream_done_callback(self._accumulating_body) | ||
|
||
@dont_throw | ||
def _process_event(self, event): | ||
payload_part = event.get("PayloadPart") | ||
if not payload_part: | ||
return | ||
|
||
decoded_payload_part = payload_part.get("Bytes").decode() | ||
self._accumulating_body += decoded_payload_part |
29 changes: 29 additions & 0 deletions
29
.../opentelemetry-instrumentation-sagemaker/opentelemetry/instrumentation/sagemaker/utils.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
import logging | ||
import traceback | ||
|
||
from opentelemetry.instrumentation.sagemaker.config import Config | ||
|
||
|
||
def dont_throw(func): | ||
""" | ||
A decorator that wraps the passed in function and logs exceptions instead of throwing them. | ||
@param func: The function to wrap | ||
@return: The wrapper function | ||
""" | ||
# Obtain a logger specific to the function's module | ||
logger = logging.getLogger(func.__module__) | ||
|
||
def wrapper(*args, **kwargs): | ||
try: | ||
return func(*args, **kwargs) | ||
except Exception as e: | ||
logger.debug( | ||
"OpenLLMetry failed to trace in %s, error: %s", | ||
func.__name__, | ||
traceback.format_exc(), | ||
) | ||
if Config.exception_logger: | ||
Config.exception_logger(e) | ||
|
||
return wrapper |
1 change: 1 addition & 0 deletions
1
...pentelemetry-instrumentation-sagemaker/opentelemetry/instrumentation/sagemaker/version.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
__version__ = "0.25.6" |
Oops, something went wrong.