From ed5259ea9bbfcf0557a29238bfa93b487422d77f Mon Sep 17 00:00:00 2001 From: Jim Myers Date: Tue, 12 Oct 2021 12:08:17 -0400 Subject: [PATCH] respect provided tracer provider when instrumenting sqlalchemy This change updates the SQLALchemyInstrumentor to respect the tracer provider that is passed in through the kwargs when patching the `create_engine` functionality provided by SQLAlchemy. Previously, it would default to the global tracer provider. --- .../instrumentation/sqlalchemy/__init__.py | 6 +-- .../instrumentation/sqlalchemy/engine.py | 38 ++++++++++++------- .../tests/test_sqlalchemy.py | 24 ++++++++++++ 3 files changed, 51 insertions(+), 17 deletions(-) diff --git a/instrumentation/opentelemetry-instrumentation-sqlalchemy/src/opentelemetry/instrumentation/sqlalchemy/__init__.py b/instrumentation/opentelemetry-instrumentation-sqlalchemy/src/opentelemetry/instrumentation/sqlalchemy/__init__.py index 05e6451626..42524fdfcd 100644 --- a/instrumentation/opentelemetry-instrumentation-sqlalchemy/src/opentelemetry/instrumentation/sqlalchemy/__init__.py +++ b/instrumentation/opentelemetry-instrumentation-sqlalchemy/src/opentelemetry/instrumentation/sqlalchemy/__init__.py @@ -88,13 +88,13 @@ def _instrument(self, **kwargs): Returns: An instrumented engine if passed in as an argument, None otherwise. """ - _w("sqlalchemy", "create_engine", _wrap_create_engine) - _w("sqlalchemy.engine", "create_engine", _wrap_create_engine) + _w("sqlalchemy", "create_engine", _wrap_create_engine(kwargs)) + _w("sqlalchemy.engine", "create_engine", _wrap_create_engine(kwargs)) if parse_version(sqlalchemy.__version__).release >= (1, 4): _w( "sqlalchemy.ext.asyncio", "create_async_engine", - _wrap_create_async_engine, + _wrap_create_async_engine(kwargs), ) if kwargs.get("engine") is not None: diff --git a/instrumentation/opentelemetry-instrumentation-sqlalchemy/src/opentelemetry/instrumentation/sqlalchemy/engine.py b/instrumentation/opentelemetry-instrumentation-sqlalchemy/src/opentelemetry/instrumentation/sqlalchemy/engine.py index ed1dfb1976..97cb3eb6b4 100644 --- a/instrumentation/opentelemetry-instrumentation-sqlalchemy/src/opentelemetry/instrumentation/sqlalchemy/engine.py +++ b/instrumentation/opentelemetry-instrumentation-sqlalchemy/src/opentelemetry/instrumentation/sqlalchemy/engine.py @@ -43,23 +43,33 @@ def _get_tracer(engine, tracer_provider=None): # pylint: disable=unused-argument -def _wrap_create_async_engine(func, module, args, kwargs): - """Trace the SQLAlchemy engine, creating an `EngineTracer` - object that will listen to SQLAlchemy events. - """ - engine = func(*args, **kwargs) - EngineTracer(_get_tracer(engine), engine.sync_engine) - return engine +def _wrap_create_async_engine(kwargs): + tracer_provider = kwargs.get("tracer_provider") + + def _wrap_create_async_engine_internal(func, module, args, kwargs): + """Trace the SQLAlchemy engine, creating an `EngineTracer` + object that will listen to SQLAlchemy events. + """ + engine = func(*args, **kwargs) + EngineTracer(_get_tracer(engine, tracer_provider), engine.sync_engine) + return engine + + return _wrap_create_async_engine_internal # pylint: disable=unused-argument -def _wrap_create_engine(func, module, args, kwargs): - """Trace the SQLAlchemy engine, creating an `EngineTracer` - object that will listen to SQLAlchemy events. - """ - engine = func(*args, **kwargs) - EngineTracer(_get_tracer(engine), engine) - return engine +def _wrap_create_engine(kwargs): + tracer_provider = kwargs.get("tracer_provider") + + def _wrap_create_engine_internal(func, module, args, kwargs): + """Trace the SQLAlchemy engine, creating an `EngineTracer` + object that will listen to SQLAlchemy events. + """ + engine = func(*args, **kwargs) + EngineTracer(_get_tracer(engine, tracer_provider), engine) + return engine + + return _wrap_create_engine_internal class EngineTracer: diff --git a/instrumentation/opentelemetry-instrumentation-sqlalchemy/tests/test_sqlalchemy.py b/instrumentation/opentelemetry-instrumentation-sqlalchemy/tests/test_sqlalchemy.py index bed2b5f312..536e323a2c 100644 --- a/instrumentation/opentelemetry-instrumentation-sqlalchemy/tests/test_sqlalchemy.py +++ b/instrumentation/opentelemetry-instrumentation-sqlalchemy/tests/test_sqlalchemy.py @@ -19,6 +19,8 @@ from sqlalchemy import create_engine from opentelemetry import trace +from opentelemetry.sdk.trace import TracerProvider, export +from opentelemetry.sdk.resources import Resource from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor from opentelemetry.test.test_base import TestBase @@ -95,6 +97,28 @@ def test_create_engine_wrapper(self): self.assertEqual(spans[0].name, "SELECT :memory:") self.assertEqual(spans[0].kind, trace.SpanKind.CLIENT) + def test_custom_tracer_provider(self): + provider = TracerProvider( + resource=Resource.create( + {"service.name": "test", "deployment.environment": "env", "service.version": "1234"}, + ), + ) + provider.add_span_processor(export.SimpleSpanProcessor(self.memory_exporter)) + + SQLAlchemyInstrumentor().instrument(tracer_provider=provider) + from sqlalchemy import create_engine # pylint: disable-all + + engine = create_engine("sqlite:///:memory:") + cnx = engine.connect() + cnx.execute("SELECT 1 + 1;").fetchall() + spans = self.memory_exporter.get_finished_spans() + + self.assertEqual(len(spans), 1) + self.assertEqual(spans[0].resource.attributes["service.name"], "test") + self.assertEqual(spans[0].resource.attributes["deployment.environment"], "env") + self.assertEqual(spans[0].resource.attributes["service.version"], "1234") + + @pytest.mark.skipif( not sqlalchemy.__version__.startswith("1.4"), reason="only run async tests for 1.4",