Skip to content

Commit

Permalink
respect provided tracer provider when instrumenting sqlalchemy
Browse files Browse the repository at this point in the history
This change updates the SQLALchemyInstrumentor to respect the tracer
provider that is passed in through the kwargs when patching the
`create_engine` functionality provided by SQLAlchemy. Previously, it
would default to the global tracer provider.
  • Loading branch information
jfmyers9 committed Oct 12, 2021
1 parent 224780f commit ed5259e
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,13 @@ def _instrument(self, **kwargs):
Returns:
An instrumented engine if passed in as an argument, None otherwise.
"""
_w("sqlalchemy", "create_engine", _wrap_create_engine)
_w("sqlalchemy.engine", "create_engine", _wrap_create_engine)
_w("sqlalchemy", "create_engine", _wrap_create_engine(kwargs))
_w("sqlalchemy.engine", "create_engine", _wrap_create_engine(kwargs))
if parse_version(sqlalchemy.__version__).release >= (1, 4):
_w(
"sqlalchemy.ext.asyncio",
"create_async_engine",
_wrap_create_async_engine,
_wrap_create_async_engine(kwargs),
)

if kwargs.get("engine") is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,23 +43,33 @@ def _get_tracer(engine, tracer_provider=None):


# pylint: disable=unused-argument
def _wrap_create_async_engine(func, module, args, kwargs):
"""Trace the SQLAlchemy engine, creating an `EngineTracer`
object that will listen to SQLAlchemy events.
"""
engine = func(*args, **kwargs)
EngineTracer(_get_tracer(engine), engine.sync_engine)
return engine
def _wrap_create_async_engine(kwargs):
tracer_provider = kwargs.get("tracer_provider")

def _wrap_create_async_engine_internal(func, module, args, kwargs):
"""Trace the SQLAlchemy engine, creating an `EngineTracer`
object that will listen to SQLAlchemy events.
"""
engine = func(*args, **kwargs)
EngineTracer(_get_tracer(engine, tracer_provider), engine.sync_engine)
return engine

return _wrap_create_async_engine_internal


# pylint: disable=unused-argument
def _wrap_create_engine(func, module, args, kwargs):
"""Trace the SQLAlchemy engine, creating an `EngineTracer`
object that will listen to SQLAlchemy events.
"""
engine = func(*args, **kwargs)
EngineTracer(_get_tracer(engine), engine)
return engine
def _wrap_create_engine(kwargs):
tracer_provider = kwargs.get("tracer_provider")

def _wrap_create_engine_internal(func, module, args, kwargs):
"""Trace the SQLAlchemy engine, creating an `EngineTracer`
object that will listen to SQLAlchemy events.
"""
engine = func(*args, **kwargs)
EngineTracer(_get_tracer(engine, tracer_provider), engine)
return engine

return _wrap_create_engine_internal


class EngineTracer:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from sqlalchemy import create_engine

from opentelemetry import trace
from opentelemetry.sdk.trace import TracerProvider, export
from opentelemetry.sdk.resources import Resource
from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor
from opentelemetry.test.test_base import TestBase

Expand Down Expand Up @@ -95,6 +97,28 @@ def test_create_engine_wrapper(self):
self.assertEqual(spans[0].name, "SELECT :memory:")
self.assertEqual(spans[0].kind, trace.SpanKind.CLIENT)

def test_custom_tracer_provider(self):
provider = TracerProvider(
resource=Resource.create(
{"service.name": "test", "deployment.environment": "env", "service.version": "1234"},
),
)
provider.add_span_processor(export.SimpleSpanProcessor(self.memory_exporter))

SQLAlchemyInstrumentor().instrument(tracer_provider=provider)
from sqlalchemy import create_engine # pylint: disable-all

engine = create_engine("sqlite:///:memory:")
cnx = engine.connect()
cnx.execute("SELECT 1 + 1;").fetchall()
spans = self.memory_exporter.get_finished_spans()

self.assertEqual(len(spans), 1)
self.assertEqual(spans[0].resource.attributes["service.name"], "test")
self.assertEqual(spans[0].resource.attributes["deployment.environment"], "env")
self.assertEqual(spans[0].resource.attributes["service.version"], "1234")


@pytest.mark.skipif(
not sqlalchemy.__version__.startswith("1.4"),
reason="only run async tests for 1.4",
Expand Down

0 comments on commit ed5259e

Please sign in to comment.