Skip to content

Commit

Permalink
Merge pull request #202 from lonewolf3739/sqlalchemy-semantic-conv
Browse files Browse the repository at this point in the history
  • Loading branch information
lzchen authored Nov 23, 2020
2 parents f52c88b + b7f8a5b commit e742cf7
Show file tree
Hide file tree
Showing 9 changed files with 51 additions and 65 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

## Unreleased

- Update sqlalchemy instrumentation to follow semantic conventions
([#202](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/202))

## Version 0.13b0

Released 2020-09-17
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
engine = create_engine("sqlite:///:memory:")
SQLAlchemyInstrumentor().instrument(
engine=engine,
service="service-A",
)
API
Expand Down Expand Up @@ -66,7 +65,6 @@ def _instrument(self, **kwargs):
**kwargs: Optional arguments
``engine``: a SQLAlchemy engine instance
``tracer_provider``: a TracerProvider, defaults to global
``service``: the name of the service to trace.
Returns:
An instrumented engine if passed in as an argument, None otherwise.
Expand All @@ -78,7 +76,6 @@ def _instrument(self, **kwargs):
_get_tracer(
kwargs.get("engine"), kwargs.get("tracer_provider")
),
kwargs.get("service"),
kwargs.get("engine"),
)
return None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,9 @@
_PORT = "net.peer.port"
# Database semantic conventions here:
# https://github.com/open-telemetry/opentelemetry-specification/blob/master/specification/trace/semantic_conventions/database.md
_ROWS = "sql.rows" # number of rows returned by a query
_STMT = "db.statement"
_DB = "db.type"
_URL = "db.url"
_DB = "db.name"
_USER = "db.user"


def _normalize_vendor(vendor):
Expand All @@ -39,7 +38,7 @@ def _normalize_vendor(vendor):
return "sqlite"

if "postgres" in vendor or vendor == "psycopg2":
return "postgres"
return "postgresql"

return vendor

Expand All @@ -58,17 +57,15 @@ def _wrap_create_engine(func, module, args, kwargs):
object that will listen to SQLAlchemy events.
"""
engine = func(*args, **kwargs)
EngineTracer(_get_tracer(engine), None, engine)
EngineTracer(_get_tracer(engine), engine)
return engine


class EngineTracer:
def __init__(self, tracer, service, engine):
def __init__(self, tracer, engine):
self.tracer = tracer
self.engine = engine
self.vendor = _normalize_vendor(engine.name)
self.service = service or self.vendor
self.name = "%s.query" % self.vendor
self.current_span = None

listen(engine, "before_cursor_execute", self._before_cur_exec)
Expand All @@ -77,11 +74,11 @@ def __init__(self, tracer, service, engine):

# pylint: disable=unused-argument
def _before_cur_exec(self, conn, cursor, statement, *args):
self.current_span = self.tracer.start_span(self.name)
self.current_span = self.tracer.start_span(statement)
with self.tracer.use_span(self.current_span, end_on_exit=False):
if self.current_span.is_recording():
self.current_span.set_attribute("service", self.vendor)
self.current_span.set_attribute(_STMT, statement)
self.current_span.set_attribute("db.system", self.vendor)

if not _set_attributes_from_url(
self.current_span, conn.engine.url
Expand All @@ -94,16 +91,7 @@ def _before_cur_exec(self, conn, cursor, statement, *args):
def _after_cur_exec(self, conn, cursor, statement, *args):
if self.current_span is None:
return

try:
if (
cursor
and cursor.rowcount >= 0
and self.current_span.is_recording()
):
self.current_span.set_attribute(_ROWS, cursor.rowcount)
finally:
self.current_span.end()
self.current_span.end()

def _handle_error(self, context):
if self.current_span is None:
Expand All @@ -127,6 +115,8 @@ def _set_attributes_from_url(span: trace.Span, url):
span.set_attribute(_PORT, url.port)
if url.database:
span.set_attribute(_DB, url.database)
if url.username:
span.set_attribute(_USER, url.username)

return bool(url.host)

Expand All @@ -135,7 +125,7 @@ def _set_attributes_from_cursor(span: trace.Span, vendor, cursor):
"""Attempt to set db connection attributes by introspecting the cursor."""
if not span.is_recording():
return
if vendor == "postgres":
if vendor == "postgresql":
# pylint: disable=import-outside-toplevel
from psycopg2.extensions import parse_dsn

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,14 @@ def tearDown(self):
def test_trace_integration(self):
engine = create_engine("sqlite:///:memory:")
SQLAlchemyInstrumentor().instrument(
engine=engine,
tracer_provider=self.tracer_provider,
service="my-database",
engine=engine, tracer_provider=self.tracer_provider,
)
cnx = engine.connect()
cnx.execute("SELECT 1 + 1;").fetchall()
spans = self.memory_exporter.get_finished_spans()

self.assertEqual(len(spans), 1)
self.assertEqual(spans[0].name, "sqlite.query")
self.assertEqual(spans[0].name, "SELECT 1 + 1;")

def test_not_recording(self):
mock_tracer = mock.Mock()
Expand All @@ -49,9 +47,7 @@ def test_not_recording(self):
tracer.return_value = mock_tracer
engine = create_engine("sqlite:///:memory:")
SQLAlchemyInstrumentor().instrument(
engine=engine,
tracer_provider=self.tracer_provider,
service="my-database",
engine=engine, tracer_provider=self.tracer_provider,
)
cnx = engine.connect()
cnx.execute("SELECT 1 + 1;").fetchall()
Expand All @@ -70,4 +66,4 @@ def test_create_engine_wrapper(self):
spans = self.memory_exporter.get_finished_spans()

self.assertEqual(len(spans), 1)
self.assertEqual(spans[0].name, "sqlite.query")
self.assertEqual(spans[0].name, "SELECT 1 + 1;")
33 changes: 21 additions & 12 deletions tests/opentelemetry-docker-tests/tests/sqlalchemy_tests/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from opentelemetry import trace
from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor
from opentelemetry.instrumentation.sqlalchemy.engine import _DB, _ROWS, _STMT
from opentelemetry.instrumentation.sqlalchemy.engine import _DB, _STMT
from opentelemetry.test.test_base import TestBase

Base = declarative_base()
Expand Down Expand Up @@ -109,9 +109,8 @@ def tearDown(self):
SQLAlchemyInstrumentor().uninstrument()
super().tearDown()

def _check_span(self, span):
self.assertEqual(span.name, "{}.query".format(self.VENDOR))
self.assertEqual(span.attributes.get("service"), self.SERVICE)
def _check_span(self, span, name):
self.assertEqual(span.name, name)
self.assertEqual(span.attributes.get(_DB), self.SQL_DB)
self.assertIs(span.status.status_code, trace.status.StatusCode.UNSET)
self.assertGreater((span.end_time - span.start_time), 0)
Expand All @@ -125,9 +124,13 @@ def test_orm_insert(self):
spans = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans), 1)
span = spans[0]
self._check_span(span)
stmt = "INSERT INTO players (id, name) VALUES "
if span.attributes.get("db.system") == "sqlite":
stmt += "(?, ?)"
else:
stmt += "(%(id)s, %(name)s)"
self._check_span(span, stmt)
self.assertIn("INSERT INTO players", span.attributes.get(_STMT))
self.assertEqual(span.attributes.get(_ROWS), 1)
self.check_meta(span)

def test_session_query(self):
Expand All @@ -138,7 +141,12 @@ def test_session_query(self):
spans = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans), 1)
span = spans[0]
self._check_span(span)
stmt = "SELECT players.id AS players_id, players.name AS players_name \nFROM players \nWHERE players.name = "
if span.attributes.get("db.system") == "sqlite":
stmt += "?"
else:
stmt += "%(name_1)s"
self._check_span(span, stmt)
self.assertIn(
"SELECT players.id AS players_id, players.name AS players_name \nFROM players \nWHERE players.name",
span.attributes.get(_STMT),
Expand All @@ -147,24 +155,26 @@ def test_session_query(self):

def test_engine_connect_execute(self):
# ensures that engine.connect() is properly traced
stmt = "SELECT * FROM players"
with self.connection() as conn:
rows = conn.execute("SELECT * FROM players").fetchall()
rows = conn.execute(stmt).fetchall()
self.assertEqual(len(rows), 0)

spans = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans), 1)
span = spans[0]
self._check_span(span)
self._check_span(span, stmt)
self.assertEqual(span.attributes.get(_STMT), "SELECT * FROM players")
self.check_meta(span)

def test_parent(self):
"""Ensure that sqlalchemy works with opentelemetry."""
stmt = "SELECT * FROM players"
tracer = self.tracer_provider.get_tracer("sqlalch_svc")

with tracer.start_as_current_span("sqlalch_op"):
with self.connection() as conn:
rows = conn.execute("SELECT * FROM players").fetchall()
rows = conn.execute(stmt).fetchall()
self.assertEqual(len(rows), 0)

spans = self.memory_exporter.get_finished_spans()
Expand All @@ -178,5 +188,4 @@ def test_parent(self):
self.assertEqual(parent_span.name, "sqlalch_op")
self.assertEqual(parent_span.instrumentation_info.name, "sqlalch_svc")

self.assertEqual(child_span.name, "{}.query".format(self.VENDOR))
self.assertEqual(child_span.attributes.get("service"), self.SERVICE)
self.assertEqual(child_span.name, stmt)
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ def test_engine_traced(self):
self.assertEqual(len(traces), 1)
span = traces[0]
# check subset of span fields
self.assertEqual(span.name, "postgres.query")
self.assertEqual(span.attributes.get("service"), "postgres")
self.assertEqual(span.name, "SELECT 1")
self.assertIs(span.status.status_code, trace.status.StatusCode.UNSET)
self.assertGreater((span.end_time - span.start_time), 0)
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
_DB,
_HOST,
_PORT,
_ROWS,
_STMT,
_USER,
)

from .mixins import SQLAlchemyTestMixin
Expand All @@ -45,7 +45,6 @@ class MysqlConnectorTestCase(SQLAlchemyTestMixin):

VENDOR = "mysql"
SQL_DB = "opentelemetry-tests"
SERVICE = "mysql"
ENGINE_ARGS = {
"url": "mysql+mysqlconnector://%(user)s:%(password)s@%(host)s:%(port)s/%(database)s"
% MYSQL_CONFIG
Expand All @@ -55,6 +54,8 @@ def check_meta(self, span):
# check database connection tags
self.assertEqual(span.attributes.get(_HOST), MYSQL_CONFIG["host"])
self.assertEqual(span.attributes.get(_PORT), MYSQL_CONFIG["port"])
self.assertEqual(span.attributes.get(_DB), MYSQL_CONFIG["database"])
self.assertEqual(span.attributes.get(_USER), MYSQL_CONFIG["user"])

def test_engine_execute_errors(self):
# ensures that SQL errors are reported
Expand All @@ -66,13 +67,11 @@ def test_engine_execute_errors(self):
self.assertEqual(len(spans), 1)
span = spans[0]
# span fields
self.assertEqual(span.name, "{}.query".format(self.VENDOR))
self.assertEqual(span.attributes.get("service"), self.SERVICE)
self.assertEqual(span.name, "SELECT * FROM a_wrong_table")
self.assertEqual(
span.attributes.get(_STMT), "SELECT * FROM a_wrong_table"
)
self.assertEqual(span.attributes.get(_DB), self.SQL_DB)
self.assertIsNone(span.attributes.get(_ROWS))
self.check_meta(span)
self.assertTrue(span.end_time - span.start_time > 0)
# check the error
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
_DB,
_HOST,
_PORT,
_ROWS,
_STMT,
)

Expand All @@ -44,9 +43,8 @@ class PostgresTestCase(SQLAlchemyTestMixin):

__test__ = True

VENDOR = "postgres"
VENDOR = "postgresql"
SQL_DB = "opentelemetry-tests"
SERVICE = "postgres"
ENGINE_ARGS = {
"url": "postgresql://%(user)s:%(password)s@%(host)s:%(port)s/%(dbname)s"
% POSTGRES_CONFIG
Expand All @@ -67,13 +65,11 @@ def test_engine_execute_errors(self):
self.assertEqual(len(spans), 1)
span = spans[0]
# span fields
self.assertEqual(span.name, "{}.query".format(self.VENDOR))
self.assertEqual(span.attributes.get("service"), self.SERVICE)
self.assertEqual(span.name, "SELECT * FROM a_wrong_table")
self.assertEqual(
span.attributes.get(_STMT), "SELECT * FROM a_wrong_table"
)
self.assertEqual(span.attributes.get(_DB), self.SQL_DB)
self.assertIsNone(span.attributes.get(_ROWS))
self.check_meta(span)
self.assertTrue(span.end_time - span.start_time > 0)
# check the error
Expand All @@ -88,9 +84,8 @@ class PostgresCreatorTestCase(PostgresTestCase):
of `PostgresTestCase`, but it uses a specific `creator` function.
"""

VENDOR = "postgres"
VENDOR = "postgresql"
SQL_DB = "opentelemetry-tests"
SERVICE = "postgres"
ENGINE_ARGS = {
"url": "postgresql://",
"creator": lambda: psycopg2.connect(**POSTGRES_CONFIG),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from sqlalchemy.exc import OperationalError

from opentelemetry import trace
from opentelemetry.instrumentation.sqlalchemy.engine import _DB, _ROWS, _STMT
from opentelemetry.instrumentation.sqlalchemy.engine import _DB, _STMT

from .mixins import SQLAlchemyTestMixin

Expand All @@ -30,26 +30,24 @@ class SQLiteTestCase(SQLAlchemyTestMixin):

VENDOR = "sqlite"
SQL_DB = ":memory:"
SERVICE = "sqlite"
ENGINE_ARGS = {"url": "sqlite:///:memory:"}

def test_engine_execute_errors(self):
# ensures that SQL errors are reported
stmt = "SELECT * FROM a_wrong_table"
with pytest.raises(OperationalError):
with self.connection() as conn:
conn.execute("SELECT * FROM a_wrong_table").fetchall()
conn.execute(stmt).fetchall()

spans = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans), 1)
span = spans[0]
# span fields
self.assertEqual(span.name, "{}.query".format(self.VENDOR))
self.assertEqual(span.attributes.get("service"), self.SERVICE)
self.assertEqual(span.name, stmt)
self.assertEqual(
span.attributes.get(_STMT), "SELECT * FROM a_wrong_table"
)
self.assertEqual(span.attributes.get(_DB), self.SQL_DB)
self.assertIsNone(span.attributes.get(_ROWS))
self.assertTrue((span.end_time - span.start_time) > 0)
# check the error
self.assertIs(
Expand Down

0 comments on commit e742cf7

Please sign in to comment.