diff --git a/ext/opentelemetry-ext-wsgi/src/opentelemetry/ext/wsgi/__init__.py b/ext/opentelemetry-ext-wsgi/src/opentelemetry/ext/wsgi/__init__.py index 28caa77e52b..61f061948d3 100644 --- a/ext/opentelemetry-ext-wsgi/src/opentelemetry/ext/wsgi/__init__.py +++ b/ext/opentelemetry-ext-wsgi/src/opentelemetry/ext/wsgi/__init__.py @@ -44,14 +44,34 @@ def _add_request_attributes(span, environ): span.set_attribute("component", "http") span.set_attribute("http.method", environ["REQUEST_METHOD"]) - host = environ.get("HTTP_HOST") or environ["SERVER_NAME"] + host = environ.get("HTTP_HOST") + if not host: + host = environ["SERVER_NAME"] + port = environ["SERVER_PORT"] + if ( + port != "80" + and environ["wsgi.url_scheme"] == "http" + or port != "443" + ): + host += ":" + port + + # NOTE: Nonstandard span.set_attribute("http.host", host) - url = ( - environ.get("REQUEST_URI") - or environ.get("RAW_URI") - or wsgiref_util.request_uri(environ, include_query=False) - ) + url = environ.get("REQUEST_URI") or environ.get("RAW_URI") + + if url: + if url[0] == "/": + # We assume that no scheme-relative URLs will be in url here. + # After all, if a request is made to http://myserver//foo, we may get + # //foo which looks like scheme-relative but isn't. + url = environ["wsgi.url_scheme"] + "://" + host + url + elif not url.startswith(environ["wsgi.url_scheme"] + ":"): + # Something fishy is in RAW_URL. Let's fall back to request_uri() + url = wsgiref_util.request_uri(environ) + else: + url = wsgiref_util.request_uri(environ) + span.set_attribute("http.url", url) @staticmethod @@ -85,24 +105,27 @@ def __call__(self, environ, start_response): tracer = trace.tracer() path_info = environ["PATH_INFO"] or "/" - parent_span = propagators.extract(get_header_from_environ, environ) + parent_span = propagators.extract(_get_header_from_environ, environ) - with tracer.start_span( + span = tracer.create_span( path_info, parent_span, kind=trace.SpanKind.SERVER - ) as span: - self._add_request_attributes(span, environ) - start_response = self._create_start_response(span, start_response) + ) + span.start() + try: + with tracer.use_span(span): + self._add_request_attributes(span, environ) + start_response = self._create_start_response( + span, start_response + ) - iterable = self.wsgi(environ, start_response) - try: - for yielded in iterable: - yield yielded - finally: - if hasattr(iterable, "close"): - iterable.close() + iterable = self.wsgi(environ, start_response) + return _end_span_after_iterating(iterable, span, tracer) + except: # noqa + span.end() + raise -def get_header_from_environ( +def _get_header_from_environ( environ: dict, header_name: str ) -> typing.List[str]: """Retrieve the header value from the wsgi environ dictionary. @@ -115,3 +138,18 @@ def get_header_from_environ( if value: return [value] return [] + + +# Put this in a subfunction to not delay the call to the wrapped +# WSGI application (instrumentation should change the application +# behavior as little as possible). +def _end_span_after_iterating(iterable, span, tracer): + try: + with tracer.use_span(span): + for yielded in iterable: + yield yielded + finally: + close = getattr(iterable, "close", None) + if close: + close() + span.end() diff --git a/ext/opentelemetry-ext-wsgi/tests/test_wsgi_middleware.py b/ext/opentelemetry-ext-wsgi/tests/test_wsgi_middleware.py index 52cffc051ad..cd1778c492e 100644 --- a/ext/opentelemetry-ext-wsgi/tests/test_wsgi_middleware.py +++ b/ext/opentelemetry-ext-wsgi/tests/test_wsgi_middleware.py @@ -17,6 +17,7 @@ import unittest import unittest.mock as mock import wsgiref.util as wsgiref_util +from urllib.parse import urlparse from opentelemetry import trace as trace_api from opentelemetry.ext.wsgi import OpenTelemetryMiddleware @@ -52,6 +53,15 @@ def iter_wsgi(environ, start_response): return iter_wsgi +def create_gen_wsgi(response): + def gen_wsgi(environ, start_response): + result = create_iter_wsgi(response)(environ, start_response) + yield from result + getattr(result, "close", lambda: None)() + + return gen_wsgi + + def error_wsgi(environ, start_response): assert isinstance(environ, dict) try: @@ -66,18 +76,15 @@ def error_wsgi(environ, start_response): class TestWsgiApplication(unittest.TestCase): def setUp(self): tracer = trace_api.tracer() - self.span_context_manager = mock.MagicMock() - self.span_context_manager.__enter__.return_value = mock.create_autospec( - trace_api.Span, spec_set=True - ) - self.patcher = mock.patch.object( + self.span = mock.create_autospec(trace_api.Span, spec_set=True) + self.create_span_patcher = mock.patch.object( tracer, - "start_span", + "create_span", autospec=True, spec_set=True, - return_value=self.span_context_manager, + return_value=self.span, ) - self.start_span = self.patcher.start() + self.create_span = self.create_span_patcher.start() self.write_buffer = io.BytesIO() self.write = self.write_buffer.write @@ -90,11 +97,11 @@ def setUp(self): self.exc_info = None def tearDown(self): - self.patcher.stop() + self.create_span_patcher.stop() def start_response(self, status, response_headers, exc_info=None): # The span should have started already - self.span_context_manager.__enter__.assert_called_with() + self.span.start.assert_called_once_with() self.status = status self.response_headers = response_headers @@ -105,12 +112,10 @@ def validate_response(self, response, error=None): while True: try: value = next(response) - self.span_context_manager.__exit__.assert_not_called() + self.assertEqual(0, self.span.end.call_count) self.assertEqual(value, b"*") except StopIteration: - self.span_context_manager.__exit__.assert_called_with( - None, None, None - ) + self.span.end.assert_called_once_with() break self.assertEqual(self.status, "200 OK") @@ -125,9 +130,10 @@ def validate_response(self, response, error=None): self.assertIsNone(self.exc_info) # Verify that start_span has been called - self.start_span.assert_called_once_with( + self.create_span.assert_called_with( "/", trace_api.INVALID_SPAN_CONTEXT, kind=trace_api.SpanKind.SERVER ) + self.span.start.assert_called_with() def test_basic_wsgi_call(self): app = OpenTelemetryMiddleware(simple_wsgi) @@ -139,12 +145,24 @@ def test_wsgi_iterable(self): iter_wsgi = create_iter_wsgi(original_response) app = OpenTelemetryMiddleware(iter_wsgi) response = app(self.environ, self.start_response) - # Verify that start_response has not been called yet + # Verify that start_response has been called + self.assertTrue(self.status) + self.validate_response(response) + + # Verify that close has been called exactly once + self.assertEqual(original_response.close_calls, 1) + + def test_wsgi_generator(self): + original_response = Response() + gen_wsgi = create_gen_wsgi(original_response) + app = OpenTelemetryMiddleware(gen_wsgi) + response = app(self.environ, self.start_response) + # Verify that start_response has not been called self.assertIsNone(self.status) self.validate_response(response) # Verify that close has been called exactly once - assert original_response.close_calls == 1 + self.assertEqual(original_response.close_calls, 1) def test_wsgi_exc_info(self): app = OpenTelemetryMiddleware(error_wsgi) @@ -159,18 +177,87 @@ def setUp(self): self.span = mock.create_autospec(trace_api.Span, spec_set=True) def test_request_attributes(self): + self.environ["QUERY_STRING"] = "foo=bar" + OpenTelemetryMiddleware._add_request_attributes( # noqa pylint: disable=protected-access self.span, self.environ ) + expected = ( mock.call("component", "http"), mock.call("http.method", "GET"), mock.call("http.host", "127.0.0.1"), - mock.call("http.url", "http://127.0.0.1/"), + mock.call("http.url", "http://127.0.0.1/?foo=bar"), ) self.assertEqual(self.span.set_attribute.call_count, len(expected)) self.span.set_attribute.assert_has_calls(expected, any_order=True) + def validate_url(self, expected_url): + OpenTelemetryMiddleware._add_request_attributes( # noqa pylint: disable=protected-access + self.span, self.environ + ) + attrs = { + args[0][0]: args[0][1] + for args in self.span.set_attribute.call_args_list + } + self.assertIn("http.url", attrs) + self.assertEqual(attrs["http.url"], expected_url) + self.assertIn("http.host", attrs) + self.assertEqual( + attrs["http.host"], urlparse(attrs["http.url"]).netloc + ) + + def test_request_attributes_with_partial_raw_uri(self): + self.environ["RAW_URI"] = "/#top" + self.validate_url("http://127.0.0.1/#top") + + def test_request_attributes_with_partial_raw_uri_and_nonstandard_port( + self + ): + self.environ["RAW_URI"] = "/?" + del self.environ["HTTP_HOST"] + self.environ["SERVER_PORT"] = "8080" + self.validate_url("http://127.0.0.1:8080/?") + + def test_https_uri_port(self): + del self.environ["HTTP_HOST"] + self.environ["SERVER_PORT"] = "443" + self.environ["wsgi.url_scheme"] = "https" + self.validate_url("https://127.0.0.1/") + + self.environ["SERVER_PORT"] = "8080" + self.validate_url("https://127.0.0.1:8080/") + + self.environ["SERVER_PORT"] = "80" + self.validate_url("https://127.0.0.1:80/") + + def test_request_attributes_with_nonstandard_port_and_no_host(self): + del self.environ["HTTP_HOST"] + self.environ["SERVER_PORT"] = "8080" + self.validate_url("http://127.0.0.1:8080/") + + self.environ["SERVER_PORT"] = "443" + self.validate_url("http://127.0.0.1:443/") + + def test_request_attributes_with_nonstandard_port(self): + self.environ["HTTP_HOST"] += ":8080" + self.validate_url("http://127.0.0.1:8080/") + + def test_request_attributes_with_faux_scheme_relative_raw_uri(self): + self.environ["RAW_URI"] = "//127.0.0.1/?" + self.validate_url("http://127.0.0.1//127.0.0.1/?") + + def test_request_attributes_with_pathless_raw_uri(self): + self.environ["PATH_INFO"] = "" + self.environ["RAW_URI"] = "http://hello" + self.environ["HTTP_HOST"] = "hello" + self.validate_url("http://hello") + + def test_request_attributes_with_full_request_uri(self): + self.environ["HTTP_HOST"] = "127.0.0.1:8080" + self.environ["REQUEST_URI"] = "http://127.0.0.1:8080/?foo=bar#top" + self.validate_url("http://127.0.0.1:8080/?foo=bar#top") + def test_response_attributes(self): OpenTelemetryMiddleware._add_response_attributes( # noqa pylint: disable=protected-access self.span, "404 Not Found"