Skip to content

Commit

Permalink
Merge pull request #1709 from pgjones/websocket
Browse files Browse the repository at this point in the history
Add support for WebSocket rules in the routing
  • Loading branch information
davidism authored Feb 4, 2020
2 parents 49cf35b + ecd0d75 commit 13b6ef0
Show file tree
Hide file tree
Showing 4 changed files with 175 additions and 31 deletions.
2 changes: 2 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ Unreleased
quality tags. Instead the initial order is preserved. :issue:`1686`
- Added ``Map.lock_class`` attribute for alternative
implementations. :pr:`1702`
- Support matching and building WebSocket rules in the routing system,
for use by async frameworks. :pr:`1709`


Version 0.16.1
Expand Down
36 changes: 36 additions & 0 deletions docs/routing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -227,3 +227,39 @@ Variable parts are of course also possible in the host section::
Rule('/', endpoint='www_index', host='www.example.com'),
Rule('/', endpoint='user_index', host='<user>.example.com')
], host_matching=True)


WebSockets
==========

.. versionadded:: 1.0

If a :class:`Rule` is created with ``websocket=True``, it will only
match if the :class:`Map` is bound to a request with a ``url_scheme`` of
``ws`` or ``wss``.

.. note::

Werkzeug has no further WebSocket support beyond routing. This
functionality is mostly of use to ASGI projects.

.. code-block:: python
url_map = Map([
Rule("/ws", endpoint="comm", websocket=True),
])
adapter = map.bind("example.org", "/ws", url_scheme="ws")
assert adapter.match() == ("comm", {})
If the only match is a WebSocket rule and the bind is HTTP (or the
only match is HTTP and the bind is WebSocket) a
:exc:`WebsocketMismatch` (derives from
:exc:`~werkzeug.exceptions.BadRequest`) exception is raised.

As WebSocket URLs have a different scheme, rules are always built with a
scheme and host, ``force_external=True`` is implied.

.. code-block:: python
url = adapter.build("comm")
assert url == "ws://example.org/ws"
137 changes: 106 additions & 31 deletions src/werkzeug/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@
from .datastructures import ImmutableDict
from .datastructures import MultiDict
from .exceptions import BadHost
from .exceptions import BadRequest
from .exceptions import HTTPException
from .exceptions import MethodNotAllowed
from .exceptions import NotFound
Expand Down Expand Up @@ -329,6 +330,12 @@ def __str__(self):
return u"".join(message)


class WebsocketMismatch(BadRequest):
"""The only matched rule is either a WebSocket and the request is
HTTP, or the rule is HTTP and the request is a WebSocket.
"""


class ValidationError(ValueError):
"""Validation error. If a rule converter raises this exception the rule
does not match the current URL and the next URL is tried.
Expand Down Expand Up @@ -576,21 +583,12 @@ class Rule(RuleFactory):
`MethodNotAllowed` rather than `NotFound`. If `GET` is present in the
list of methods and `HEAD` is not, `HEAD` is added automatically.
.. versionchanged:: 0.6.1
`HEAD` is now automatically added to the methods if `GET` is
present. The reason for this is that existing code often did not
work properly in servers not rewriting `HEAD` to `GET`
automatically and it was not documented how `HEAD` should be
treated. This was considered a bug in Werkzeug because of that.
`strict_slashes`
Override the `Map` setting for `strict_slashes` only for this rule. If
not specified the `Map` setting is used.
`merge_slashes`
Override the ``Map`` setting for ``merge_slashes`` for this rule.
.. versionadded:: 1.0
Override :attr:`Map.merge_slashes` for this rule.
`build_only`
Set this to True and the rule will never match but will create a URL
Expand Down Expand Up @@ -631,8 +629,22 @@ def foo_with_slug(adapter, id):
used to provide a match rule for the whole host. This also means
that the subdomain feature is disabled.
`websocket`
If ``True``, this rule is only matches for WebSocket (``ws://``,
``wss://``) requests. By default, rules will only match for HTTP
requests.
.. versionadded:: 1.0
Added ``websocket``.
.. versionadded:: 1.0
Added ``merge_slashes``.
.. versionadded:: 0.7
The `alias` and `host` parameters were added.
Added ``alias`` and ``host``.
.. versionchanged:: 0.6.1
``HEAD`` is added to ``methods`` if ``GET`` is present.
"""

