diff --git a/CHANGELOG.md b/CHANGELOG.md index 841382c865..a969841508 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ([#713](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/713)) - `opentelemetry-sdk-extension-aws` Move AWS X-Ray Propagator into its own `opentelemetry-propagators-aws` package ([#720](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/720)) +- `opentelemetry-instrumentation-sqlalchemy` Respect provided tracer provider when instrumenting SQLAlchemy + ([#728](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/728)) ### Changed diff --git a/instrumentation/opentelemetry-instrumentation-sqlalchemy/src/opentelemetry/instrumentation/sqlalchemy/__init__.py b/instrumentation/opentelemetry-instrumentation-sqlalchemy/src/opentelemetry/instrumentation/sqlalchemy/__init__.py index 05e6451626..0c81f2f0da 100644 --- a/instrumentation/opentelemetry-instrumentation-sqlalchemy/src/opentelemetry/instrumentation/sqlalchemy/__init__.py +++ b/instrumentation/opentelemetry-instrumentation-sqlalchemy/src/opentelemetry/instrumentation/sqlalchemy/__init__.py @@ -88,20 +88,23 @@ def _instrument(self, **kwargs): Returns: An instrumented engine if passed in as an argument, None otherwise. """ - _w("sqlalchemy", "create_engine", _wrap_create_engine) - _w("sqlalchemy.engine", "create_engine", _wrap_create_engine) + tracer_provider = kwargs.get("tracer_provider") + _w("sqlalchemy", "create_engine", _wrap_create_engine(tracer_provider)) + _w( + "sqlalchemy.engine", + "create_engine", + _wrap_create_engine(tracer_provider), + ) if parse_version(sqlalchemy.__version__).release >= (1, 4): _w( "sqlalchemy.ext.asyncio", "create_async_engine", - _wrap_create_async_engine, + _wrap_create_async_engine(tracer_provider), ) if kwargs.get("engine") is not None: return EngineTracer( - _get_tracer( - kwargs.get("engine"), kwargs.get("tracer_provider") - ), + _get_tracer(kwargs.get("engine"), tracer_provider), kwargs.get("engine"), ) return None diff --git a/instrumentation/opentelemetry-instrumentation-sqlalchemy/src/opentelemetry/instrumentation/sqlalchemy/engine.py b/instrumentation/opentelemetry-instrumentation-sqlalchemy/src/opentelemetry/instrumentation/sqlalchemy/engine.py index ed1dfb1976..f516e54193 100644 --- a/instrumentation/opentelemetry-instrumentation-sqlalchemy/src/opentelemetry/instrumentation/sqlalchemy/engine.py +++ b/instrumentation/opentelemetry-instrumentation-sqlalchemy/src/opentelemetry/instrumentation/sqlalchemy/engine.py @@ -42,24 +42,30 @@ def _get_tracer(engine, tracer_provider=None): ) -# pylint: disable=unused-argument -def _wrap_create_async_engine(func, module, args, kwargs): - """Trace the SQLAlchemy engine, creating an `EngineTracer` - object that will listen to SQLAlchemy events. - """ - engine = func(*args, **kwargs) - EngineTracer(_get_tracer(engine), engine.sync_engine) - return engine +def _wrap_create_async_engine(tracer_provider=None): + # pylint: disable=unused-argument + def _wrap_create_async_engine_internal(func, module, args, kwargs): + """Trace the SQLAlchemy engine, creating an `EngineTracer` + object that will listen to SQLAlchemy events. + """ + engine = func(*args, **kwargs) + EngineTracer(_get_tracer(engine, tracer_provider), engine.sync_engine) + return engine + return _wrap_create_async_engine_internal -# pylint: disable=unused-argument -def _wrap_create_engine(func, module, args, kwargs): - """Trace the SQLAlchemy engine, creating an `EngineTracer` - object that will listen to SQLAlchemy events. - """ - engine = func(*args, **kwargs) - EngineTracer(_get_tracer(engine), engine) - return engine + +def _wrap_create_engine(tracer_provider=None): + # pylint: disable=unused-argument + def _wrap_create_engine_internal(func, module, args, kwargs): + """Trace the SQLAlchemy engine, creating an `EngineTracer` + object that will listen to SQLAlchemy events. + """ + engine = func(*args, **kwargs) + EngineTracer(_get_tracer(engine, tracer_provider), engine) + return engine + + return _wrap_create_engine_internal class EngineTracer: diff --git a/instrumentation/opentelemetry-instrumentation-sqlalchemy/tests/test_sqlalchemy.py b/instrumentation/opentelemetry-instrumentation-sqlalchemy/tests/test_sqlalchemy.py index bed2b5f312..c71f1ab8bd 100644 --- a/instrumentation/opentelemetry-instrumentation-sqlalchemy/tests/test_sqlalchemy.py +++ b/instrumentation/opentelemetry-instrumentation-sqlalchemy/tests/test_sqlalchemy.py @@ -20,6 +20,8 @@ from opentelemetry import trace from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor +from opentelemetry.sdk.resources import Resource +from opentelemetry.sdk.trace import TracerProvider, export from opentelemetry.test.test_base import TestBase @@ -95,6 +97,37 @@ def test_create_engine_wrapper(self): self.assertEqual(spans[0].name, "SELECT :memory:") self.assertEqual(spans[0].kind, trace.SpanKind.CLIENT) + def test_custom_tracer_provider(self): + provider = TracerProvider( + resource=Resource.create( + { + "service.name": "test", + "deployment.environment": "env", + "service.version": "1234", + }, + ), + ) + provider.add_span_processor( + export.SimpleSpanProcessor(self.memory_exporter) + ) + + SQLAlchemyInstrumentor().instrument(tracer_provider=provider) + from sqlalchemy import create_engine # pylint: disable-all + + engine = create_engine("sqlite:///:memory:") + cnx = engine.connect() + cnx.execute("SELECT 1 + 1;").fetchall() + spans = self.memory_exporter.get_finished_spans() + + self.assertEqual(len(spans), 1) + self.assertEqual(spans[0].resource.attributes["service.name"], "test") + self.assertEqual( + spans[0].resource.attributes["deployment.environment"], "env" + ) + self.assertEqual( + spans[0].resource.attributes["service.version"], "1234" + ) + @pytest.mark.skipif( not sqlalchemy.__version__.startswith("1.4"), reason="only run async tests for 1.4",