diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index a88deaa6..99ae58b4 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -13,6 +13,7 @@ * :exc:`InvalidProxyMessage` * :exc:`InvalidProxyStatus` * :exc:`InvalidMessage` + * :exc:`InvalidMethod` * :exc:`InvalidStatus` * :exc:`InvalidStatusCode` (legacy) * :exc:`InvalidHeader` @@ -53,6 +54,7 @@ "InvalidProxyMessage", "InvalidProxyStatus", "InvalidMessage", + "InvalidMethod", "InvalidStatus", "InvalidHeader", "InvalidHeaderFormat", @@ -242,6 +244,19 @@ class InvalidMessage(InvalidHandshake): """ +class InvalidMethod(InvalidHandshake): + """ + Raised when the HTTP method isn't GET. + + """ + + def __init__(self, method: str) -> None: + self.method = method + + def __str__(self) -> str: + return f"invalid HTTP method; expected GET; got {self.method}" + + class InvalidStatus(InvalidHandshake): """ Raised when a handshake response rejects the WebSocket upgrade. diff --git a/src/websockets/http11.py b/src/websockets/http11.py index 5af73eb0..9c0351e7 100644 --- a/src/websockets/http11.py +++ b/src/websockets/http11.py @@ -9,7 +9,7 @@ from typing import Callable from .datastructures import Headers -from .exceptions import SecurityError +from .exceptions import InvalidMethod, SecurityError from .version import version as websockets_version @@ -125,6 +125,7 @@ def parse( Raises: EOFError: If the connection is closed without a full HTTP request. SecurityError: If the request exceeds a security limit. + InvalidMethod: If the HTTP method isn't GET. ValueError: If the request isn't well formatted. """ @@ -148,7 +149,7 @@ def parse( f"unsupported protocol; expected HTTP/1.1: {d(request_line)}" ) if method != b"GET": - raise ValueError(f"unsupported HTTP method; expected GET; got {d(method)}") + raise InvalidMethod(d(method)) path = raw_path.decode("ascii", "surrogateescape") headers = yield from parse_headers(read_line) diff --git a/src/websockets/server.py b/src/websockets/server.py index de2c6354..e56180ea 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -15,6 +15,7 @@ InvalidHeader, InvalidHeaderValue, InvalidMessage, + InvalidMethod, InvalidOrigin, InvalidUpgrade, NegotiationError, @@ -547,6 +548,23 @@ def parse(self) -> Generator[None]: request = yield from Request.parse( self.reader.read_line, ) + except InvalidMethod as exc: + self.handshake_exc = exc + # Build 405 response without calling reject() to maintain layering + body = f"Failed to open a WebSocket connection: {exc}.\n".encode() + headers = Headers( + [ + ("Date", email.utils.formatdate(usegmt=True)), + ("Connection", "close"), + ("Content-Length", str(len(body))), + ("Content-Type", "text/plain; charset=utf-8"), + ("Allow", "GET"), + ] + ) + response = Response(405, "Method Not Allowed", headers, body) + self.send_response(response) + yield + return except Exception as exc: self.handshake_exc = InvalidMessage( "did not receive a valid HTTP request" @@ -556,6 +574,7 @@ def parse(self) -> Generator[None]: self.parser = self.discard() next(self.parser) # start coroutine yield + return if self.debug: self.logger.debug("< GET %s HTTP/1.1", request.path) diff --git a/tests/test_http11.py b/tests/test_http11.py index 3328b3b5..048aebd7 100644 --- a/tests/test_http11.py +++ b/tests/test_http11.py @@ -1,5 +1,5 @@ from websockets.datastructures import Headers -from websockets.exceptions import SecurityError +from websockets.exceptions import InvalidMethod, SecurityError from websockets.http11 import * from websockets.http11 import parse_headers from websockets.streams import StreamReader @@ -61,11 +61,11 @@ def test_parse_unsupported_protocol(self): def test_parse_unsupported_method(self): self.reader.feed_data(b"OPTIONS * HTTP/1.1\r\n\r\n") - with self.assertRaises(ValueError) as raised: + with self.assertRaises(InvalidMethod) as raised: next(self.parse()) self.assertEqual( str(raised.exception), - "unsupported HTTP method; expected GET; got OPTIONS", + "invalid HTTP method; expected GET; got OPTIONS", ) def test_parse_invalid_header(self): diff --git a/tests/test_server.py b/tests/test_server.py index 43970a7c..31aa1040 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -9,6 +9,7 @@ from websockets.exceptions import ( InvalidHeader, InvalidMessage, + InvalidMethod, InvalidOrigin, InvalidUpgrade, NegotiationError, @@ -257,6 +258,70 @@ def test_receive_junk_request(self): "invalid HTTP request line: HELO relay.invalid", ) + @patch("email.utils.formatdate", return_value=DATE) + def test_receive_head_request(self, _formatdate): + """Server receives a HEAD request and returns 405 Method Not Allowed.""" + server = ServerProtocol() + server.receive_data( + ( + f"HEAD /test HTTP/1.1\r\n" + f"Host: example.com\r\n" + f"\r\n" + ).encode(), + ) + + self.assertEqual(server.events_received(), []) + self.assertIsInstance(server.handshake_exc, InvalidMethod) + self.assertEqual(str(server.handshake_exc), "invalid HTTP method; expected GET; got HEAD") + self.assertEqual( + server.data_to_send(), + [ + f"HTTP/1.1 405 Method Not Allowed\r\n" + f"Date: {DATE}\r\n" + f"Connection: close\r\n" + f"Content-Length: 84\r\n" + f"Content-Type: text/plain; charset=utf-8\r\n" + f"Allow: GET\r\n" + f"\r\n" + f"Failed to open a WebSocket connection: " + f"invalid HTTP method; expected GET; got HEAD.\n".encode(), + b"", + ], + ) + self.assertTrue(server.close_expected()) + + @patch("email.utils.formatdate", return_value=DATE) + def test_receive_post_request(self, _formatdate): + """Server receives a POST request and returns 405 Method Not Allowed.""" + server = ServerProtocol() + server.receive_data( + ( + f"POST /test HTTP/1.1\r\n" + f"Host: example.com\r\n" + f"\r\n" + ).encode(), + ) + + self.assertEqual(server.events_received(), []) + self.assertIsInstance(server.handshake_exc, InvalidMethod) + self.assertEqual(str(server.handshake_exc), "invalid HTTP method; expected GET; got POST") + self.assertEqual( + server.data_to_send(), + [ + f"HTTP/1.1 405 Method Not Allowed\r\n" + f"Date: {DATE}\r\n" + f"Connection: close\r\n" + f"Content-Length: 84\r\n" + f"Content-Type: text/plain; charset=utf-8\r\n" + f"Allow: GET\r\n" + f"\r\n" + f"Failed to open a WebSocket connection: " + f"invalid HTTP method; expected GET; got POST.\n".encode(), + b"", + ], + ) + self.assertTrue(server.close_expected()) + class ResponseTests(unittest.TestCase): """Test generating opening handshake responses."""