diff --git a/src/labthings/views/__init__.py b/src/labthings/views/__init__.py index 07d241ed..0e2c0a96 100644 --- a/src/labthings/views/__init__.py +++ b/src/labthings/views/__init__.py @@ -1,5 +1,5 @@ from flask.views import MethodView, http_method_funcs -from flask import request +from flask import request, abort from werkzeug.wrappers import Response as ResponseBase from werkzeug.exceptions import BadRequest @@ -107,6 +107,21 @@ def get_value(self): else: # Unless somehow an HTTP response isn't returned... return response + def _find_request_method(self): + meth = getattr(self, request.method.lower(), None) + if meth is None and request.method == "HEAD": + meth = getattr(self, "get", None) + + # Handle the case of a GET request asking for WS upgrade where + # no websocket method is defined on the view + if request.method == "GET" and request.environ.get("wsgi.websocket"): + ws_meth = getattr(self, "websocket", None) + if ws_meth is None: + abort(400, "Unable to upgrade websocket connection") + return ws_meth + + return meth + def dispatch_request(self, *args, **kwargs): """ @@ -114,12 +129,7 @@ def dispatch_request(self, *args, **kwargs): :param **kwargs: """ - meth = getattr(self, request.method.lower(), None) - - # If the request method is HEAD and we don't have a handler for it - # retry with GET. - if meth is None and request.method == "HEAD": - meth = getattr(self, "get", None) + meth = self._find_request_method() # Generate basic response return self.represent_response(meth(*args, **kwargs)) @@ -272,7 +282,7 @@ def dispatch_request(self, *args, **kwargs): :param **kwargs: """ - meth = getattr(self, request.method.lower(), None) + meth = self._find_request_method() # Let base View handle non-POST requests if request.method != "POST": @@ -414,12 +424,7 @@ def dispatch_request(self, *args, **kwargs): :param **kwargs: """ - meth = getattr(self, request.method.lower(), None) - - # If the request method is HEAD and we don't have a handler for it - # retry with GET. - if meth is None and request.method == "HEAD": - meth = getattr(self, "get", None) + meth = self._find_request_method() # POST and PUT methods can be used to write properties # In all other cases, ignore arguments