diff --git a/aiohttp/web.py b/aiohttp/web.py index 7b0bf4e2ff4..24ed8b6216d 100644 --- a/aiohttp/web.py +++ b/aiohttp/web.py @@ -8,7 +8,7 @@ from yarl import URL from . import (hdrs, web_exceptions, web_reqrep, web_server, web_urldispatcher, - web_ws) + web_ws, web_middlewares) from .abc import AbstractMatchInfo, AbstractRouter from .helpers import FrozenList, sentinel from .log import access_logger, web_logger @@ -19,6 +19,7 @@ from .web_server import Server from .web_urldispatcher import * # noqa from .web_urldispatcher import PrefixedSubAppResource +from .web_middlewares import * # noqa from .web_ws import * # noqa __all__ = (web_reqrep.__all__ + @@ -26,6 +27,7 @@ web_urldispatcher.__all__ + web_ws.__all__ + web_server.__all__ + + web_middlewares.__all__ + ('Application', 'HttpVersion', 'MsgType')) diff --git a/aiohttp/web_middlewares.py b/aiohttp/web_middlewares.py new file mode 100644 index 00000000000..acbe1d1b467 --- /dev/null +++ b/aiohttp/web_middlewares.py @@ -0,0 +1,75 @@ +import asyncio +import re + +from aiohttp.web_exceptions import HTTPMovedPermanently +from aiohttp.web_urldispatcher import SystemRoute + + +__all__ = ( + 'normalize_path_middleware', +) + + +@asyncio.coroutine +def _check_request_resolves(request, path): + alt_request = request.clone(rel_url=path) + + match_info = yield from request.app.router.resolve(alt_request) + alt_request._match_info = match_info + + if not isinstance(match_info.route, SystemRoute): + return True, alt_request + + return False, request + + +def normalize_path_middleware( + *, append_slash=True, merge_slashes=True, + redirect_class=HTTPMovedPermanently): + """ + Middleware that normalizes the path of a request. By normalizing + it means: + + - Add a trailing slash to the path. + - Double slashes are replaced by one. + + The middleware returns as soon as it finds a path that resolves + correctly. The order if all enable is 1) merge_slashes, 2) append_slash + and 3) both merge_slashes and append_slash. If the path resolves with + at least one of those conditions, it will redirect to the new path. + + If append_slash is True append slash when needed. If a resource is + defined with trailing slash and the request comes without it, it will + append it automatically. + + If merge_slashes is True, merge multiple consecutive slashes in the + path into one. + """ + + @asyncio.coroutine + def normalize_path_factory(app, handler): + + @asyncio.coroutine + def middleware(request): + + if isinstance(request.match_info.route, SystemRoute): + paths_to_check = [] + if merge_slashes: + paths_to_check.append(re.sub('//+', '/', request.path)) + if append_slash and not request.path.endswith('/'): + paths_to_check.append(request.path + '/') + if merge_slashes and append_slash: + paths_to_check.append( + re.sub('//+', '/', request.path + '/')) + + for path in paths_to_check: + resolves, request = yield from _check_request_resolves( + request, path) + if resolves: + return redirect_class(request.path) + + return (yield from handler(request)) + + return middleware + + return normalize_path_factory diff --git a/docs/web_reference.rst b/docs/web_reference.rst index b55bfd8b4df..3a5bb5607cc 100644 --- a/docs/web_reference.rst +++ b/docs/web_reference.rst @@ -2185,5 +2185,33 @@ Constants *no compression* + +Middlewares +----------- + +Normalize path middleware +^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. function:: normalize_path_middleware(*, append_slash=True, merge_slashes=True) + + Middleware that normalizes the path of a request. By normalizing + it means: + + - Add a trailing slash to the path. + - Double slashes are replaced by one. + + The middleware returns as soon as it finds a path that resolves + correctly. The order if all enabled is 1) merge_slashes, 2) append_slash + and 3) both merge_slashes and append_slash. If the path resolves with + at least one of those conditions, it will redirect to the new path. + + If append_slash is True append slash when needed. If a resource is + defined with trailing slash and the request comes without it, it will + append it automatically. + + If merge_slashes is True, merge multiple consecutive slashes in the + path into one. + + .. disqus:: :title: aiohttp server reference diff --git a/setup.cfg b/setup.cfg index efeed68aa18..090855dab3c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,8 +1,12 @@ +[pep8] +max-line-length=79 + [easy_install] zip_ok = false [flake8] ignore = N801,N802,N803,E226 +max-line-length=79 [tool:pytest] timeout = 10 diff --git a/tests/test_web_middleware.py b/tests/test_web_middleware.py index f262a351a4d..1166bcfef4a 100644 --- a/tests/test_web_middleware.py +++ b/tests/test_web_middleware.py @@ -92,3 +92,94 @@ def middleware(request): assert 200 == resp.status txt = yield from resp.text() assert 'OK[2][1]' == txt + + +@pytest.fixture +def cli(loop, test_client): + def wrapper(extra_middlewares): + app = web.Application(loop=loop) + app.router.add_route( + 'GET', '/resource1', lambda x: web.Response(text="OK")) + app.router.add_route( + 'GET', '/resource2/', lambda x: web.Response(text="OK")) + app.router.add_route( + 'GET', '/resource1/a/b', lambda x: web.Response(text="OK")) + app.router.add_route( + 'GET', '/resource2/a/b/', lambda x: web.Response(text="OK")) + app.middlewares.extend(extra_middlewares) + return test_client(app) + return wrapper + + +class TestNormalizePathMiddleware: + + @asyncio.coroutine + @pytest.mark.parametrize("path, status", [ + ('/resource1', 200), + ('/resource1/', 404), + ('/resource2', 200), + ('/resource2/', 200) + ]) + def test_add_trailing_when_necessary( + self, path, status, cli): + extra_middlewares = [ + web.normalize_path_middleware(merge_slashes=False)] + client = yield from cli(extra_middlewares) + + resp = yield from client.get(path) + assert resp.status == status + + @asyncio.coroutine + @pytest.mark.parametrize("path, status", [ + ('/resource1', 200), + ('/resource1/', 404), + ('/resource2', 404), + ('/resource2/', 200) + ]) + def test_no_trailing_slash_when_disabled( + self, path, status, cli): + extra_middlewares = [ + web.normalize_path_middleware( + append_slash=False, merge_slashes=False)] + client = yield from cli(extra_middlewares) + + resp = yield from client.get(path) + assert resp.status == status + + @asyncio.coroutine + @pytest.mark.parametrize("path, status", [ + ('/resource1/a/b', 200), + ('///resource1//a//b', 200), + ('/////resource1/a///b', 200), + ('/////resource1/a//b/', 404) + ]) + def test_merge_slash(self, path, status, cli): + extra_middlewares = [ + web.normalize_path_middleware(append_slash=False)] + client = yield from cli(extra_middlewares) + + resp = yield from client.get(path) + assert resp.status == status + + @asyncio.coroutine + @pytest.mark.parametrize("path, status", [ + ('/resource1/a/b', 200), + ('/resource1/a/b/', 404), + ('///resource1//a//b', 200), + ('///resource1//a//b/', 404), + ('/////resource1/a///b', 200), + ('/////resource1/a///b/', 404), + ('/resource2/a/b', 200), + ('///resource2//a//b', 200), + ('///resource2//a//b/', 200), + ('/////resource2/a///b', 200), + ('/////resource2/a///b/', 200) + ]) + def test_append_and_merge_slash(self, path, status, cli): + extra_middlewares = [ + web.normalize_path_middleware()] + + client = yield from cli(extra_middlewares) + + resp = yield from client.get(path) + assert resp.status == status