From 547a26b685d08cac0aa64e5e65f7867ac0ea9bc0 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 23 May 2021 18:51:27 +0200 Subject: [PATCH] Use constant-time comparison for passwords. Backport of c91b4c2a and dfecbd03. --- docs/changelog.rst | 6 ++++++ src/websockets/legacy/auth.py | 28 +++++++++++++++------------- tests/legacy/test_auth.py | 11 +++++++++-- 3 files changed, 30 insertions(+), 15 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 1064af736..f3e1acf08 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -30,6 +30,12 @@ They may change at any time. *In development* +.. note:: + + **Version 9.1 fixes a security issue introduced in version 8.0.** + + Version 8.0 was vulnerable to timing attacks on HTTP Basic Auth passwords. + 9.0.2 ..... diff --git a/src/websockets/legacy/auth.py b/src/websockets/legacy/auth.py index e0beede57..80ceff28d 100644 --- a/src/websockets/legacy/auth.py +++ b/src/websockets/legacy/auth.py @@ -6,6 +6,7 @@ import functools +import hmac import http from typing import Any, Awaitable, Callable, Iterable, Optional, Tuple, Union, cast @@ -132,24 +133,23 @@ def basic_auth_protocol_factory( if credentials is not None: if is_credentials(credentials): - - async def check_credentials(username: str, password: str) -> bool: - return (username, password) == credentials - + credentials_list = [cast(Credentials, credentials)] elif isinstance(credentials, Iterable): credentials_list = list(credentials) - if all(is_credentials(item) for item in credentials_list): - credentials_dict = dict(credentials_list) - - async def check_credentials(username: str, password: str) -> bool: - return credentials_dict.get(username) == password - - else: + if not all(is_credentials(item) for item in credentials_list): raise TypeError(f"invalid credentials argument: {credentials}") - else: raise TypeError(f"invalid credentials argument: {credentials}") + credentials_dict = dict(credentials_list) + + async def check_credentials(username: str, password: str) -> bool: + try: + expected_password = credentials_dict[username] + except KeyError: + return False + return hmac.compare_digest(expected_password, password) + if create_protocol is None: # Not sure why mypy cannot figure this out. create_protocol = cast( @@ -158,5 +158,7 @@ async def check_credentials(username: str, password: str) -> bool: ) return functools.partial( - create_protocol, realm=realm, check_credentials=check_credentials + create_protocol, + realm=realm, + check_credentials=check_credentials, ) diff --git a/tests/legacy/test_auth.py b/tests/legacy/test_auth.py index bb8c6a6eb..3d8eb90d7 100644 --- a/tests/legacy/test_auth.py +++ b/tests/legacy/test_auth.py @@ -1,3 +1,4 @@ +import hmac import unittest import urllib.error @@ -76,7 +77,7 @@ def test_basic_auth_bad_multiple_credentials(self): ) async def check_credentials(username, password): - return password == "iloveyou" + return hmac.compare_digest(password, "iloveyou") create_protocol_check_credentials = basic_auth_protocol_factory( realm="auth-tests", @@ -140,7 +141,13 @@ def test_basic_auth_unsupported_credentials_details(self): self.assertEqual(raised.exception.read().decode(), "Unsupported credentials\n") @with_server(create_protocol=create_protocol) - def test_basic_auth_invalid_credentials(self): + def test_basic_auth_invalid_username(self): + with self.assertRaises(InvalidStatusCode) as raised: + self.start_client(user_info=("goodbye", "iloveyou")) + self.assertEqual(raised.exception.status_code, 401) + + @with_server(create_protocol=create_protocol) + def test_basic_auth_invalid_password(self): with self.assertRaises(InvalidStatusCode) as raised: self.start_client(user_info=("hello", "ihateyou")) self.assertEqual(raised.exception.status_code, 401)