Skip to content

Commit

Permalink
Respect provided tracer provider when instrumenting SQLAlchemy (#728)
Browse files Browse the repository at this point in the history
* respect provided tracer provider when instrumenting sqlalchemy

This change updates the SQLALchemyInstrumentor to respect the tracer
provider that is passed in through the kwargs when patching the
`create_engine` functionality provided by SQLAlchemy. Previously, it
would default to the global tracer provider.

* feedback: pass in tracer_provider directly rather than kwargs

* feedback: update changelog

* build: lint
  • Loading branch information
jfmyers9 authored Oct 12, 2021
1 parent 5105820 commit e8af7a3
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 22 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
([#713](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/713))
- `opentelemetry-sdk-extension-aws` Move AWS X-Ray Propagator into its own `opentelemetry-propagators-aws` package
([#720](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/720))
- `opentelemetry-instrumentation-sqlalchemy` Respect provided tracer provider when instrumenting SQLAlchemy
([#728](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/728))


### Changed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,20 +88,23 @@ def _instrument(self, **kwargs):
Returns:
An instrumented engine if passed in as an argument, None otherwise.
"""
_w("sqlalchemy", "create_engine", _wrap_create_engine)
_w("sqlalchemy.engine", "create_engine", _wrap_create_engine)
tracer_provider = kwargs.get("tracer_provider")
_w("sqlalchemy", "create_engine", _wrap_create_engine(tracer_provider))
_w(
"sqlalchemy.engine",
"create_engine",
_wrap_create_engine(tracer_provider),
)
if parse_version(sqlalchemy.__version__).release >= (1, 4):
_w(
"sqlalchemy.ext.asyncio",
"create_async_engine",
_wrap_create_async_engine,
_wrap_create_async_engine(tracer_provider),
)

if kwargs.get("engine") is not None:
return EngineTracer(
_get_tracer(
kwargs.get("engine"), kwargs.get("tracer_provider")
),
_get_tracer(kwargs.get("engine"), tracer_provider),
kwargs.get("engine"),
)
return None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,24 +42,30 @@ def _get_tracer(engine, tracer_provider=None):
)


# pylint: disable=unused-argument
def _wrap_create_async_engine(func, module, args, kwargs):
"""Trace the SQLAlchemy engine, creating an `EngineTracer`
object that will listen to SQLAlchemy events.
"""
engine = func(*args, **kwargs)
EngineTracer(_get_tracer(engine), engine.sync_engine)
return engine
def _wrap_create_async_engine(tracer_provider=None):
# pylint: disable=unused-argument
def _wrap_create_async_engine_internal(func, module, args, kwargs):
"""Trace the SQLAlchemy engine, creating an `EngineTracer`
object that will listen to SQLAlchemy events.
"""
engine = func(*args, **kwargs)
EngineTracer(_get_tracer(engine, tracer_provider), engine.sync_engine)
return engine

return _wrap_create_async_engine_internal

# pylint: disable=unused-argument
def _wrap_create_engine(func, module, args, kwargs):
"""Trace the SQLAlchemy engine, creating an `EngineTracer`
object that will listen to SQLAlchemy events.
"""
engine = func(*args, **kwargs)
EngineTracer(_get_tracer(engine), engine)
return engine

def _wrap_create_engine(tracer_provider=None):
# pylint: disable=unused-argument
def _wrap_create_engine_internal(func, module, args, kwargs):
"""Trace the SQLAlchemy engine, creating an `EngineTracer`
object that will listen to SQLAlchemy events.
"""
engine = func(*args, **kwargs)
EngineTracer(_get_tracer(engine, tracer_provider), engine)
return engine

return _wrap_create_engine_internal


class EngineTracer:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

from opentelemetry import trace
from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace import TracerProvider, export
from opentelemetry.test.test_base import TestBase


Expand Down Expand Up @@ -95,6 +97,37 @@ def test_create_engine_wrapper(self):
self.assertEqual(spans[0].name, "SELECT :memory:")
self.assertEqual(spans[0].kind, trace.SpanKind.CLIENT)

def test_custom_tracer_provider(self):
provider = TracerProvider(
resource=Resource.create(
{
"service.name": "test",
"deployment.environment": "env",
"service.version": "1234",
},
),
)
provider.add_span_processor(
export.SimpleSpanProcessor(self.memory_exporter)
)

SQLAlchemyInstrumentor().instrument(tracer_provider=provider)
from sqlalchemy import create_engine # pylint: disable-all

engine = create_engine("sqlite:///:memory:")
cnx = engine.connect()
cnx.execute("SELECT 1 + 1;").fetchall()
spans = self.memory_exporter.get_finished_spans()

self.assertEqual(len(spans), 1)
self.assertEqual(spans[0].resource.attributes["service.name"], "test")
self.assertEqual(
spans[0].resource.attributes["deployment.environment"], "env"
)
self.assertEqual(
spans[0].resource.attributes["service.version"], "1234"
)

@pytest.mark.skipif(
not sqlalchemy.__version__.startswith("1.4"),
reason="only run async tests for 1.4",
Expand Down

0 comments on commit e8af7a3

Please sign in to comment.