Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix issue with Flask instrumentation when a request spawn children threads and copies the request context #1654

Merged
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

- Fix Flask instrumentation to only close the span if it was created by the same thread.
([#1654](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/1654))
srikanthccv marked this conversation as resolved.
Show resolved Hide resolved
- Fix TortoiseORM instrumentation `AttributeError: type object 'Config' has no attribute 'title'`
([#1575](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/1575))
- Fix SQLAlchemy uninstrumentation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,8 +238,8 @@ def response_hook(span: Span, status: str, response_headers: List):
API
---
"""

from logging import getLogger
from threading import get_ident
from time import time_ns
from timeit import default_timer
from typing import Collection
Expand All @@ -265,6 +265,7 @@ def response_hook(span: Span, status: str, response_headers: List):
_ENVIRON_STARTTIME_KEY = "opentelemetry-flask.starttime_key"
_ENVIRON_SPAN_KEY = "opentelemetry-flask.span_key"
_ENVIRON_ACTIVATION_KEY = "opentelemetry-flask.activation_key"
_ENVIRON_THREAD_ID_KEY = "opentelemetry-flask.thread_id_key"
_ENVIRON_TOKEN = "opentelemetry-flask.token"

_excluded_urls_from_env = get_excluded_urls("FLASK")
Expand Down Expand Up @@ -398,6 +399,7 @@ def _before_request():
activation = trace.use_span(span, end_on_exit=True)
activation.__enter__() # pylint: disable=E1101
flask_request_environ[_ENVIRON_ACTIVATION_KEY] = activation
flask_request_environ[_ENVIRON_THREAD_ID_KEY] = get_ident()
flask_request_environ[_ENVIRON_SPAN_KEY] = span
flask_request_environ[_ENVIRON_TOKEN] = token

Expand Down Expand Up @@ -437,10 +439,17 @@ def _teardown_request(exc):
return

activation = flask.request.environ.get(_ENVIRON_ACTIVATION_KEY)
if not activation:
thread_id = flask.request.environ.get(_ENVIRON_THREAD_ID_KEY)
if not activation or thread_id != get_ident():
# This request didn't start a span, maybe because it was created in
# a way that doesn't run `before_request`, like when it is created
# with `app.test_request_context`.
srikanthccv marked this conversation as resolved.
Show resolved Hide resolved
#
# Similarly, check the thread_id against the current thread to ensure
# tear down only happens on the original thread. This situation can
# arise if the original thread handling the request spawn children
# threads and then uses something like copy_current_request_context
# to copy the request context.
return
if exc is None:
activation.__exit__(None, None, None)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from concurrent.futures import ThreadPoolExecutor, as_completed
from random import randint

import flask
from werkzeug.test import Client
from werkzeug.wrappers import Response
Expand All @@ -34,6 +37,25 @@ def _sqlcommenter_endpoint():
)
return sqlcommenter_flask_values

@staticmethod
def _multithreaded_endpoint(count):
def do_random_stuff():
@flask.copy_current_request_context
def inner():
return randint(0, 100)

return inner

executor = ThreadPoolExecutor(count)
futures = []
for _ in range(count):
futures.append(executor.submit(do_random_stuff()))
numbers = []
for future in as_completed(futures):
numbers.append(future.result())

return " ".join([str(i) for i in numbers])

@staticmethod
def _custom_response_headers():
resp = flask.Response("test response")
Expand Down Expand Up @@ -61,6 +83,7 @@ def excluded2_endpoint():
# pylint: disable=no-member
self.app.route("/hello/<int:helloid>")(self._hello_endpoint)
self.app.route("/sqlcommenter")(self._sqlcommenter_endpoint)
self.app.route("/multithreaded")(self._multithreaded_endpoint)
self.app.route("/excluded/<int:helloid>")(self._hello_endpoint)
self.app.route("/excluded")(excluded_endpoint)
self.app.route("/excluded2")(excluded2_endpoint)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright The OpenTelemetry Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import flask
from werkzeug.test import Client
from werkzeug.wrappers import Response

from opentelemetry.instrumentation.flask import FlaskInstrumentor
from opentelemetry.test.wsgitestutil import WsgiTestBase

# pylint: disable=import-error
from .base_test import InstrumentationTest


class TestMultiThreading(InstrumentationTest, WsgiTestBase):
def setUp(self):
super().setUp()
FlaskInstrumentor().instrument()
self.app = flask.Flask(__name__)
self._common_initialization()

def tearDown(self):
super().tearDown()
with self.disable_logging():
FlaskInstrumentor().uninstrument()

def test_multithreaded(self):
"""Test that instrumentation tear down does not blow up
when the request thread spawn children threads and the request
context is copied to the children threads
"""
self.app = flask.Flask(__name__)
self.app.route("/multithreaded/<int:count>")(
self._multithreaded_endpoint
)
client = Client(self.app, Response)
count = 5
resp = client.get(f"/multithreaded/{count}")
self.assertEqual(200, resp.status_code)
# Should return the specified number of random integers
self.assertEqual(count, len(resp.text.split(" ")))