def __init__(
Expand All @@ -648,6 +660,7 @@ def __init__(
redirect_to=None,
alias=False,
host=None,
websocket=False,
):
if not string.startswith("/"):
raise ValueError("urls must start with a leading slash")
Expand All @@ -662,14 +675,23 @@ def __init__(
self.defaults = defaults
self.build_only = build_only
self.alias = alias
if methods is None:
self.methods = None
else:
self.websocket = websocket

if methods is not None:
if isinstance(methods, str):
raise TypeError("param `methods` should be `Iterable[str]`, not `str`")
self.methods = set([x.upper() for x in methods])
if "HEAD" not in self.methods and "GET" in self.methods:
self.methods.add("HEAD")
raise TypeError("'methods' should be a list of strings.")

methods = {x.upper() for x in methods}

if "HEAD" not in methods and "GET" in methods:
methods.add("HEAD")

if websocket and methods - {"GET", "HEAD", "OPTIONS"}:
raise ValueError(
"WebSocket rules can only use 'GET', 'HEAD', and 'OPTIONS' methods."
)

self.methods = methods
self.endpoint = endpoint
self.redirect_to = redirect_to

Expand Down Expand Up @@ -1359,6 +1381,10 @@ class Map(object):
enabled the `host` parameter to rules is used
instead of the `subdomain` one.
.. versionchanged:: 1.0
If ``url_scheme`` is ``ws`` or ``wss``, only WebSocket rules
will match.
.. versionchanged:: 1.0
Added ``merge_slashes``.
Expand Down Expand Up @@ -1484,14 +1510,18 @@ def bind(
no defined. If there is no `default_subdomain` you cannot use the
subdomain feature.
.. versionadded:: 0.7
`query_args` added
.. versionadded:: 0.8
`query_args` can now also be a string.
.. versionchanged:: 1.0
If ``url_scheme`` is ``ws`` or ``wss``, only WebSocket rules
will match.
.. versionchanged:: 0.15
``path_info`` defaults to ``'/'`` if ``None``.
.. versionchanged:: 0.8
``query_args`` can be a string.
.. versionchanged:: 0.7
Added ``query_args``.
"""
server_name = server_name.lower()
if self.host_matching:
Expand Down Expand Up @@ -1663,6 +1693,7 @@ def __init__(
self.path_info = to_unicode(path_info)
self.default_method = to_unicode(default_method)
self.query_args = query_args
self.websocket = self.url_scheme in {"ws", "wss"}

def dispatch(
self, view_func, path_info=None, method=None, catch_http_exceptions=False
Expand Down Expand Up @@ -1720,7 +1751,14 @@ def application(environ, start_response):
return e
raise

def match(self, path_info=None, method=None, return_rule=False, query_args=None):
def match(
self,
path_info=None,
method=None,
return_rule=False,
query_args=None,
websocket=None,
):
"""The usage is simple: you just pass the match method the current
path info as well as the method (which defaults to `GET`). The
following things can then happen:
Expand All @@ -1741,6 +1779,11 @@ def match(self, path_info=None, method=None, return_rule=False, query_args=None)
You can use the `RequestRedirect` instance as response-like object
similar to all other subclasses of `HTTPException`.
- you receive a ``WebsocketMismatch`` exception if the only
match is a WebSocket rule but the bind is an HTTP request, or
if the match is an HTTP rule but the bind is a WebSocket
request.
- you get a tuple in the form ``(endpoint, arguments)`` if there is
a match (unless `return_rule` is True, in which case you get a tuple
in the form ``(rule, arguments)``)
Expand Down Expand Up @@ -1787,15 +1830,21 @@ def match(self, path_info=None, method=None, return_rule=False, query_args=None)
automatic redirects as string or dictionary. It's
currently not possible to use the query arguments
for URL matching.
:param websocket: Match WebSocket instead of HTTP requests. A
websocket request has a ``ws`` or ``wss``
:attr:`url_scheme`. This overrides that detection.
.. versionadded:: 0.6
`return_rule` was added.
.. versionadded:: 1.0
Added ``websocket``.
.. versionchanged:: 0.8
``query_args`` can be a string.
.. versionadded:: 0.7
`query_args` was added.
Added ``query_args``.
.. versionchanged:: 0.8
`query_args` can now also be a string.
.. versionadded:: 0.6
Added ``return_rule``.
"""
self.map.update()
if path_info is None:
Expand All @@ -1806,6 +1855,9 @@ def match(self, path_info=None, method=None, return_rule=False, query_args=None)
query_args = self.query_args
method = (method or self.default_method).upper()

if websocket is None:
websocket = self.websocket

require_redirect = False

path = u"%s|%s" % (
Expand All @@ -1814,6 +1866,8 @@ def match(self, path_info=None, method=None, return_rule=False, query_args=None)
)

have_match_for = set()
websocket_mismatch = False

for rule in self.map._rules:
try:
rv = rule.match(path, method)
Expand All @@ -1836,6 +1890,10 @@ def match(self, path_info=None, method=None, return_rule=False, query_args=None)
have_match_for.update(rule.methods)
continue

if rule.websocket != websocket:
websocket_mismatch = True
continue

if self.map.redirect_defaults:
redirect_url = self.get_default_redirect(rule, method, rv, query_args)
if redirect_url is not None:
Expand Down Expand Up @@ -1880,6 +1938,10 @@ def _handle_match(match):

if have_match_for:
raise MethodNotAllowed(valid_methods=list(have_match_for))

if websocket_mismatch:
raise WebsocketMismatch()

raise NotFound()

def test(self, path_info=None, method=None):
Expand Down Expand Up @@ -2005,6 +2067,7 @@ def _partial_build(self, endpoint, values, method, append_unknown):
rv = rule.build(values, append_unknown)

if rv is not None:
rv = (rv[0], rv[1], rule.websocket)
if self.map.host_matching:
if rv[0] == self.server_name:
return rv
Expand Down Expand Up @@ -2114,10 +2177,22 @@ def build(
rv = self._partial_build(endpoint, values, method, append_unknown)
if rv is None:
raise BuildError(endpoint, values, method, self)
domain_part, path = rv

domain_part, path, websocket = rv
host = self.get_host(domain_part)

# Always build WebSocket routes with the scheme (browsers
# require full URLs). If bound to a WebSocket, ensure that HTTP
# routes are built with an HTTP scheme.
url_scheme = self.url_scheme
secure = url_scheme in {"https", "wss"}

if websocket:
force_external = True
url_scheme = "wss" if secure else "ws"
elif url_scheme:
url_scheme = "https" if secure else "http"

# shortcut this.
if not force_external and (
(self.map.host_matching and host == self.server_name)
Expand All @@ -2127,7 +2202,7 @@ def build(
return str(
"%s//%s%s/%s"
% (
self.url_scheme + ":" if self.url_scheme else "",
url_scheme + ":" if url_scheme else "",
host,
self.script_name[:-1],
path.lstrip("/"),
Expand Down
Loading

0 comments on commit 13b6ef0

Please sign in to comment.