diff --git a/tests/opentelemetry-docker-tests/tests/sqlalchemy_tests/mixins.py b/tests/opentelemetry-docker-tests/tests/sqlalchemy_tests/mixins.py index ccf37ae425..c2e5548ab1 100644 --- a/tests/opentelemetry-docker-tests/tests/sqlalchemy_tests/mixins.py +++ b/tests/opentelemetry-docker-tests/tests/sqlalchemy_tests/mixins.py @@ -13,10 +13,12 @@ # limitations under the License. import contextlib +import logging +import threading -from sqlalchemy import Column, Integer, String, create_engine +from sqlalchemy import Column, Integer, String, create_engine, insert from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm import close_all_sessions, scoped_session, sessionmaker from opentelemetry import trace from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor @@ -199,3 +201,45 @@ def test_parent(self): self.assertEqual(parent_span.instrumentation_info.name, "sqlalch_svc") self.assertEqual(child_span.name, "SELECT " + self.SQL_DB) + + def test_multithreading(self): + """Ensure spans are captured correctly in a multithreading scenario + + We also expect no logged warnings about calling end() on an ended span. + """ + + if self.VENDOR == "sqlite": + return + + def insert_player(session): + _session = session() + player = Player(name="Player") + _session.add(player) + _session.commit() + _session.query(Player).all() + + def insert_players(session): + _session = session() + players = [] + for player_number in range(3): + players.append(Player(name=f"Player {player_number}")) + _session.add_all(players) + _session.commit() + + session_factory = sessionmaker(bind=self.engine) + # pylint: disable=invalid-name + Session = scoped_session(session_factory) + thread_one = threading.Thread(target=insert_player, args=(Session,)) + thread_two = threading.Thread(target=insert_players, args=(Session,)) + + logger = logging.getLogger("opentelemetry.sdk.trace") + with self.assertRaises(AssertionError): + with self.assertLogs(logger, level="WARNING"): + thread_one.start() + thread_two.start() + thread_one.join() + thread_two.join() + close_all_sessions() + + spans = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans), 5)