Skip to content

Commit

Permalink
Add SQLAlchemy multithreading test (#468)
Browse files Browse the repository at this point in the history
  • Loading branch information
jomasti authored Jun 3, 2021
1 parent 3d7cc64 commit a3ecbc1
Showing 1 changed file with 46 additions and 2 deletions.
48 changes: 46 additions & 2 deletions tests/opentelemetry-docker-tests/tests/sqlalchemy_tests/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@
# limitations under the License.

import contextlib
import logging
import threading

from sqlalchemy import Column, Integer, String, create_engine
from sqlalchemy import Column, Integer, String, create_engine, insert
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import close_all_sessions, scoped_session, sessionmaker

from opentelemetry import trace
from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor
Expand Down Expand Up @@ -199,3 +201,45 @@ def test_parent(self):
self.assertEqual(parent_span.instrumentation_info.name, "sqlalch_svc")

self.assertEqual(child_span.name, "SELECT " + self.SQL_DB)

def test_multithreading(self):
"""Ensure spans are captured correctly in a multithreading scenario
We also expect no logged warnings about calling end() on an ended span.
"""

if self.VENDOR == "sqlite":
return

def insert_player(session):
_session = session()
player = Player(name="Player")
_session.add(player)
_session.commit()
_session.query(Player).all()

def insert_players(session):
_session = session()
players = []
for player_number in range(3):
players.append(Player(name=f"Player {player_number}"))
_session.add_all(players)
_session.commit()

session_factory = sessionmaker(bind=self.engine)
# pylint: disable=invalid-name
Session = scoped_session(session_factory)
thread_one = threading.Thread(target=insert_player, args=(Session,))
thread_two = threading.Thread(target=insert_players, args=(Session,))

logger = logging.getLogger("opentelemetry.sdk.trace")
with self.assertRaises(AssertionError):
with self.assertLogs(logger, level="WARNING"):
thread_one.start()
thread_two.start()
thread_one.join()
thread_two.join()
close_all_sessions()

spans = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans), 5)

0 comments on commit a3ecbc1

Please sign in to comment.