From 86da7ffa939e60c81b4b74a40a63e148e9bdff09 Mon Sep 17 00:00:00 2001 From: Bar Shaul Date: Tue, 5 Jul 2022 11:00:00 +0300 Subject: [PATCH 01/23] A CredentialsProvider class has been added to allow the user to add his own provider for password rotation --- CHANGES | 1 + docs/examples/connection_examples.ipynb | 88 +++++++++++++++++++++- redis/__init__.py | 2 + redis/client.py | 2 + redis/cluster.py | 1 + redis/connection.py | 63 +++++++++++++--- tests/test_connection.py | 97 ++++++++++++++++++++++++- 7 files changed, 239 insertions(+), 15 deletions(-) diff --git a/CHANGES b/CHANGES index f5c267bdda..0a19bf73fd 100644 --- a/CHANGES +++ b/CHANGES @@ -23,6 +23,7 @@ * ClusterPipeline Doesn't Handle ConnectionError for Dead Hosts (#2225) * Remove compatibility code for old versions of Hiredis, drop Packaging dependency * The `deprecated` library is no longer a dependency + * Added CredentialsProvider class to support password rotation * 4.1.3 (Feb 8, 2022) * Fix flushdb and flushall (#1926) diff --git a/docs/examples/connection_examples.ipynb b/docs/examples/connection_examples.ipynb index b0084ff055..34e1ad0e81 100644 --- a/docs/examples/connection_examples.ipynb +++ b/docs/examples/connection_examples.ipynb @@ -97,6 +97,92 @@ "user_connection.ping()" ] }, + { + "cell_type": "markdown", + "metadata": { + "collapsed": false + }, + "source": [ + "## Connecting to a redis instance with AWS Secrets Manager credentials provider." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "import redis\n", + "import boto3\n", + "import json\n", + "import cachetools.func\n", + "\n", + "sm_client = boto3.client('secretsmanager')\n", + " \n", + "def sm_auth_provider(secret_id, version_id=None, version_stage='AWSCURRENT'):\n", + " @cachetools.func.ttl_cache(maxsize=128, ttl=24 * 60 * 60) #24h\n", + " def get_sm_user_credentials(secret_id, version_id, version_stage):\n", + " secret = sm_client.get_secret_value(secret_id, version_id)\n", + " return json.loads(secret['SecretString'])\n", + " creds = get_sm_user_credentials(secret_id, version_id, version_stage)\n", + " return creds['username'], creds['password']\n", + "\n", + "secret_id = \"EXAMPLE1-90ab-cdef-fedc-ba987SECRET1\"\n", + "creds_provider = redis.CredentialsProvider(supplier=sm_auth_provider, secret_id=secret_id)\n", + "user_connection = redis.Redis(host=\"localhost\", port=6379, credentials_provider=creds_provider)\n", + "user_connection.ping()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Connecting to a redis instance with ElastiCache IAM credentials provider." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import redis\n", + "import boto3\n", + "import cachetools.func\n", + "\n", + "ec_client = boto3.client('elasticache')\n", + "\n", + "def iam_auth_provider(user, endpoint, port=6379, region=\"us-east-1\"):\n", + " @cachetools.func.ttl_cache(maxsize=128, ttl=15 * 60) # 15m\n", + " def get_iam_auth_token(user, endpoint, port, region):\n", + " return ec_client.generate_iam_auth_token(user, endpoint, port, region)\n", + " iam_auth_token = get_iam_auth_token(endpoint, port, user, region)\n", + " return iam_auth_token\n", + "\n", + "username = \"barshaul\"\n", + "endpoint = \"test-001.use1.cache.amazonaws.com\"\n", + "creds_provider = redis.CredentialsProvider(supplier=iam_auth_provider, user=username,\n", + " endpoint=endpoint)\n", + "user_connection = redis.Redis(host=endpoint, port=6379, credentials_provider=creds_provider)\n", + "user_connection.ping()" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -176,4 +262,4 @@ }, "nbformat": 4, "nbformat_minor": 2 -} +} \ No newline at end of file diff --git a/redis/__init__.py b/redis/__init__.py index b7560a6715..777ba3ba90 100644 --- a/redis/__init__.py +++ b/redis/__init__.py @@ -6,6 +6,7 @@ BlockingConnectionPool, Connection, ConnectionPool, + CredentialsProvider, SSLConnection, UnixDomainSocketConnection, ) @@ -62,6 +63,7 @@ def int_or_str(value): "Connection", "ConnectionError", "ConnectionPool", + "CredentialsProvider", "DataError", "from_url", "InvalidResponse", diff --git a/redis/client.py b/redis/client.py index 75a0dac226..4d2ded2e22 100755 --- a/redis/client.py +++ b/redis/client.py @@ -938,6 +938,7 @@ def __init__( username=None, retry=None, redis_connect_func=None, + credentials_provider=None, ): """ Initialize a new Redis client. @@ -985,6 +986,7 @@ def __init__( "health_check_interval": health_check_interval, "client_name": client_name, "redis_connect_func": redis_connect_func, + "credentials_provider": credentials_provider, } # based on input, setup appropriate connection args if unix_socket_path is not None: diff --git a/redis/cluster.py b/redis/cluster.py index cee578b075..13ff353100 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -121,6 +121,7 @@ def parse_cluster_shards(resp, **options): "connection_class", "connection_pool", "client_name", + "credentials_provider", "db", "decode_responses", "encoding", diff --git a/redis/connection.py b/redis/connection.py index 2e33e31d2f..d04ffd5e55 100755 --- a/redis/connection.py +++ b/redis/connection.py @@ -476,6 +476,42 @@ def read_response(self, disable_decoding=False): DefaultParser = PythonParser +class CredentialsProvider: + def __init__(self, username="", password="", supplier=None, *args, **kwargs): + """ + Initialize a new Credentials Provider. + :param supplier: a supplier function that returns the username and password. + def supplier(arg1, arg2, ...) -> (username, password) + For examples see examples/connection_examples.ipynb + :param args: arguments to pass to the supplier function + :param kwargs: keyword arguments to pass to the supplier function + """ + self.username = username + self.password = password + self.supplier = supplier + self.args = args + self.kwargs = kwargs + + def get_credentials(self): + if self.supplier: + self.username, self.password = self.supplier(*self.args, **self.kwargs) + if self.username: + auth_args = (self.username, self.password or "") + else: + auth_args = (self.password,) + return auth_args + + def get_password(self, call_supplier=True): + if call_supplier and self.supplier: + self.username, self.password = self.supplier(*self.args, **self.kwargs) + return self.password + + def get_username(self, call_supplier=True): + if call_supplier and self.supplier: + self.username, self.password = self.supplier(*self.args, **self.kwargs) + return self.username + + class Connection: "Manages TCP communication to and from a Redis server" @@ -502,6 +538,7 @@ def __init__( username=None, retry=None, redis_connect_func=None, + credentials_provider=None, ): """ Initialize a new Connection. @@ -514,9 +551,10 @@ def __init__( self.host = host self.port = int(port) self.db = db - self.username = username self.client_name = client_name - self.password = password + self.credentials_provider = credentials_provider + if not self.credentials_provider and (username or password): + self.credentials_provider = CredentialsProvider(username, password) self.socket_timeout = socket_timeout self.socket_connect_timeout = socket_connect_timeout or socket_timeout self.socket_keepalive = socket_keepalive @@ -675,12 +713,9 @@ def on_connect(self): "Initialize the connection, authenticate and select a database" self._parser.on_connect(self) - # if username and/or password are set, authenticate - if self.username or self.password: - if self.username: - auth_args = (self.username, self.password or "") - else: - auth_args = (self.password,) + # if credentials provider is set, authenticate + if self.credentials_provider: + auth_args = self.credentials_provider.get_credentials() # avoid checking health here -- PING will fail if we try # to check the health prior to the AUTH self.send_command("AUTH", *auth_args, check_health=False) @@ -692,7 +727,11 @@ def on_connect(self): # server seems to be < 6.0.0 which expects a single password # arg. retry auth with just the password. # https://github.com/andymccurdy/redis-py/issues/1274 - self.send_command("AUTH", self.password, check_health=False) + self.send_command( + "AUTH", + self.credentials_provider.get_password(), + check_health=False, + ) auth_response = self.read_response() if str_if_bytes(auth_response) != "OK": @@ -1050,6 +1089,7 @@ def __init__( client_name=None, retry=None, redis_connect_func=None, + credentials_provider=None, ): """ Initialize a new UnixDomainSocketConnection. @@ -1061,9 +1101,10 @@ def __init__( self.pid = os.getpid() self.path = path self.db = db - self.username = username self.client_name = client_name - self.password = password + self.credentials_provider = credentials_provider + if not self.credentials_provider and (username or password): + self.credentials_provider = CredentialsProvider(username, password) self.socket_timeout = socket_timeout self.retry_on_timeout = retry_on_timeout if retry_on_error is SENTINEL: diff --git a/tests/test_connection.py b/tests/test_connection.py index d9251c31dc..abaa35ecfa 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -1,17 +1,25 @@ +import random import socket +import string import types from unittest import mock from unittest.mock import patch import pytest +import redis from redis.backoff import NoBackoff -from redis.connection import Connection -from redis.exceptions import ConnectionError, InvalidResponse, TimeoutError +from redis.connection import Connection, CredentialsProvider +from redis.exceptions import ( + ConnectionError, + InvalidResponse, + ResponseError, + TimeoutError, +) from redis.retry import Retry from redis.utils import HIREDIS_AVAILABLE -from .conftest import skip_if_server_version_lt +from .conftest import _get_client, skip_if_redis_enterprise, skip_if_server_version_lt @pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") @@ -122,3 +130,86 @@ def test_connect_timeout_error_without_retry(self): assert conn._connect.call_count == 1 assert str(e.value) == "Timeout connecting to server" self.clear(conn) + + +class TestCredentialsProvider: + @skip_if_redis_enterprise() + def test_credentials_provider_without_supplier(self, r, request): + # first, test for default user (`username` is supposed to be optional) + default_username = "default" + temp_pass = "temp_pass" + creds_provider = CredentialsProvider(default_username, temp_pass) + r.config_set("requirepass", temp_pass) + creds = creds_provider.get_credentials() + assert r.auth(creds[1], creds[0]) is True + assert r.auth(creds_provider.get_password()) is True + + # test for other users + username = "redis-py-auth" + password = "strong_password" + + def teardown(): + try: + r.auth(temp_pass) + except ResponseError: + r.auth("default", "") + r.config_set("requirepass", "") + r.acl_deluser(username) + + request.addfinalizer(teardown) + + assert r.acl_setuser( + username, + enabled=True, + passwords=["+" + password], + keys="~*", + commands=["+ping", "+command", "+info", "+select", "+flushdb", "+cluster"], + ) + + creds_provider2 = CredentialsProvider(username, password) + r2 = _get_client( + redis.Redis, request, flushdb=False, credentials_provider=creds_provider2 + ) + + assert r2.ping() is True + + @skip_if_redis_enterprise() + def test_credentials_provider_with_supplier(self, r, request): + import functools + + @functools.lru_cache(maxsize=10) + def auth_supplier(user, endpoint): + def get_random_string(length): + letters = string.ascii_lowercase + result_str = "".join(random.choice(letters) for i in range(length)) + return result_str + + auth_token = get_random_string(5) + user + "_" + endpoint + return user, auth_token + + username = "redis-py-auth" + creds_provider = CredentialsProvider( + supplier=auth_supplier, + user=username, + endpoint="localhost", + ) + password = creds_provider.get_password() + + assert r.acl_setuser( + username, + enabled=True, + passwords=["+" + password], + keys="~*", + commands=["+ping", "+command", "+info", "+select", "+flushdb", "+cluster"], + ) + + def teardown(): + r.acl_deluser(username) + + request.addfinalizer(teardown) + + r2 = _get_client( + redis.Redis, request, flushdb=False, credentials_provider=creds_provider + ) + + assert r2.ping() is True From 5dfddde86b06d64c966e5c6bc544eef608643972 Mon Sep 17 00:00:00 2001 From: Bar Shaul Date: Mon, 18 Jul 2022 12:20:50 +0300 Subject: [PATCH 02/23] Moved CredentialsProvider to a separate file, added type hints --- redis/__init__.py | 2 +- redis/connection.py | 37 +-------------- redis/credentials.py | 41 +++++++++++++++++ tests/test_connection.py | 97 ++------------------------------------- tests/test_credentials.py | 90 ++++++++++++++++++++++++++++++++++++ 5 files changed, 136 insertions(+), 131 deletions(-) create mode 100644 redis/credentials.py create mode 100644 tests/test_credentials.py diff --git a/redis/__init__.py b/redis/__init__.py index 777ba3ba90..7122e1a457 100644 --- a/redis/__init__.py +++ b/redis/__init__.py @@ -6,10 +6,10 @@ BlockingConnectionPool, Connection, ConnectionPool, - CredentialsProvider, SSLConnection, UnixDomainSocketConnection, ) +from redis.credentials import CredentialsProvider from redis.exceptions import ( AuthenticationError, AuthenticationWrongNumberOfArgsError, diff --git a/redis/connection.py b/redis/connection.py index d04ffd5e55..aef2ed68da 100755 --- a/redis/connection.py +++ b/redis/connection.py @@ -11,6 +11,7 @@ from urllib.parse import parse_qs, unquote, urlparse from redis.backoff import NoBackoff +from redis.credentials import CredentialsProvider from redis.exceptions import ( AuthenticationError, AuthenticationWrongNumberOfArgsError, @@ -476,42 +477,6 @@ def read_response(self, disable_decoding=False): DefaultParser = PythonParser -class CredentialsProvider: - def __init__(self, username="", password="", supplier=None, *args, **kwargs): - """ - Initialize a new Credentials Provider. - :param supplier: a supplier function that returns the username and password. - def supplier(arg1, arg2, ...) -> (username, password) - For examples see examples/connection_examples.ipynb - :param args: arguments to pass to the supplier function - :param kwargs: keyword arguments to pass to the supplier function - """ - self.username = username - self.password = password - self.supplier = supplier - self.args = args - self.kwargs = kwargs - - def get_credentials(self): - if self.supplier: - self.username, self.password = self.supplier(*self.args, **self.kwargs) - if self.username: - auth_args = (self.username, self.password or "") - else: - auth_args = (self.password,) - return auth_args - - def get_password(self, call_supplier=True): - if call_supplier and self.supplier: - self.username, self.password = self.supplier(*self.args, **self.kwargs) - return self.password - - def get_username(self, call_supplier=True): - if call_supplier and self.supplier: - self.username, self.password = self.supplier(*self.args, **self.kwargs) - return self.username - - class Connection: "Manages TCP communication to and from a Redis server" diff --git a/redis/credentials.py b/redis/credentials.py new file mode 100644 index 0000000000..d2e9c92911 --- /dev/null +++ b/redis/credentials.py @@ -0,0 +1,41 @@ +class CredentialsProvider: + def __init__( + self, + username: str = "", + password: str = "", + supplier: callable = None, + *args, + **kwargs, + ): + """ + Initialize a new Credentials Provider. + :param supplier: a supplier function that returns the username and password. + def supplier(arg1, arg2, ...) -> (username, password) + For examples see examples/connection_examples.ipynb + :param args: arguments to pass to the supplier function + :param kwargs: keyword arguments to pass to the supplier function + """ + self.username = username + self.password = password + self.supplier = supplier + self.args = args + self.kwargs = kwargs + + def get_credentials(self): + if self.supplier: + self.username, self.password = self.supplier(*self.args, **self.kwargs) + if self.username: + auth_args = (self.username, self.password or "") + else: + auth_args = (self.password,) + return auth_args + + def get_password(self, call_supplier: bool = True): + if call_supplier and self.supplier: + self.username, self.password = self.supplier(*self.args, **self.kwargs) + return self.password + + def get_username(self, call_supplier: bool = True): + if call_supplier and self.supplier: + self.username, self.password = self.supplier(*self.args, **self.kwargs) + return self.username diff --git a/tests/test_connection.py b/tests/test_connection.py index abaa35ecfa..d9251c31dc 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -1,25 +1,17 @@ -import random import socket -import string import types from unittest import mock from unittest.mock import patch import pytest -import redis from redis.backoff import NoBackoff -from redis.connection import Connection, CredentialsProvider -from redis.exceptions import ( - ConnectionError, - InvalidResponse, - ResponseError, - TimeoutError, -) +from redis.connection import Connection +from redis.exceptions import ConnectionError, InvalidResponse, TimeoutError from redis.retry import Retry from redis.utils import HIREDIS_AVAILABLE -from .conftest import _get_client, skip_if_redis_enterprise, skip_if_server_version_lt +from .conftest import skip_if_server_version_lt @pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") @@ -130,86 +122,3 @@ def test_connect_timeout_error_without_retry(self): assert conn._connect.call_count == 1 assert str(e.value) == "Timeout connecting to server" self.clear(conn) - - -class TestCredentialsProvider: - @skip_if_redis_enterprise() - def test_credentials_provider_without_supplier(self, r, request): - # first, test for default user (`username` is supposed to be optional) - default_username = "default" - temp_pass = "temp_pass" - creds_provider = CredentialsProvider(default_username, temp_pass) - r.config_set("requirepass", temp_pass) - creds = creds_provider.get_credentials() - assert r.auth(creds[1], creds[0]) is True - assert r.auth(creds_provider.get_password()) is True - - # test for other users - username = "redis-py-auth" - password = "strong_password" - - def teardown(): - try: - r.auth(temp_pass) - except ResponseError: - r.auth("default", "") - r.config_set("requirepass", "") - r.acl_deluser(username) - - request.addfinalizer(teardown) - - assert r.acl_setuser( - username, - enabled=True, - passwords=["+" + password], - keys="~*", - commands=["+ping", "+command", "+info", "+select", "+flushdb", "+cluster"], - ) - - creds_provider2 = CredentialsProvider(username, password) - r2 = _get_client( - redis.Redis, request, flushdb=False, credentials_provider=creds_provider2 - ) - - assert r2.ping() is True - - @skip_if_redis_enterprise() - def test_credentials_provider_with_supplier(self, r, request): - import functools - - @functools.lru_cache(maxsize=10) - def auth_supplier(user, endpoint): - def get_random_string(length): - letters = string.ascii_lowercase - result_str = "".join(random.choice(letters) for i in range(length)) - return result_str - - auth_token = get_random_string(5) + user + "_" + endpoint - return user, auth_token - - username = "redis-py-auth" - creds_provider = CredentialsProvider( - supplier=auth_supplier, - user=username, - endpoint="localhost", - ) - password = creds_provider.get_password() - - assert r.acl_setuser( - username, - enabled=True, - passwords=["+" + password], - keys="~*", - commands=["+ping", "+command", "+info", "+select", "+flushdb", "+cluster"], - ) - - def teardown(): - r.acl_deluser(username) - - request.addfinalizer(teardown) - - r2 = _get_client( - redis.Redis, request, flushdb=False, credentials_provider=creds_provider - ) - - assert r2.ping() is True diff --git a/tests/test_credentials.py b/tests/test_credentials.py new file mode 100644 index 0000000000..f199e7c87a --- /dev/null +++ b/tests/test_credentials.py @@ -0,0 +1,90 @@ +import random +import string + +import redis +from redis import ResponseError +from redis.credentials import CredentialsProvider +from tests.conftest import _get_client, skip_if_redis_enterprise + + +class TestCredentialsProvider: + @skip_if_redis_enterprise() + def test_credentials_provider_without_supplier(self, r, request): + # first, test for default user (`username` is supposed to be optional) + default_username = "default" + temp_pass = "temp_pass" + creds_provider = CredentialsProvider(default_username, temp_pass) + r.config_set("requirepass", temp_pass) + creds = creds_provider.get_credentials() + assert r.auth(creds[1], creds[0]) is True + assert r.auth(creds_provider.get_password()) is True + + # test for other users + username = "redis-py-auth" + password = "strong_password" + + def teardown(): + try: + r.auth(temp_pass) + except ResponseError: + r.auth("default", "") + r.config_set("requirepass", "") + r.acl_deluser(username) + + request.addfinalizer(teardown) + + assert r.acl_setuser( + username, + enabled=True, + passwords=["+" + password], + keys="~*", + commands=["+ping", "+command", "+info", "+select", "+flushdb", "+cluster"], + ) + + creds_provider2 = CredentialsProvider(username, password) + r2 = _get_client( + redis.Redis, request, flushdb=False, credentials_provider=creds_provider2 + ) + + assert r2.ping() is True + + @skip_if_redis_enterprise() + def test_credentials_provider_with_supplier(self, r, request): + import functools + + @functools.lru_cache(maxsize=10) + def auth_supplier(user, endpoint): + def get_random_string(length): + letters = string.ascii_lowercase + result_str = "".join(random.choice(letters) for i in range(length)) + return result_str + + auth_token = get_random_string(5) + user + "_" + endpoint + return user, auth_token + + username = "redis-py-auth" + creds_provider = CredentialsProvider( + supplier=auth_supplier, + user=username, + endpoint="localhost", + ) + password = creds_provider.get_password() + + assert r.acl_setuser( + username, + enabled=True, + passwords=["+" + password], + keys="~*", + commands=["+ping", "+command", "+info", "+select", "+flushdb", "+cluster"], + ) + + def teardown(): + r.acl_deluser(username) + + request.addfinalizer(teardown) + + r2 = _get_client( + redis.Redis, request, flushdb=False, credentials_provider=creds_provider + ) + + assert r2.ping() is True From 243b24423962847d6f0d0fc3ebb08c1bc08c0bea Mon Sep 17 00:00:00 2001 From: Bar Shaul Date: Tue, 9 Aug 2022 18:37:16 +0300 Subject: [PATCH 03/23] Changed username and password to properties --- redis/connection.py | 4 ++-- redis/credentials.py | 37 +++++++++++++++++++++++-------------- tests/test_credentials.py | 18 +++++++++++++----- 3 files changed, 38 insertions(+), 21 deletions(-) diff --git a/redis/connection.py b/redis/connection.py index aef2ed68da..5630c2736b 100755 --- a/redis/connection.py +++ b/redis/connection.py @@ -694,7 +694,7 @@ def on_connect(self): # https://github.com/andymccurdy/redis-py/issues/1274 self.send_command( "AUTH", - self.credentials_provider.get_password(), + self.credentials_provider.password, check_health=False, ) auth_response = self.read_response() @@ -1068,7 +1068,7 @@ def __init__( self.db = db self.client_name = client_name self.credentials_provider = credentials_provider - if not self.credentials_provider and (username or password): + if (username or password) and self.credentials_provider is None: self.credentials_provider = CredentialsProvider(username, password) self.socket_timeout = socket_timeout self.retry_on_timeout = retry_on_timeout diff --git a/redis/credentials.py b/redis/credentials.py index d2e9c92911..10a5c127af 100644 --- a/redis/credentials.py +++ b/redis/credentials.py @@ -1,9 +1,12 @@ +from typing import Callable, Optional + + class CredentialsProvider: def __init__( self, username: str = "", password: str = "", - supplier: callable = None, + supplier: Optional[Callable] = None, *args, **kwargs, ): @@ -15,8 +18,8 @@ def supplier(arg1, arg2, ...) -> (username, password) :param args: arguments to pass to the supplier function :param kwargs: keyword arguments to pass to the supplier function """ - self.username = username - self.password = password + self._username = "" if username is None else username + self._password = "" if password is None else password self.supplier = supplier self.args = args self.kwargs = kwargs @@ -24,18 +27,24 @@ def supplier(arg1, arg2, ...) -> (username, password) def get_credentials(self): if self.supplier: self.username, self.password = self.supplier(*self.args, **self.kwargs) - if self.username: - auth_args = (self.username, self.password or "") - else: - auth_args = (self.password,) - return auth_args + return self._username, self._password - def get_password(self, call_supplier: bool = True): - if call_supplier and self.supplier: + @property + def password(self): + if self.supplier and not self._password: self.username, self.password = self.supplier(*self.args, **self.kwargs) - return self.password + return self._password - def get_username(self, call_supplier: bool = True): - if call_supplier and self.supplier: + @password.setter + def password(self, value): + self._password = value + + @property + def username(self): + if self.supplier and not self._username: self.username, self.password = self.supplier(*self.args, **self.kwargs) - return self.username + return self._username + + @username.setter + def username(self, value): + self._username = value diff --git a/tests/test_credentials.py b/tests/test_credentials.py index f199e7c87a..bf9126e9d6 100644 --- a/tests/test_credentials.py +++ b/tests/test_credentials.py @@ -1,6 +1,8 @@ import random import string +import pytest + import redis from redis import ResponseError from redis.credentials import CredentialsProvider @@ -17,7 +19,7 @@ def test_credentials_provider_without_supplier(self, r, request): r.config_set("requirepass", temp_pass) creds = creds_provider.get_credentials() assert r.auth(creds[1], creds[0]) is True - assert r.auth(creds_provider.get_password()) is True + assert r.auth(creds_provider.password) is True # test for other users username = "redis-py-auth" @@ -48,8 +50,12 @@ def teardown(): assert r2.ping() is True + @pytest.mark.parametrize("username", ["redis-py-auth", ""]) + @pytest.mark.parametrize("use_password", [True, False]) @skip_if_redis_enterprise() - def test_credentials_provider_with_supplier(self, r, request): + def test_credentials_provider_with_supplier( + self, r, request, username, use_password + ): import functools @functools.lru_cache(maxsize=10) @@ -59,16 +65,18 @@ def get_random_string(length): result_str = "".join(random.choice(letters) for i in range(length)) return result_str - auth_token = get_random_string(5) + user + "_" + endpoint + if use_password: + auth_token = get_random_string(5) + user + "_" + endpoint + else: + auth_token = "" return user, auth_token - username = "redis-py-auth" creds_provider = CredentialsProvider( supplier=auth_supplier, user=username, endpoint="localhost", ) - password = creds_provider.get_password() + password = creds_provider.password assert r.acl_setuser( username, From af8e5603db39dfc22a496289cf9abdea67038c00 Mon Sep 17 00:00:00 2001 From: Bar Shaul Date: Tue, 23 Aug 2022 18:08:19 +0300 Subject: [PATCH 04/23] Added: StaticCredentialProvider, examples, tests Changed: CredentialsProvider to CredentialProvider Fixed: calling AUTH only with password --- docs/examples/connection_examples.ipynb | 113 +++++++++++++- redis/__init__.py | 5 +- redis/client.py | 5 +- redis/cluster.py | 2 +- redis/connection.py | 26 ++-- redis/credentials.py | 61 +++++--- tests/test_credentials.py | 188 ++++++++++++++++++------ 7 files changed, 307 insertions(+), 93 deletions(-) diff --git a/docs/examples/connection_examples.ipynb b/docs/examples/connection_examples.ipynb index 34e1ad0e81..db6feac35e 100644 --- a/docs/examples/connection_examples.ipynb +++ b/docs/examples/connection_examples.ipynb @@ -97,13 +97,110 @@ "user_connection.ping()" ] }, + { + "cell_type": "markdown", + "source": [ + "## Connecting to a redis instance with static credential provider" + ], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "import redis\n", + "\n", + "creds_provider = redis.StaticCredentialProvider(\"username\", \"password\")\n", + "user_connection = redis.Redis(host=\"localhost\", port=6379, credential_provider=creds_provider)\n", + "user_connection.ping()" + ], + "metadata": {} + } + }, + { + "cell_type": "markdown", + "source": [ + "## Connecting to a redis instance with standard credential provider" + ], + "metadata": {} + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "import redis\n", + "\n", + "creds_map = {\"user_1\": \"pass_1\",\n", + " \"user_2\": \"pass_2\"}\n", + "\n", + "# Create a default connection to set the ACL user\n", + "default_connection = redis.Redis(host=\"localhost\", port=6379)\n", + "default_connection.acl_setuser(\n", + " \"user_1\",\n", + " enabled=True,\n", + " passwords=[\"+\" + \"pass_1\"],\n", + " keys=\"~*\",\n", + " commands=[\"+ping\", \"+command\", \"+info\", \"+select\", \"+flushdb\"],\n", + ")\n", + "\n", + "def creds_provider(self):\n", + " return self.username, creds_map.get(self.username)\n", + "\n", + "# Create a CredentialProvider instance for user_1\n", + "creds_provider = redis.CredentialProvider(username=\"user_1\", supplier=creds_provider)\n", + "# Initiate user connection with the credential provider\n", + "user_connection = redis.Redis(host=\"localhost\", port=6379,\n", + " credential_provider=creds_provider)\n", + "user_connection.ping()" + ], + "metadata": {} + } + }, + { + "cell_type": "markdown", + "source": [ + "## Connecting to a redis instance first with an initial credential set and then calling the credential provider" + ], + "metadata": {} + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "import redis\n", + "\n", + "def call_external_supplier():\n", + " # Call to an external credential supplier\n", + " raise NotImplementedError\n", + "\n", + "def creds_supplier(self):\n", + " call_supplier = self.supplier_kwargs.get(\"call_supplier\", True)\n", + " if call_supplier:\n", + " return call_external_supplier()\n", + " # Use the init set only for the first time\n", + " self.kwargs.update({\"call_supplier\": True})\n", + " return self.username, self.password\n", + "\n", + "cred_provider = redis.CredentialProvider(username=\"init_user\",\n", + " password=\"init_pass\",\n", + " call_supplier=False,\n", + " supplier=creds_supplier)" + ], + "metadata": {} + } + }, { "cell_type": "markdown", "metadata": { "collapsed": false }, "source": [ - "## Connecting to a redis instance with AWS Secrets Manager credentials provider." + "## Connecting to a redis instance with AWS Secrets Manager credential provider." ] }, { @@ -124,7 +221,7 @@ "\n", "sm_client = boto3.client('secretsmanager')\n", " \n", - "def sm_auth_provider(secret_id, version_id=None, version_stage='AWSCURRENT'):\n", + "def sm_auth_provider(self, secret_id, version_id=None, version_stage='AWSCURRENT'):\n", " @cachetools.func.ttl_cache(maxsize=128, ttl=24 * 60 * 60) #24h\n", " def get_sm_user_credentials(secret_id, version_id, version_stage):\n", " secret = sm_client.get_secret_value(secret_id, version_id)\n", @@ -133,8 +230,8 @@ " return creds['username'], creds['password']\n", "\n", "secret_id = \"EXAMPLE1-90ab-cdef-fedc-ba987SECRET1\"\n", - "creds_provider = redis.CredentialsProvider(supplier=sm_auth_provider, secret_id=secret_id)\n", - "user_connection = redis.Redis(host=\"localhost\", port=6379, credentials_provider=creds_provider)\n", + "creds_provider = redis.CredentialProvider(supplier=sm_auth_provider, secret_id=secret_id)\n", + "user_connection = redis.Redis(host=\"localhost\", port=6379, credential_provider=creds_provider)\n", "user_connection.ping()" ] }, @@ -142,7 +239,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Connecting to a redis instance with ElastiCache IAM credentials provider." + "## Connecting to a redis instance with ElastiCache IAM credential provider." ] }, { @@ -168,7 +265,7 @@ "\n", "ec_client = boto3.client('elasticache')\n", "\n", - "def iam_auth_provider(user, endpoint, port=6379, region=\"us-east-1\"):\n", + "def iam_auth_provider(self, user, endpoint, port=6379, region=\"us-east-1\"):\n", " @cachetools.func.ttl_cache(maxsize=128, ttl=15 * 60) # 15m\n", " def get_iam_auth_token(user, endpoint, port, region):\n", " return ec_client.generate_iam_auth_token(user, endpoint, port, region)\n", @@ -177,9 +274,9 @@ "\n", "username = \"barshaul\"\n", "endpoint = \"test-001.use1.cache.amazonaws.com\"\n", - "creds_provider = redis.CredentialsProvider(supplier=iam_auth_provider, user=username,\n", + "creds_provider = redis.CredentialProvider(supplier=iam_auth_provider, user=username,\n", " endpoint=endpoint)\n", - "user_connection = redis.Redis(host=endpoint, port=6379, credentials_provider=creds_provider)\n", + "user_connection = redis.Redis(host=endpoint, port=6379, credential_provider=creds_provider)\n", "user_connection.ping()" ] }, diff --git a/redis/__init__.py b/redis/__init__.py index 7122e1a457..d4ce090562 100644 --- a/redis/__init__.py +++ b/redis/__init__.py @@ -9,7 +9,7 @@ SSLConnection, UnixDomainSocketConnection, ) -from redis.credentials import CredentialsProvider +from redis.credentials import CredentialProvider, StaticCredentialProvider from redis.exceptions import ( AuthenticationError, AuthenticationWrongNumberOfArgsError, @@ -63,7 +63,7 @@ def int_or_str(value): "Connection", "ConnectionError", "ConnectionPool", - "CredentialsProvider", + "CredentialProvider", "DataError", "from_url", "InvalidResponse", @@ -78,6 +78,7 @@ def int_or_str(value): "SentinelManagedConnection", "SentinelManagedSSLConnection", "SSLConnection", + "StaticCredentialProvider", "StrictRedis", "TimeoutError", "UnixDomainSocketConnection", diff --git a/redis/client.py b/redis/client.py index 4d2ded2e22..fb7bb60509 100755 --- a/redis/client.py +++ b/redis/client.py @@ -13,6 +13,7 @@ list_or_args, ) from redis.connection import ConnectionPool, SSLConnection, UnixDomainSocketConnection +from redis.credentials import CredentialProvider from redis.exceptions import ( ConnectionError, ExecAbortError, @@ -938,7 +939,7 @@ def __init__( username=None, retry=None, redis_connect_func=None, - credentials_provider=None, + credential_provider: CredentialProvider = None, ): """ Initialize a new Redis client. @@ -986,7 +987,7 @@ def __init__( "health_check_interval": health_check_interval, "client_name": client_name, "redis_connect_func": redis_connect_func, - "credentials_provider": credentials_provider, + "credential_provider": credential_provider, } # based on input, setup appropriate connection args if unix_socket_path is not None: diff --git a/redis/cluster.py b/redis/cluster.py index 13ff353100..f6c9264d39 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -121,7 +121,7 @@ def parse_cluster_shards(resp, **options): "connection_class", "connection_pool", "client_name", - "credentials_provider", + "credential_provider", "db", "decode_responses", "encoding", diff --git a/redis/connection.py b/redis/connection.py index 5630c2736b..be891db505 100755 --- a/redis/connection.py +++ b/redis/connection.py @@ -11,7 +11,7 @@ from urllib.parse import parse_qs, unquote, urlparse from redis.backoff import NoBackoff -from redis.credentials import CredentialsProvider +from redis.credentials import CredentialProvider, StaticCredentialProvider from redis.exceptions import ( AuthenticationError, AuthenticationWrongNumberOfArgsError, @@ -503,7 +503,7 @@ def __init__( username=None, retry=None, redis_connect_func=None, - credentials_provider=None, + credential_provider: CredentialProvider = None, ): """ Initialize a new Connection. @@ -517,9 +517,10 @@ def __init__( self.port = int(port) self.db = db self.client_name = client_name - self.credentials_provider = credentials_provider - if not self.credentials_provider and (username or password): - self.credentials_provider = CredentialsProvider(username, password) + self.credential_provider = credential_provider + if (username or password) and self.credential_provider is None: + # username and password backward compatibility + self.credential_provider = StaticCredentialProvider(username, password) self.socket_timeout = socket_timeout self.socket_connect_timeout = socket_connect_timeout or socket_timeout self.socket_keepalive = socket_keepalive @@ -679,8 +680,8 @@ def on_connect(self): self._parser.on_connect(self) # if credentials provider is set, authenticate - if self.credentials_provider: - auth_args = self.credentials_provider.get_credentials() + if self.credential_provider is not None: + auth_args = self.credential_provider.get_credentials() # avoid checking health here -- PING will fail if we try # to check the health prior to the AUTH self.send_command("AUTH", *auth_args, check_health=False) @@ -694,7 +695,7 @@ def on_connect(self): # https://github.com/andymccurdy/redis-py/issues/1274 self.send_command( "AUTH", - self.credentials_provider.password, + self.credential_provider.password, check_health=False, ) auth_response = self.read_response() @@ -1054,7 +1055,7 @@ def __init__( client_name=None, retry=None, redis_connect_func=None, - credentials_provider=None, + credential_provider: CredentialProvider = None, ): """ Initialize a new UnixDomainSocketConnection. @@ -1067,9 +1068,10 @@ def __init__( self.path = path self.db = db self.client_name = client_name - self.credentials_provider = credentials_provider - if (username or password) and self.credentials_provider is None: - self.credentials_provider = CredentialsProvider(username, password) + self.credential_provider = credential_provider + if (username or password) and self.credential_provider is None: + # username and password backward compatibility + self.credential_provider = StaticCredentialProvider(username, password) self.socket_timeout = socket_timeout self.retry_on_timeout = retry_on_timeout if retry_on_error is SENTINEL: diff --git a/redis/credentials.py b/redis/credentials.py index 10a5c127af..00795b762e 100644 --- a/redis/credentials.py +++ b/redis/credentials.py @@ -1,38 +1,43 @@ -from typing import Callable, Optional +from typing import Callable, Optional, Union -class CredentialsProvider: +class CredentialProvider: def __init__( self, - username: str = "", - password: str = "", + username: Union[str, None] = "", + password: Union[str, None] = "", supplier: Optional[Callable] = None, - *args, - **kwargs, + *supplier_args, + **supplier_kwargs, ): """ Initialize a new Credentials Provider. :param supplier: a supplier function that returns the username and password. - def supplier(arg1, arg2, ...) -> (username, password) - For examples see examples/connection_examples.ipynb - :param args: arguments to pass to the supplier function - :param kwargs: keyword arguments to pass to the supplier function + def supplier(self, arg1, arg2, ...) -> (username, password) + See examples/connection_examples.ipynb + :param supplier_args: arguments to pass to the supplier function + :param supplier_kwargs: keyword arguments to pass to the supplier function """ self._username = "" if username is None else username self._password = "" if password is None else password self.supplier = supplier - self.args = args - self.kwargs = kwargs + self.supplier_args = supplier_args + self.supplier_kwargs = supplier_kwargs def get_credentials(self): - if self.supplier: - self.username, self.password = self.supplier(*self.args, **self.kwargs) - return self._username, self._password + if self.supplier is not None: + self.username, self.password = self.supplier( + self, *self.supplier_args, **self.supplier_kwargs + ) + + return (self._username, self._password) if self._username else (self._password,) @property def password(self): - if self.supplier and not self._password: - self.username, self.password = self.supplier(*self.args, **self.kwargs) + if self.supplier is not None and not self._password: + self.username, self.password = self.supplier( + self, *self.supplier_args, **self.supplier_kwargs + ) return self._password @password.setter @@ -41,10 +46,28 @@ def password(self, value): @property def username(self): - if self.supplier and not self._username: - self.username, self.password = self.supplier(*self.args, **self.kwargs) + if self.supplier is not None and not self._username: + self.username, self.password = self.supplier( + self, *self.supplier_args, **self.supplier_kwargs + ) return self._username @username.setter def username(self, value): self._username = value + + +class StaticCredentialProvider(CredentialProvider): + """ + Simple implementation of CredentialProvider that just wraps static + username and password. + """ + + def __init__( + self, username: Union[str, None] = "", password: Union[str, None] = "" + ): + super().__init__( + username=username, + password=password, + credential_provider=lambda self: (self.username, self.password), + ) diff --git a/tests/test_credentials.py b/tests/test_credentials.py index bf9126e9d6..9f4a61bdd8 100644 --- a/tests/test_credentials.py +++ b/tests/test_credentials.py @@ -5,94 +5,184 @@ import redis from redis import ResponseError -from redis.credentials import CredentialsProvider +from redis.credentials import CredentialProvider, StaticCredentialProvider from tests.conftest import _get_client, skip_if_redis_enterprise -class TestCredentialsProvider: - @skip_if_redis_enterprise() - def test_credentials_provider_without_supplier(self, r, request): - # first, test for default user (`username` is supposed to be optional) - default_username = "default" - temp_pass = "temp_pass" - creds_provider = CredentialsProvider(default_username, temp_pass) - r.config_set("requirepass", temp_pass) - creds = creds_provider.get_credentials() - assert r.auth(creds[1], creds[0]) is True - assert r.auth(creds_provider.password) is True +def init_acl_user(r, request, username, password): + # reset the user + r.acl_deluser(username) + if password: + assert ( + r.acl_setuser( + username, + enabled=True, + passwords=["+" + password], + keys="~*", + commands=[ + "+ping", + "+command", + "+info", + "+select", + "+flushdb", + "+cluster", + ], + ) + is True + ) + else: + assert ( + r.acl_setuser( + username, + enabled=True, + keys="~*", + commands=[ + "+ping", + "+command", + "+info", + "+select", + "+flushdb", + "+cluster", + ], + nopass=True, + ) + is True + ) - # test for other users - username = "redis-py-auth" - password = "strong_password" + if request is not None: def teardown(): - try: - r.auth(temp_pass) - except ResponseError: - r.auth("default", "") - r.config_set("requirepass", "") r.acl_deluser(username) request.addfinalizer(teardown) - assert r.acl_setuser( - username, - enabled=True, - passwords=["+" + password], - keys="~*", - commands=["+ping", "+command", "+info", "+select", "+flushdb", "+cluster"], + +def init_required_pass(r, request, password): + r.config_set("requirepass", password) + + def teardown(): + try: + r.auth(password) + except ResponseError: + r.auth("default", "") + r.config_set("requirepass", "") + + request.addfinalizer(teardown) + + +class TestCredentialsProvider: + @skip_if_redis_enterprise() + def test_credential_provider_without_supplier_only_pass(self, r, request): + # test for default user (`username` is supposed to be optional) + password = "password" + creds_provider = CredentialProvider(password=password) + init_required_pass(r, request, password) + assert r.auth(creds_provider.password) is True + + r2 = _get_client( + redis.Redis, request, flushdb=False, credential_provider=creds_provider ) - creds_provider2 = CredentialsProvider(username, password) + assert r2.ping() + + @skip_if_redis_enterprise() + def test_credential_provider_without_supplier_acl_user_and_pass(self, r, request): + # test for other users + username = "username" + password = "password" + + init_acl_user(r, request, username, password) + creds_provider = CredentialProvider(username, password) r2 = _get_client( - redis.Redis, request, flushdb=False, credentials_provider=creds_provider2 + redis.Redis, request, flushdb=False, credential_provider=creds_provider ) assert r2.ping() is True - @pytest.mark.parametrize("username", ["redis-py-auth", ""]) - @pytest.mark.parametrize("use_password", [True, False]) + @pytest.mark.parametrize("username", ["username", ""]) @skip_if_redis_enterprise() - def test_credentials_provider_with_supplier( - self, r, request, username, use_password - ): + @pytest.mark.onlynoncluster + def test_credential_provider_with_supplier(self, r, request, username): import functools @functools.lru_cache(maxsize=10) - def auth_supplier(user, endpoint): + def auth_supplier(self, user, endpoint): def get_random_string(length): letters = string.ascii_lowercase result_str = "".join(random.choice(letters) for i in range(length)) return result_str - if use_password: - auth_token = get_random_string(5) + user + "_" + endpoint - else: - auth_token = "" + auth_token = get_random_string(5) + user + "_" + endpoint return user, auth_token - creds_provider = CredentialsProvider( + creds_provider = CredentialProvider( supplier=auth_supplier, user=username, endpoint="localhost", ) password = creds_provider.password - assert r.acl_setuser( - username, - enabled=True, - passwords=["+" + password], - keys="~*", - commands=["+ping", "+command", "+info", "+select", "+flushdb", "+cluster"], + if username: + init_acl_user(r, request, username, password) + else: + init_required_pass(r, request, password) + + r2 = _get_client( + redis.Redis, request, flushdb=False, credential_provider=creds_provider ) - def teardown(): - r.acl_deluser(username) + assert r2.ping() is True - request.addfinalizer(teardown) + def test_credential_provider_no_password_success(self, r, request): + def creds_provider(self): + return "username", "" + init_acl_user(r, request, "username", "") + creds_provider = CredentialProvider(supplier=creds_provider) r2 = _get_client( - redis.Redis, request, flushdb=False, credentials_provider=creds_provider + redis.Redis, request, flushdb=False, credential_provider=creds_provider ) + assert r2.ping() is True + @pytest.mark.onlynoncluster + def test_credential_provider_no_password_error(self, r, request): + def bad_creds_provider(self): + return "username", "" + + init_acl_user(r, request, "username", "password") + creds_provider = CredentialProvider(supplier=bad_creds_provider) + with pytest.raises(ResponseError) as e: + _get_client( + redis.Redis, request, flushdb=False, credential_provider=creds_provider + ) + assert e.match("WRONGPASS") + + +class TestStaticCredentialProvider: + def test_static_credential_provider_acl_user_and_pass(self, r, request): + username = "username" + password = "password" + provider = StaticCredentialProvider(username, password) + assert provider.username == username + assert provider.password == password + assert provider.get_credentials() == (username, password) + init_acl_user(r, request, provider.username, provider.password) + r2 = _get_client( + redis.Redis, request, flushdb=False, credential_provider=provider + ) + assert r2.ping() is True + + def test_static_credential_provider_only_password(self, r, request): + password = "password" + provider = StaticCredentialProvider(password=password) + assert provider.username == "" + assert provider.password == password + assert provider.get_credentials() == (password,) + + init_required_pass(r, request, password) + + r2 = _get_client( + redis.Redis, request, flushdb=False, credential_provider=provider + ) + assert r2.auth(provider.password) is True assert r2.ping() is True From 2261cb08a4a46782dbab2339ec928cf493479829 Mon Sep 17 00:00:00 2001 From: Bar Shaul Date: Tue, 30 Aug 2022 11:38:47 +0300 Subject: [PATCH 05/23] Changed private members' prefix to __ --- redis/credentials.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/redis/credentials.py b/redis/credentials.py index 00795b762e..d970f78bd2 100644 --- a/redis/credentials.py +++ b/redis/credentials.py @@ -18,8 +18,8 @@ def supplier(self, arg1, arg2, ...) -> (username, password) :param supplier_args: arguments to pass to the supplier function :param supplier_kwargs: keyword arguments to pass to the supplier function """ - self._username = "" if username is None else username - self._password = "" if password is None else password + self.__username = "" if username is None else username + self.__password = "" if password is None else password self.supplier = supplier self.supplier_args = supplier_args self.supplier_kwargs = supplier_kwargs @@ -30,31 +30,31 @@ def get_credentials(self): self, *self.supplier_args, **self.supplier_kwargs ) - return (self._username, self._password) if self._username else (self._password,) + return (self.__username, self.__password) if self.__username else (self.__password,) @property def password(self): - if self.supplier is not None and not self._password: + if self.supplier is not None and not self.__password: self.username, self.password = self.supplier( self, *self.supplier_args, **self.supplier_kwargs ) - return self._password + return self.__password @password.setter def password(self, value): - self._password = value + self.__password = value @property def username(self): - if self.supplier is not None and not self._username: + if self.supplier is not None and not self.__username: self.username, self.password = self.supplier( self, *self.supplier_args, **self.supplier_kwargs ) - return self._username + return self.__username @username.setter def username(self, value): - self._username = value + self.__username = value class StaticCredentialProvider(CredentialProvider): From ddfe1ea3cc4dab4435afb54bb6c5c367fc35371a Mon Sep 17 00:00:00 2001 From: Bar Shaul Date: Thu, 1 Sep 2022 17:10:53 +0300 Subject: [PATCH 06/23] fixed linters --- redis/credentials.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/redis/credentials.py b/redis/credentials.py index d970f78bd2..24b6e770d9 100644 --- a/redis/credentials.py +++ b/redis/credentials.py @@ -30,7 +30,11 @@ def get_credentials(self): self, *self.supplier_args, **self.supplier_kwargs ) - return (self.__username, self.__password) if self.__username else (self.__password,) + return ( + (self.__username, self.__password) + if self.__username + else (self.__password,) + ) @property def password(self): From b4810676f37b5bc6606016da218d2981bb75b12d Mon Sep 17 00:00:00 2001 From: Bar Shaul Date: Sun, 4 Sep 2022 10:51:38 +0300 Subject: [PATCH 07/23] fixed auth test --- tests/test_commands.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/test_commands.py b/tests/test_commands.py index 1c9a5c27eb..929b64f6df 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -10,6 +10,7 @@ import redis from redis import exceptions from redis.client import parse_info +from redis.credentials import StaticCredentialProvider from .conftest import ( _get_client, @@ -95,7 +96,9 @@ def teardown(): # error when switching to the db 9 because we're not authenticated yet # setting the password on the connection itself triggers the # authentication in the connection's `on_connect` method - r.connection.password = temp_pass + r.connection.credential_provider = StaticCredentialProvider( + password=temp_pass + ) except AttributeError: # connection field is not set in Redis Cluster, but that's ok # because the problem discussed above does not apply to Redis Cluster From 686d172fa620bf327d8246625b47c6209d892d3d Mon Sep 17 00:00:00 2001 From: Bar Shaul Date: Sun, 4 Sep 2022 11:04:32 +0300 Subject: [PATCH 08/23] fixed credential test --- tests/test_credentials.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_credentials.py b/tests/test_credentials.py index 9f4a61bdd8..4f08651482 100644 --- a/tests/test_credentials.py +++ b/tests/test_credentials.py @@ -4,7 +4,7 @@ import pytest import redis -from redis import ResponseError +from redis import AuthenticationError, ResponseError from redis.credentials import CredentialProvider, StaticCredentialProvider from tests.conftest import _get_client, skip_if_redis_enterprise @@ -63,7 +63,7 @@ def init_required_pass(r, request, password): def teardown(): try: r.auth(password) - except ResponseError: + except (ResponseError, AuthenticationError): r.auth("default", "") r.config_set("requirepass", "") @@ -151,11 +151,11 @@ def bad_creds_provider(self): init_acl_user(r, request, "username", "password") creds_provider = CredentialProvider(supplier=bad_creds_provider) - with pytest.raises(ResponseError) as e: + with pytest.raises(AuthenticationError) as e: _get_client( redis.Redis, request, flushdb=False, credential_provider=creds_provider ) - assert e.match("WRONGPASS") + assert e.match("invalid username-password") class TestStaticCredentialProvider: From d1d10af85c563a62f9b018e402205f7e7fe75424 Mon Sep 17 00:00:00 2001 From: Bar Shaul Date: Tue, 20 Sep 2022 14:45:19 +0300 Subject: [PATCH 09/23] Raise an error if username or password are passed along with credential_provider --- redis/connection.py | 31 ++++++++++++++++++++++++++----- tests/test_credentials.py | 11 ++++++++++- 2 files changed, 36 insertions(+), 6 deletions(-) diff --git a/redis/connection.py b/redis/connection.py index be891db505..9d07ef6ffd 100755 --- a/redis/connection.py +++ b/redis/connection.py @@ -517,9 +517,20 @@ def __init__( self.port = int(port) self.db = db self.client_name = client_name + if (username or password) and credential_provider is not None: + raise DataError( + "'username' and 'password' cannot be passed along with 'credential_provider'. " + "Please provide only one of the following arguments: \n" + "1. 'password' and (optional) 'username'\n" + "2. 'credential_provider'" + ) + self.credential_provider = credential_provider - if (username or password) and self.credential_provider is None: - # username and password backward compatibility + self.username = username + self.password = password + if username or password: + # Keep backward compatibility by creating a static credential provider + # for the passed username and password self.credential_provider = StaticCredentialProvider(username, password) self.socket_timeout = socket_timeout self.socket_connect_timeout = socket_connect_timeout or socket_timeout @@ -695,7 +706,7 @@ def on_connect(self): # https://github.com/andymccurdy/redis-py/issues/1274 self.send_command( "AUTH", - self.credential_provider.password, + auth_args[-1], check_health=False, ) auth_response = self.read_response() @@ -1068,9 +1079,19 @@ def __init__( self.path = path self.db = db self.client_name = client_name + if (username or password) and credential_provider is not None: + raise DataError( + "'username' and 'password' cannot be passed along with 'credential_provider'. " + "Please provide only one of the following arguments: \n" + "1. 'password' and (optional) 'username'\n" + "2. 'credential_provider'" + ) self.credential_provider = credential_provider - if (username or password) and self.credential_provider is None: - # username and password backward compatibility + self.username = username + self.password = password + if username or password: + # Keep backward compatibility by creating a static credential provider + # for the passed username and password self.credential_provider = StaticCredentialProvider(username, password) self.socket_timeout = socket_timeout self.retry_on_timeout = retry_on_timeout diff --git a/tests/test_credentials.py b/tests/test_credentials.py index 4f08651482..4a082032ca 100644 --- a/tests/test_credentials.py +++ b/tests/test_credentials.py @@ -4,7 +4,7 @@ import pytest import redis -from redis import AuthenticationError, ResponseError +from redis import AuthenticationError, ResponseError, DataError from redis.credentials import CredentialProvider, StaticCredentialProvider from tests.conftest import _get_client, skip_if_redis_enterprise @@ -186,3 +186,12 @@ def test_static_credential_provider_only_password(self, r, request): ) assert r2.auth(provider.password) is True assert r2.ping() is True + + def test_password_and_username_together_with_cred_provider_raise_error(self, request): + provider = StaticCredentialProvider(password="password") + with pytest.raises(DataError) as e: + _get_client( + redis.Redis, request, flushdb=False, credential_provider=provider + ) + assert e.match("'username' and 'password' cannot be passed along with " + "'credential_provider'.") From 9de8d211350f27c2c64c87cedde89186df7a7a49 Mon Sep 17 00:00:00 2001 From: Bar Shaul Date: Tue, 20 Sep 2022 14:58:44 +0300 Subject: [PATCH 10/23] fixing linters --- redis/connection.py | 8 ++++---- tests/test_credentials.py | 18 ++++++++++-------- 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/redis/connection.py b/redis/connection.py index 9d07ef6ffd..fd4cd09323 100755 --- a/redis/connection.py +++ b/redis/connection.py @@ -519,8 +519,8 @@ def __init__( self.client_name = client_name if (username or password) and credential_provider is not None: raise DataError( - "'username' and 'password' cannot be passed along with 'credential_provider'. " - "Please provide only one of the following arguments: \n" + "'username' and 'password' cannot be passed along with 'credential_" + "provider'. Please provide only one of the following arguments: \n" "1. 'password' and (optional) 'username'\n" "2. 'credential_provider'" ) @@ -1081,8 +1081,8 @@ def __init__( self.client_name = client_name if (username or password) and credential_provider is not None: raise DataError( - "'username' and 'password' cannot be passed along with 'credential_provider'. " - "Please provide only one of the following arguments: \n" + "'username' and 'password' cannot be passed along with 'credential_" + "provider'. Please provide only one of the following arguments: \n" "1. 'password' and (optional) 'username'\n" "2. 'credential_provider'" ) diff --git a/tests/test_credentials.py b/tests/test_credentials.py index 4a082032ca..ae90ae3004 100644 --- a/tests/test_credentials.py +++ b/tests/test_credentials.py @@ -4,7 +4,7 @@ import pytest import redis -from redis import AuthenticationError, ResponseError, DataError +from redis import AuthenticationError, DataError, ResponseError from redis.credentials import CredentialProvider, StaticCredentialProvider from tests.conftest import _get_client, skip_if_redis_enterprise @@ -187,11 +187,13 @@ def test_static_credential_provider_only_password(self, r, request): assert r2.auth(provider.password) is True assert r2.ping() is True - def test_password_and_username_together_with_cred_provider_raise_error(self, request): - provider = StaticCredentialProvider(password="password") + def test_password_and_username_together_with_cred_provider_raise_error( + self, request + ): + creds = StaticCredentialProvider(password="password") with pytest.raises(DataError) as e: - _get_client( - redis.Redis, request, flushdb=False, credential_provider=provider - ) - assert e.match("'username' and 'password' cannot be passed along with " - "'credential_provider'.") + _get_client(redis.Redis, request, flushdb=False, credential_provider=creds) + assert e.match( + "'username' and 'password' cannot be passed along with " + "'credential_provider'." + ) From def996bcec8a47be34f0661d152ac99c769b0b7e Mon Sep 17 00:00:00 2001 From: Bar Shaul Date: Wed, 21 Sep 2022 19:17:17 +0300 Subject: [PATCH 11/23] fixing test --- tests/test_credentials.py | 33 ++++++++++++++++++++++----------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/tests/test_credentials.py b/tests/test_credentials.py index ae90ae3004..2717ad4c3b 100644 --- a/tests/test_credentials.py +++ b/tests/test_credentials.py @@ -157,6 +157,28 @@ def bad_creds_provider(self): ) assert e.match("invalid username-password") + @pytest.mark.onlynoncluster + def test_password_and_username_together_with_cred_provider_raise_error( + self, r, request + ): + init_acl_user(r, request, "username", "password") + cred_provider = StaticCredentialProvider( + username="username", password="password" + ) + with pytest.raises(DataError) as e: + _get_client( + redis.Redis, + request, + flushdb=False, + username="username", + password="password", + credential_provider=cred_provider, + ) + assert e.match( + "'username' and 'password' cannot be passed along with " + "'credential_provider'." + ) + class TestStaticCredentialProvider: def test_static_credential_provider_acl_user_and_pass(self, r, request): @@ -186,14 +208,3 @@ def test_static_credential_provider_only_password(self, r, request): ) assert r2.auth(provider.password) is True assert r2.ping() is True - - def test_password_and_username_together_with_cred_provider_raise_error( - self, request - ): - creds = StaticCredentialProvider(password="password") - with pytest.raises(DataError) as e: - _get_client(redis.Redis, request, flushdb=False, credential_provider=creds) - assert e.match( - "'username' and 'password' cannot be passed along with " - "'credential_provider'." - ) From 29c80061378c7c65539a2e14afe135a456b544e3 Mon Sep 17 00:00:00 2001 From: Bar Shaul Date: Sun, 2 Oct 2022 09:42:07 +0300 Subject: [PATCH 12/23] Changed dundered to single per side underscore --- redis/credentials.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/redis/credentials.py b/redis/credentials.py index 24b6e770d9..113627184d 100644 --- a/redis/credentials.py +++ b/redis/credentials.py @@ -18,8 +18,8 @@ def supplier(self, arg1, arg2, ...) -> (username, password) :param supplier_args: arguments to pass to the supplier function :param supplier_kwargs: keyword arguments to pass to the supplier function """ - self.__username = "" if username is None else username - self.__password = "" if password is None else password + self._username_ = "" if username is None else username + self._password_ = "" if password is None else password self.supplier = supplier self.supplier_args = supplier_args self.supplier_kwargs = supplier_kwargs @@ -31,34 +31,34 @@ def get_credentials(self): ) return ( - (self.__username, self.__password) - if self.__username - else (self.__password,) + (self._username_, self._password_) + if self._username_ + else (self._password_,) ) @property def password(self): - if self.supplier is not None and not self.__password: + if self.supplier is not None and not self._password_: self.username, self.password = self.supplier( self, *self.supplier_args, **self.supplier_kwargs ) - return self.__password + return self._password_ @password.setter def password(self, value): - self.__password = value + self._password_ = value @property def username(self): - if self.supplier is not None and not self.__username: + if self.supplier is not None and not self._username_: self.username, self.password = self.supplier( self, *self.supplier_args, **self.supplier_kwargs ) - return self.__username + return self._username_ @username.setter def username(self, value): - self.__username = value + self._username_ = value class StaticCredentialProvider(CredentialProvider): From 6b8cf1f87a51ea4ef123872861a771405c728404 Mon Sep 17 00:00:00 2001 From: Bar Shaul Date: Sun, 2 Oct 2022 18:53:32 +0300 Subject: [PATCH 13/23] Changed Connection class members username and password to properties to enable backward compatibility with changing the members value on existing connection. --- redis/connection.py | 69 ++++++++++++++++++++++++++++++++++++--- redis/credentials.py | 14 +++----- tests/test_credentials.py | 29 ++++++++++++++++ 3 files changed, 99 insertions(+), 13 deletions(-) diff --git a/redis/connection.py b/redis/connection.py index fd4cd09323..d560992a84 100755 --- a/redis/connection.py +++ b/redis/connection.py @@ -8,6 +8,7 @@ from itertools import chain from queue import Empty, Full, LifoQueue from time import time +from typing import Optional from urllib.parse import parse_qs, unquote, urlparse from redis.backoff import NoBackoff @@ -526,8 +527,6 @@ def __init__( ) self.credential_provider = credential_provider - self.username = username - self.password = password if username or password: # Keep backward compatibility by creating a static credential provider # for the passed username and password @@ -564,6 +563,38 @@ def __init__( self._connect_callbacks = [] self._buffer_cutoff = 6000 + @property + def password(self) -> Optional[str]: + if self.credential_provider is not None: + return self.credential_provider.password + else: + return None + + @password.setter + def password(self, value: Optional[str]): + if value is None: + # Delete the credential provider + self.credential_provider = None + return + if self.credential_provider is not None: + self.credential_provider.password = value + else: + self.credential_provider = StaticCredentialProvider(password=value) + + @property + def username(self) -> Optional[str]: + if self.credential_provider is not None: + return self.credential_provider.username + else: + return None + + @username.setter + def username(self, value: Optional[str]): + if self.credential_provider is not None: + self.credential_provider.username = value + else: + self.credential_provider = StaticCredentialProvider(username=value) + def __repr__(self): repr_args = ",".join([f"{k}={v}" for k, v in self.repr_pieces()]) return f"{self.__class__.__name__}<{repr_args}>" @@ -1087,8 +1118,6 @@ def __init__( "2. 'credential_provider'" ) self.credential_provider = credential_provider - self.username = username - self.password = password if username or password: # Keep backward compatibility by creating a static credential provider # for the passed username and password @@ -1121,6 +1150,38 @@ def __init__( self._connect_callbacks = [] self._buffer_cutoff = 6000 + @property + def password(self) -> Optional[str]: + if self.credential_provider is not None: + return self.credential_provider.password + else: + return None + + @password.setter + def password(self, value: Optional[str]): + if value is None: + # Delete the credential provider + self.credential_provider = None + return + if self.credential_provider is not None: + self.credential_provider.password = value + else: + self.credential_provider = StaticCredentialProvider(password=value) + + @property + def username(self) -> Optional[str]: + if self.credential_provider is not None: + return self.credential_provider.username + else: + return None + + @username.setter + def username(self, value: Optional[str]): + if self.credential_provider is not None: + self.credential_provider.username = value + else: + self.credential_provider = StaticCredentialProvider(username=value) + def repr_pieces(self): pieces = [("path", self.path), ("db", self.db)] if self.client_name: diff --git a/redis/credentials.py b/redis/credentials.py index 113627184d..76635488ba 100644 --- a/redis/credentials.py +++ b/redis/credentials.py @@ -18,8 +18,8 @@ def supplier(self, arg1, arg2, ...) -> (username, password) :param supplier_args: arguments to pass to the supplier function :param supplier_kwargs: keyword arguments to pass to the supplier function """ - self._username_ = "" if username is None else username - self._password_ = "" if password is None else password + self._username_ = username if username is not None else "" + self._password_ = password if password is not None else "" self.supplier = supplier self.supplier_args = supplier_args self.supplier_kwargs = supplier_kwargs @@ -46,7 +46,7 @@ def password(self): @password.setter def password(self, value): - self._password_ = value + self._password_ = value if value is not None else "" @property def username(self): @@ -58,7 +58,7 @@ def username(self): @username.setter def username(self, value): - self._username_ = value + self._username_ = value if value is not None else "" class StaticCredentialProvider(CredentialProvider): @@ -70,8 +70,4 @@ class StaticCredentialProvider(CredentialProvider): def __init__( self, username: Union[str, None] = "", password: Union[str, None] = "" ): - super().__init__( - username=username, - password=password, - credential_provider=lambda self: (self.username, self.password), - ) + super().__init__(username=username, password=password) diff --git a/tests/test_credentials.py b/tests/test_credentials.py index 2717ad4c3b..80c7642680 100644 --- a/tests/test_credentials.py +++ b/tests/test_credentials.py @@ -6,6 +6,7 @@ import redis from redis import AuthenticationError, DataError, ResponseError from redis.credentials import CredentialProvider, StaticCredentialProvider +from redis.utils import str_if_bytes from tests.conftest import _get_client, skip_if_redis_enterprise @@ -179,6 +180,34 @@ def test_password_and_username_together_with_cred_provider_raise_error( "'credential_provider'." ) + @pytest.mark.onlynoncluster + def test_change_username_password_on_existing_connection(self, r, request): + username = "origin_username" + password = "origin_password" + new_username = "new_username" + new_password = "new_password" + init_acl_user(r, request, username, password) + r2 = _get_client( + redis.Redis, request, flushdb=False, username=username, password=password + ) + assert r2.ping() + conn = r2.connection_pool.get_connection("_") + conn.send_command("PING") + assert str_if_bytes(conn.read_response()) == "PONG" + assert conn.username == username + assert conn.password == password + init_acl_user(r, request, new_username, new_password) + conn.password = new_password + conn.username = new_username + assert conn.credential_provider.password == new_password + assert conn.credential_provider.username == new_username + conn.send_command("PING") + assert str_if_bytes(conn.read_response()) == "PONG" + conn.username = None + assert conn.credential_provider.username == "" + conn.password = None + assert conn.credential_provider is None + class TestStaticCredentialProvider: def test_static_credential_provider_acl_user_and_pass(self, r, request): From c37e0f1bd10d506c4858a02fbaf772554a0675d2 Mon Sep 17 00:00:00 2001 From: Bar Shaul Date: Mon, 3 Oct 2022 10:01:49 +0300 Subject: [PATCH 14/23] Reverting last commit and adding backward compatibility to 'username' and 'password' inside on_connect function --- redis/connection.py | 88 ++++++--------------------------------- tests/test_credentials.py | 4 -- 2 files changed, 12 insertions(+), 80 deletions(-) diff --git a/redis/connection.py b/redis/connection.py index d560992a84..e3c915c1b0 100755 --- a/redis/connection.py +++ b/redis/connection.py @@ -8,7 +8,6 @@ from itertools import chain from queue import Empty, Full, LifoQueue from time import time -from typing import Optional from urllib.parse import parse_qs, unquote, urlparse from redis.backoff import NoBackoff @@ -527,10 +526,8 @@ def __init__( ) self.credential_provider = credential_provider - if username or password: - # Keep backward compatibility by creating a static credential provider - # for the passed username and password - self.credential_provider = StaticCredentialProvider(username, password) + self.password = password + self.username = username self.socket_timeout = socket_timeout self.socket_connect_timeout = socket_connect_timeout or socket_timeout self.socket_keepalive = socket_keepalive @@ -563,38 +560,6 @@ def __init__( self._connect_callbacks = [] self._buffer_cutoff = 6000 - @property - def password(self) -> Optional[str]: - if self.credential_provider is not None: - return self.credential_provider.password - else: - return None - - @password.setter - def password(self, value: Optional[str]): - if value is None: - # Delete the credential provider - self.credential_provider = None - return - if self.credential_provider is not None: - self.credential_provider.password = value - else: - self.credential_provider = StaticCredentialProvider(password=value) - - @property - def username(self) -> Optional[str]: - if self.credential_provider is not None: - return self.credential_provider.username - else: - return None - - @username.setter - def username(self, value: Optional[str]): - if self.credential_provider is not None: - self.credential_provider.username = value - else: - self.credential_provider = StaticCredentialProvider(username=value) - def __repr__(self): repr_args = ",".join([f"{k}={v}" for k, v in self.repr_pieces()]) return f"{self.__class__.__name__}<{repr_args}>" @@ -721,9 +686,14 @@ def on_connect(self): "Initialize the connection, authenticate and select a database" self._parser.on_connect(self) - # if credentials provider is set, authenticate - if self.credential_provider is not None: - auth_args = self.credential_provider.get_credentials() + # if credential provider or username and/or password are set, authenticate + if self.credential_provider or (self.username or self.password): + cred_provider = ( + self.credential_provider + if self.credential_provider + else StaticCredentialProvider(self.username, self.password) + ) + auth_args = cred_provider.get_credentials() # avoid checking health here -- PING will fail if we try # to check the health prior to the AUTH self.send_command("AUTH", *auth_args, check_health=False) @@ -1118,10 +1088,8 @@ def __init__( "2. 'credential_provider'" ) self.credential_provider = credential_provider - if username or password: - # Keep backward compatibility by creating a static credential provider - # for the passed username and password - self.credential_provider = StaticCredentialProvider(username, password) + self.password = password + self.username = username self.socket_timeout = socket_timeout self.retry_on_timeout = retry_on_timeout if retry_on_error is SENTINEL: @@ -1150,38 +1118,6 @@ def __init__( self._connect_callbacks = [] self._buffer_cutoff = 6000 - @property - def password(self) -> Optional[str]: - if self.credential_provider is not None: - return self.credential_provider.password - else: - return None - - @password.setter - def password(self, value: Optional[str]): - if value is None: - # Delete the credential provider - self.credential_provider = None - return - if self.credential_provider is not None: - self.credential_provider.password = value - else: - self.credential_provider = StaticCredentialProvider(password=value) - - @property - def username(self) -> Optional[str]: - if self.credential_provider is not None: - return self.credential_provider.username - else: - return None - - @username.setter - def username(self, value: Optional[str]): - if self.credential_provider is not None: - self.credential_provider.username = value - else: - self.credential_provider = StaticCredentialProvider(username=value) - def repr_pieces(self): pieces = [("path", self.path), ("db", self.db)] if self.client_name: diff --git a/tests/test_credentials.py b/tests/test_credentials.py index 80c7642680..3b69239763 100644 --- a/tests/test_credentials.py +++ b/tests/test_credentials.py @@ -199,14 +199,10 @@ def test_change_username_password_on_existing_connection(self, r, request): init_acl_user(r, request, new_username, new_password) conn.password = new_password conn.username = new_username - assert conn.credential_provider.password == new_password - assert conn.credential_provider.username == new_username conn.send_command("PING") assert str_if_bytes(conn.read_response()) == "PONG" conn.username = None - assert conn.credential_provider.username == "" conn.password = None - assert conn.credential_provider is None class TestStaticCredentialProvider: From abe613773b070754f54e7e797d3a7af022673700 Mon Sep 17 00:00:00 2001 From: Bar Shaul Date: Wed, 2 Nov 2022 16:06:52 +0200 Subject: [PATCH 15/23] Refactored CredentialProvider class --- docs/examples/connection_examples.ipynb | 50 +++++++----- redis/__init__.py | 4 +- redis/connection.py | 11 +-- redis/credentials.py | 77 ++++-------------- tests/test_commands.py | 5 +- tests/test_credentials.py | 102 +++++++++++++----------- 6 files changed, 107 insertions(+), 142 deletions(-) diff --git a/docs/examples/connection_examples.ipynb b/docs/examples/connection_examples.ipynb index db6feac35e..ca8dd443c6 100644 --- a/docs/examples/connection_examples.ipynb +++ b/docs/examples/connection_examples.ipynb @@ -100,7 +100,7 @@ { "cell_type": "markdown", "source": [ - "## Connecting to a redis instance with static credential provider" + "## Connecting to a redis instance with username and password credential provider" ], "metadata": {} }, @@ -111,7 +111,7 @@ "source": [ "import redis\n", "\n", - "creds_provider = redis.StaticCredentialProvider(\"username\", \"password\")\n", + "creds_provider = redis.UsernamePasswordCredentialProvider(\"username\", \"password\")\n", "user_connection = redis.Redis(host=\"localhost\", port=6379, credential_provider=creds_provider)\n", "user_connection.ping()" ], @@ -131,11 +131,19 @@ "execution_count": null, "outputs": [], "source": [ + "from typing import Tuple\n", "import redis\n", "\n", "creds_map = {\"user_1\": \"pass_1\",\n", " \"user_2\": \"pass_2\"}\n", "\n", + "class UserMapCredentialProvider(redis.CredentialProvider):\n", + " def __init__(self, username: str):\n", + " self.username = username\n", + "\n", + " def get_credentials(self) -> Tuple[str, str]:\n", + " return self.username, creds_map.get(self.username)\n", + "\n", "# Create a default connection to set the ACL user\n", "default_connection = redis.Redis(host=\"localhost\", port=6379)\n", "default_connection.acl_setuser(\n", @@ -146,11 +154,8 @@ " commands=[\"+ping\", \"+command\", \"+info\", \"+select\", \"+flushdb\"],\n", ")\n", "\n", - "def creds_provider(self):\n", - " return self.username, creds_map.get(self.username)\n", - "\n", - "# Create a CredentialProvider instance for user_1\n", - "creds_provider = redis.CredentialProvider(username=\"user_1\", supplier=creds_provider)\n", + "# Create a UserMapCredentialProvider instance for user_1\n", + "creds_provider = UserMapCredentialProvider(\"user_1\")\n", "# Initiate user connection with the credential provider\n", "user_connection = redis.Redis(host=\"localhost\", port=6379,\n", " credential_provider=creds_provider)\n", @@ -172,24 +177,27 @@ "execution_count": null, "outputs": [], "source": [ + "from typing import Union\n", "import redis\n", "\n", - "def call_external_supplier():\n", - " # Call to an external credential supplier\n", - " raise NotImplementedError\n", + "class InitCredsSetCredentialProvider(redis.CredentialProvider):\n", + " def __init__(self, username, password):\n", + " self.username = username\n", + " self.password = password\n", + " self.call_supplier = False\n", + "\n", + " def call_external_supplier(self) -> Union[Tuple[str], Tuple[str, str]]:\n", + " # Call to an external credential supplier\n", + " raise NotImplementedError\n", "\n", - "def creds_supplier(self):\n", - " call_supplier = self.supplier_kwargs.get(\"call_supplier\", True)\n", - " if call_supplier:\n", - " return call_external_supplier()\n", - " # Use the init set only for the first time\n", - " self.kwargs.update({\"call_supplier\": True})\n", - " return self.username, self.password\n", + " def get_credentials(self) -> Union[Tuple[str], Tuple[str, str]]:\n", + " if self.call_supplier:\n", + " return self.call_external_supplier()\n", + " # Use the init set only for the first time\n", + " self.call_supplier = True\n", + " return self.username, self.password\n", "\n", - "cred_provider = redis.CredentialProvider(username=\"init_user\",\n", - " password=\"init_pass\",\n", - " call_supplier=False,\n", - " supplier=creds_supplier)" + "cred_provider = InitCredsSetCredentialProvider(username=\"init_user\", password=\"init_pass\")" ], "metadata": {} } diff --git a/redis/__init__.py b/redis/__init__.py index d4ce090562..5201fe22d4 100644 --- a/redis/__init__.py +++ b/redis/__init__.py @@ -9,7 +9,7 @@ SSLConnection, UnixDomainSocketConnection, ) -from redis.credentials import CredentialProvider, StaticCredentialProvider +from redis.credentials import CredentialProvider, UsernamePasswordCredentialProvider from redis.exceptions import ( AuthenticationError, AuthenticationWrongNumberOfArgsError, @@ -78,7 +78,7 @@ def int_or_str(value): "SentinelManagedConnection", "SentinelManagedSSLConnection", "SSLConnection", - "StaticCredentialProvider", + "UsernamePasswordCredentialProvider", "StrictRedis", "TimeoutError", "UnixDomainSocketConnection", diff --git a/redis/connection.py b/redis/connection.py index e3c915c1b0..47dfdb79a8 100755 --- a/redis/connection.py +++ b/redis/connection.py @@ -11,7 +11,7 @@ from urllib.parse import parse_qs, unquote, urlparse from redis.backoff import NoBackoff -from redis.credentials import CredentialProvider, StaticCredentialProvider +from redis.credentials import CredentialProvider, UsernamePasswordCredentialProvider from redis.exceptions import ( AuthenticationError, AuthenticationWrongNumberOfArgsError, @@ -690,8 +690,7 @@ def on_connect(self): if self.credential_provider or (self.username or self.password): cred_provider = ( self.credential_provider - if self.credential_provider - else StaticCredentialProvider(self.username, self.password) + or UsernamePasswordCredentialProvider(self.username, self.password) ) auth_args = cred_provider.get_credentials() # avoid checking health here -- PING will fail if we try @@ -705,11 +704,7 @@ def on_connect(self): # server seems to be < 6.0.0 which expects a single password # arg. retry auth with just the password. # https://github.com/andymccurdy/redis-py/issues/1274 - self.send_command( - "AUTH", - auth_args[-1], - check_health=False, - ) + self.send_command("AUTH", auth_args[-1], check_health=False) auth_response = self.read_response() if str_if_bytes(auth_response) != "OK": diff --git a/redis/credentials.py b/redis/credentials.py index 76635488ba..def93e5ab2 100644 --- a/redis/credentials.py +++ b/redis/credentials.py @@ -1,73 +1,26 @@ -from typing import Callable, Optional, Union +from typing import Optional, Tuple, Union class CredentialProvider: - def __init__( - self, - username: Union[str, None] = "", - password: Union[str, None] = "", - supplier: Optional[Callable] = None, - *supplier_args, - **supplier_kwargs, - ): - """ - Initialize a new Credentials Provider. - :param supplier: a supplier function that returns the username and password. - def supplier(self, arg1, arg2, ...) -> (username, password) - See examples/connection_examples.ipynb - :param supplier_args: arguments to pass to the supplier function - :param supplier_kwargs: keyword arguments to pass to the supplier function - """ - self._username_ = username if username is not None else "" - self._password_ = password if password is not None else "" - self.supplier = supplier - self.supplier_args = supplier_args - self.supplier_kwargs = supplier_kwargs - - def get_credentials(self): - if self.supplier is not None: - self.username, self.password = self.supplier( - self, *self.supplier_args, **self.supplier_kwargs - ) - - return ( - (self._username_, self._password_) - if self._username_ - else (self._password_,) - ) - - @property - def password(self): - if self.supplier is not None and not self._password_: - self.username, self.password = self.supplier( - self, *self.supplier_args, **self.supplier_kwargs - ) - return self._password_ - - @password.setter - def password(self, value): - self._password_ = value if value is not None else "" - - @property - def username(self): - if self.supplier is not None and not self._username_: - self.username, self.password = self.supplier( - self, *self.supplier_args, **self.supplier_kwargs - ) - return self._username_ + """ + Credentials Provider. + """ - @username.setter - def username(self, value): - self._username_ = value if value is not None else "" + def get_credentials(self) -> Union[Tuple[str], Tuple[str, str]]: + raise NotImplementedError("get_credentials must be implemented") -class StaticCredentialProvider(CredentialProvider): +class UsernamePasswordCredentialProvider(CredentialProvider): """ Simple implementation of CredentialProvider that just wraps static username and password. """ - def __init__( - self, username: Union[str, None] = "", password: Union[str, None] = "" - ): - super().__init__(username=username, password=password) + def __init__(self, username: Optional[str] = None, password: Optional[str] = None): + self.username = username + self.password = password + + def get_credentials(self): + if self.username: + return self.username, self.password + return (self.password,) diff --git a/tests/test_commands.py b/tests/test_commands.py index 929b64f6df..1c9a5c27eb 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -10,7 +10,6 @@ import redis from redis import exceptions from redis.client import parse_info -from redis.credentials import StaticCredentialProvider from .conftest import ( _get_client, @@ -96,9 +95,7 @@ def teardown(): # error when switching to the db 9 because we're not authenticated yet # setting the password on the connection itself triggers the # authentication in the connection's `on_connect` method - r.connection.credential_provider = StaticCredentialProvider( - password=temp_pass - ) + r.connection.password = temp_pass except AttributeError: # connection field is not set in Redis Cluster, but that's ok # because the problem discussed above does not apply to Redis Cluster diff --git a/tests/test_credentials.py b/tests/test_credentials.py index 3b69239763..9875a28e1e 100644 --- a/tests/test_credentials.py +++ b/tests/test_credentials.py @@ -1,15 +1,42 @@ +import functools import random import string +from typing import Optional, Tuple, Union import pytest import redis from redis import AuthenticationError, DataError, ResponseError -from redis.credentials import CredentialProvider, StaticCredentialProvider +from redis.credentials import CredentialProvider, UsernamePasswordCredentialProvider from redis.utils import str_if_bytes from tests.conftest import _get_client, skip_if_redis_enterprise +class NoPassCredProvider(CredentialProvider): + def get_credentials(self) -> Union[Tuple[str], Tuple[str, str]]: + return "username", "" + + +class RandomAuthCredProvider(CredentialProvider): + def __init__(self, user: Optional[str], endpoint: str): + self.user = user + self.endpoint = endpoint + + @functools.lru_cache(maxsize=10) + def get_credentials(self) -> Union[tuple[str, str], tuple[str]]: + def get_random_string(length): + letters = string.ascii_lowercase + result_str = "".join(random.choice(letters) for i in range(length)) + return result_str + + if self.user: + auth_token: str = get_random_string(5) + self.user + "_" + self.endpoint + return self.user, auth_token + else: + auth_token: str = get_random_string(5) + self.endpoint + return (auth_token,) + + def init_acl_user(r, request, username, password): # reset the user r.acl_deluser(username) @@ -73,55 +100,42 @@ def teardown(): class TestCredentialsProvider: @skip_if_redis_enterprise() - def test_credential_provider_without_supplier_only_pass(self, r, request): + def test_only_pass_without_creds_provider(self, r, request): # test for default user (`username` is supposed to be optional) password = "password" - creds_provider = CredentialProvider(password=password) init_required_pass(r, request, password) - assert r.auth(creds_provider.password) is True + assert r.auth(password) is True - r2 = _get_client( - redis.Redis, request, flushdb=False, credential_provider=creds_provider - ) + r2 = _get_client(redis.Redis, request, flushdb=False, password=password) - assert r2.ping() + assert r2.ping() is True @skip_if_redis_enterprise() - def test_credential_provider_without_supplier_acl_user_and_pass(self, r, request): + def test_user_and_pass_without_creds_provider(self, r, request): + """ + Test backward compatibility with username and password + """ # test for other users username = "username" password = "password" init_acl_user(r, request, username, password) - creds_provider = CredentialProvider(username, password) r2 = _get_client( - redis.Redis, request, flushdb=False, credential_provider=creds_provider + redis.Redis, request, flushdb=False, username=username, password=password ) assert r2.ping() is True - @pytest.mark.parametrize("username", ["username", ""]) + @pytest.mark.parametrize("username", ["username", None]) @skip_if_redis_enterprise() @pytest.mark.onlynoncluster def test_credential_provider_with_supplier(self, r, request, username): - import functools - - @functools.lru_cache(maxsize=10) - def auth_supplier(self, user, endpoint): - def get_random_string(length): - letters = string.ascii_lowercase - result_str = "".join(random.choice(letters) for i in range(length)) - return result_str - - auth_token = get_random_string(5) + user + "_" + endpoint - return user, auth_token - - creds_provider = CredentialProvider( - supplier=auth_supplier, + creds_provider = RandomAuthCredProvider( user=username, endpoint="localhost", ) - password = creds_provider.password + + password = creds_provider.get_credentials()[-1] if username: init_acl_user(r, request, username, password) @@ -135,26 +149,24 @@ def get_random_string(length): assert r2.ping() is True def test_credential_provider_no_password_success(self, r, request): - def creds_provider(self): - return "username", "" - init_acl_user(r, request, "username", "") - creds_provider = CredentialProvider(supplier=creds_provider) r2 = _get_client( - redis.Redis, request, flushdb=False, credential_provider=creds_provider + redis.Redis, + request, + flushdb=False, + credential_provider=NoPassCredProvider(), ) assert r2.ping() is True @pytest.mark.onlynoncluster def test_credential_provider_no_password_error(self, r, request): - def bad_creds_provider(self): - return "username", "" - init_acl_user(r, request, "username", "password") - creds_provider = CredentialProvider(supplier=bad_creds_provider) with pytest.raises(AuthenticationError) as e: _get_client( - redis.Redis, request, flushdb=False, credential_provider=creds_provider + redis.Redis, + request, + flushdb=False, + credential_provider=NoPassCredProvider(), ) assert e.match("invalid username-password") @@ -163,7 +175,7 @@ def test_password_and_username_together_with_cred_provider_raise_error( self, r, request ): init_acl_user(r, request, "username", "password") - cred_provider = StaticCredentialProvider( + cred_provider = UsernamePasswordCredentialProvider( username="username", password="password" ) with pytest.raises(DataError) as e: @@ -190,7 +202,7 @@ def test_change_username_password_on_existing_connection(self, r, request): r2 = _get_client( redis.Redis, request, flushdb=False, username=username, password=password ) - assert r2.ping() + assert r2.ping() is True conn = r2.connection_pool.get_connection("_") conn.send_command("PING") assert str_if_bytes(conn.read_response()) == "PONG" @@ -205,11 +217,11 @@ def test_change_username_password_on_existing_connection(self, r, request): conn.password = None -class TestStaticCredentialProvider: - def test_static_credential_provider_acl_user_and_pass(self, r, request): +class TestUsernamePasswordCredentialProvider: + def test_user_pass_credential_provider_acl_user_and_pass(self, r, request): username = "username" password = "password" - provider = StaticCredentialProvider(username, password) + provider = UsernamePasswordCredentialProvider(username, password) assert provider.username == username assert provider.password == password assert provider.get_credentials() == (username, password) @@ -219,10 +231,10 @@ def test_static_credential_provider_acl_user_and_pass(self, r, request): ) assert r2.ping() is True - def test_static_credential_provider_only_password(self, r, request): + def test_user_pass_provider_only_password(self, r, request): password = "password" - provider = StaticCredentialProvider(password=password) - assert provider.username == "" + provider = UsernamePasswordCredentialProvider(password=password) + assert provider.username is None assert provider.password == password assert provider.get_credentials() == (password,) From 6303243db51e5a543c7e5527fe0b602a943a7be6 Mon Sep 17 00:00:00 2001 From: Bar Shaul Date: Wed, 2 Nov 2022 16:29:55 +0200 Subject: [PATCH 16/23] Fixing tuple type to Tuple --- tests/test_credentials.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_credentials.py b/tests/test_credentials.py index 9875a28e1e..5bb1ddb786 100644 --- a/tests/test_credentials.py +++ b/tests/test_credentials.py @@ -23,7 +23,7 @@ def __init__(self, user: Optional[str], endpoint: str): self.endpoint = endpoint @functools.lru_cache(maxsize=10) - def get_credentials(self) -> Union[tuple[str, str], tuple[str]]: + def get_credentials(self) -> Union[Tuple[str, str], Tuple[str]]: def get_random_string(length): letters = string.ascii_lowercase result_str = "".join(random.choice(letters) for i in range(length)) From 057ed82437a87d235edce01ca10beea85c0cd5b8 Mon Sep 17 00:00:00 2001 From: Bar Shaul Date: Wed, 2 Nov 2022 16:39:41 +0200 Subject: [PATCH 17/23] Fixing optional string members in UsernamePasswordCredentialProvider --- redis/client.py | 3 ++- redis/connection.py | 5 +++-- redis/credentials.py | 4 ++-- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/redis/client.py b/redis/client.py index fb7bb60509..1148b30bec 100755 --- a/redis/client.py +++ b/redis/client.py @@ -5,6 +5,7 @@ import time import warnings from itertools import chain +from typing import Optional from redis.commands import ( CoreCommands, @@ -939,7 +940,7 @@ def __init__( username=None, retry=None, redis_connect_func=None, - credential_provider: CredentialProvider = None, + credential_provider: Optional[CredentialProvider] = None, ): """ Initialize a new Redis client. diff --git a/redis/connection.py b/redis/connection.py index 47dfdb79a8..6cd79ffe6f 100755 --- a/redis/connection.py +++ b/redis/connection.py @@ -8,6 +8,7 @@ from itertools import chain from queue import Empty, Full, LifoQueue from time import time +from typing import Optional from urllib.parse import parse_qs, unquote, urlparse from redis.backoff import NoBackoff @@ -503,7 +504,7 @@ def __init__( username=None, retry=None, redis_connect_func=None, - credential_provider: CredentialProvider = None, + credential_provider: Optional[CredentialProvider] = None, ): """ Initialize a new Connection. @@ -1062,7 +1063,7 @@ def __init__( client_name=None, retry=None, redis_connect_func=None, - credential_provider: CredentialProvider = None, + credential_provider: Optional[CredentialProvider] = None, ): """ Initialize a new UnixDomainSocketConnection. diff --git a/redis/credentials.py b/redis/credentials.py index def93e5ab2..7ba26dcde1 100644 --- a/redis/credentials.py +++ b/redis/credentials.py @@ -17,8 +17,8 @@ class UsernamePasswordCredentialProvider(CredentialProvider): """ def __init__(self, username: Optional[str] = None, password: Optional[str] = None): - self.username = username - self.password = password + self.username = username or "" + self.password = password or "" def get_credentials(self): if self.username: From ba91b0f70dc2a0e19eb1703838dd689fdab4eda5 Mon Sep 17 00:00:00 2001 From: Bar Shaul Date: Wed, 2 Nov 2022 16:50:36 +0200 Subject: [PATCH 18/23] Fixed credential test --- tests/test_credentials.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/test_credentials.py b/tests/test_credentials.py index 5bb1ddb786..9aeb1ef1d5 100644 --- a/tests/test_credentials.py +++ b/tests/test_credentials.py @@ -213,8 +213,6 @@ def test_change_username_password_on_existing_connection(self, r, request): conn.username = new_username conn.send_command("PING") assert str_if_bytes(conn.read_response()) == "PONG" - conn.username = None - conn.password = None class TestUsernamePasswordCredentialProvider: @@ -234,7 +232,7 @@ def test_user_pass_credential_provider_acl_user_and_pass(self, r, request): def test_user_pass_provider_only_password(self, r, request): password = "password" provider = UsernamePasswordCredentialProvider(password=password) - assert provider.username is None + assert provider.username == "" assert provider.password == password assert provider.get_credentials() == (password,) From 622390101ddf124fddd32d79f9be41089942f9e3 Mon Sep 17 00:00:00 2001 From: Bar Shaul Date: Wed, 9 Nov 2022 18:16:05 +0200 Subject: [PATCH 19/23] Added credential provider support to AsyncRedis --- docs/examples/asyncio_examples.ipynb | 129 ++++++----- redis/asyncio/client.py | 3 + redis/asyncio/cluster.py | 7 +- redis/asyncio/connection.py | 43 ++-- redis/connection.py | 19 +- redis/credentials.py | 5 +- tests/test_asyncio/test_credentials.py | 286 +++++++++++++++++++++++++ 7 files changed, 413 insertions(+), 79 deletions(-) create mode 100644 tests/test_asyncio/test_credentials.py diff --git a/docs/examples/asyncio_examples.ipynb b/docs/examples/asyncio_examples.ipynb index dab7a96ae9..855255c88d 100644 --- a/docs/examples/asyncio_examples.ipynb +++ b/docs/examples/asyncio_examples.ipynb @@ -21,11 +21,6 @@ { "cell_type": "code", "execution_count": 1, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, "outputs": [ { "name": "stdout", @@ -41,27 +36,29 @@ "connection = redis.Redis()\n", "print(f\"Ping successful: {await connection.ping()}\")\n", "await connection.close()" - ] + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } }, { "cell_type": "markdown", + "source": [ + "If you supply a custom `ConnectionPool` that is supplied to several `Redis` instances, you may want to disconnect the connection pool explicitly. Disconnecting the connection pool simply disconnects all connections hosted in the pool." + ], "metadata": { + "collapsed": false, "pycharm": { "name": "#%% md\n" } - }, - "source": [ - "If you supply a custom `ConnectionPool` that is supplied to several `Redis` instances, you may want to disconnect the connection pool explicitly. Disconnecting the connection pool simply disconnects all connections hosted in the pool." - ] + } }, { "cell_type": "code", "execution_count": 2, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, "outputs": [], "source": [ "import redis.asyncio as redis\n", @@ -70,15 +67,16 @@ "await connection.close()\n", "# Or: await connection.close(close_connection_pool=False)\n", "await connection.connection_pool.disconnect()" - ] - }, - { - "cell_type": "markdown", + ], "metadata": { + "collapsed": false, "pycharm": { - "name": "#%% md\n" + "name": "#%%\n" } - }, + } + }, + { + "cell_type": "markdown", "source": [ "## Transactions (Multi/Exec)\n", "\n", @@ -87,16 +85,17 @@ "The commands will not be reflected in Redis until execute() is called & awaited.\n", "\n", "Usually, when performing a bulk operation, taking advantage of a “transaction” (e.g., Multi/Exec) is to be desired, as it will also add a layer of atomicity to your bulk operation." - ] + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%% md\n" + } + } }, { "cell_type": "code", "execution_count": 3, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, "outputs": [], "source": [ "import redis.asyncio as redis\n", @@ -106,25 +105,31 @@ " ok1, ok2 = await (pipe.set(\"key1\", \"value1\").set(\"key2\", \"value2\").execute())\n", "assert ok1\n", "assert ok2" - ] + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } }, { "cell_type": "markdown", - "metadata": {}, "source": [ "## Pub/Sub Mode\n", "\n", "Subscribing to specific channels:" - ] + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%% md\n" + } + } }, { "cell_type": "code", "execution_count": 4, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, "outputs": [ { "name": "stdout", @@ -165,23 +170,29 @@ " await r.publish(\"channel:1\", STOPWORD)\n", "\n", " await future" - ] + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } }, { "cell_type": "markdown", - "metadata": {}, "source": [ "Subscribing to channels matching a glob-style pattern:" - ] + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%% md\n" + } + } }, { "cell_type": "code", "execution_count": 5, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, "outputs": [ { "name": "stdout", @@ -223,11 +234,16 @@ " await r.publish(\"channel:1\", STOPWORD)\n", "\n", " await future" - ] + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } }, { "cell_type": "markdown", - "metadata": {}, "source": [ "## Sentinel Client\n", "\n", @@ -236,16 +252,17 @@ "Calling aioredis.sentinel.Sentinel.master_for or aioredis.sentinel.Sentinel.slave_for methods will return Redis clients connected to specified services monitored by Sentinel.\n", "\n", "Sentinel client will detect failover and reconnect Redis clients automatically." - ] + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%% md\n" + } + } }, { "cell_type": "code", "execution_count": null, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, "outputs": [], "source": [ "import asyncio\n", @@ -260,7 +277,13 @@ "assert ok\n", "val = await r.get(\"key\")\n", "assert val == b\"value\"" - ] + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } } ], "metadata": { @@ -284,4 +307,4 @@ }, "nbformat": 4, "nbformat_minor": 1 -} +} \ No newline at end of file diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index 0e40ed70f8..86ce83de33 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -46,6 +46,7 @@ list_or_args, ) from redis.compat import Protocol, TypedDict +from redis.credentials import CredentialProvider from redis.exceptions import ( ConnectionError, ExecAbortError, @@ -174,6 +175,7 @@ def __init__( retry: Optional[Retry] = None, auto_close_connection_pool: bool = True, redis_connect_func=None, + credential_provider: Optional[CredentialProvider] = None, ): """ Initialize a new Redis client. @@ -199,6 +201,7 @@ def __init__( "db": db, "username": username, "password": password, + "credential_provider": credential_provider, "socket_timeout": socket_timeout, "encoding": encoding, "encoding_errors": encoding_errors, diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index 8d34b9ad21..1300f9c41c 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -39,6 +39,7 @@ ) from redis.commands import READ_COMMANDS, AsyncRedisClusterCommands from redis.crc import REDIS_CLUSTER_HASH_SLOTS, key_slot +from redis.credentials import CredentialProvider from redis.exceptions import ( AskError, BusyLoadingError, @@ -215,10 +216,11 @@ def __init__( reinitialize_steps: int = 10, cluster_error_retry_attempts: int = 3, connection_error_retry_attempts: int = 5, - max_connections: int = 2**31, + max_connections: int = 2 ** 31, # Client related kwargs db: Union[str, int] = 0, path: Optional[str] = None, + credential_provider: Optional[CredentialProvider] = None, username: Optional[str] = None, password: Optional[str] = None, client_name: Optional[str] = None, @@ -265,6 +267,7 @@ def __init__( "connection_class": Connection, "parser_class": ClusterParser, # Client related kwargs + "credential_provider": credential_provider, "username": username, "password": password, "client_name": client_name, @@ -792,7 +795,7 @@ def __init__( port: Union[str, int], server_type: Optional[str] = None, *, - max_connections: int = 2**31, + max_connections: int = 2 ** 31, connection_class: Type[Connection] = Connection, **connection_kwargs: Any, ) -> None: diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index b64bd125eb..96c89ba293 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -29,6 +29,7 @@ from redis.asyncio.retry import Retry from redis.backoff import NoBackoff from redis.compat import Protocol, TypedDict +from redis.credentials import CredentialProvider, UsernamePasswordCredentialProvider from redis.exceptions import ( AuthenticationError, AuthenticationWrongNumberOfArgsError, @@ -416,6 +417,7 @@ class Connection: "db", "username", "client_name", + "credential_provider", "password", "socket_timeout", "socket_connect_timeout", @@ -465,14 +467,23 @@ def __init__( retry: Optional[Retry] = None, redis_connect_func: Optional[ConnectCallbackT] = None, encoder_class: Type[Encoder] = Encoder, + credential_provider: Optional[CredentialProvider] = None, ): + if (username or password) and credential_provider is not None: + raise DataError( + "'username' and 'password' cannot be passed along with 'credential_" + "provider'. Please provide only one of the following arguments: \n" + "1. 'password' and (optional) 'username'\n" + "2. 'credential_provider'" + ) self.pid = os.getpid() self.host = host self.port = int(port) self.db = db - self.username = username self.client_name = client_name + self.credential_provider = credential_provider self.password = password + self.username = username self.socket_timeout = socket_timeout self.socket_connect_timeout = socket_connect_timeout or socket_timeout or None self.socket_keepalive = socket_keepalive @@ -637,14 +648,13 @@ async def on_connect(self) -> None: """Initialize the connection, authenticate and select a database""" self._parser.on_connect(self) - # if username and/or password are set, authenticate - if self.username or self.password: - auth_args: Union[Tuple[str], Tuple[str, str]] - if self.username: - auth_args = (self.username, self.password or "") - else: - # Mypy bug: https://github.com/python/mypy/issues/10944 - auth_args = (self.password or "",) + # if credential provider or username and/or password are set, authenticate + if self.credential_provider or (self.username or self.password): + cred_provider = ( + self.credential_provider + or UsernamePasswordCredentialProvider(self.username, self.password) + ) + auth_args = cred_provider.get_credentials() # avoid checking health here -- PING will fail if we try # to check the health prior to the AUTH await self.send_command("AUTH", *auth_args, check_health=False) @@ -656,7 +666,7 @@ async def on_connect(self) -> None: # server seems to be < 6.0.0 which expects a single password # arg. retry auth with just the password. # https://github.com/andymccurdy/redis-py/issues/1274 - await self.send_command("AUTH", self.password, check_health=False) + await self.send_command("AUTH", auth_args[-1], check_health=False) auth_response = await self.read_response() if str_if_bytes(auth_response) != "OK": @@ -1014,18 +1024,27 @@ def __init__( client_name: str = None, retry: Optional[Retry] = None, redis_connect_func=None, + credential_provider: Optional[CredentialProvider] = None, ): """ Initialize a new UnixDomainSocketConnection. To specify a retry policy, first set `retry_on_timeout` to `True` then set `retry` to a valid `Retry` object """ + if (username or password) and credential_provider is not None: + raise DataError( + "'username' and 'password' cannot be passed along with 'credential_" + "provider'. Please provide only one of the following arguments: \n" + "1. 'password' and (optional) 'username'\n" + "2. 'credential_provider'" + ) self.pid = os.getpid() self.path = path self.db = db - self.username = username self.client_name = client_name + self.credential_provider = credential_provider self.password = password + self.username = username self.socket_timeout = socket_timeout self.socket_connect_timeout = socket_connect_timeout or socket_timeout or None self.retry_on_timeout = retry_on_timeout @@ -1235,7 +1254,7 @@ def __init__( max_connections: Optional[int] = None, **connection_kwargs, ): - max_connections = max_connections or 2**31 + max_connections = max_connections or 2 ** 31 if not isinstance(max_connections, int) or max_connections < 0: raise ValueError('"max_connections" must be a positive integer') diff --git a/redis/connection.py b/redis/connection.py index 6cd79ffe6f..28eac5681f 100755 --- a/redis/connection.py +++ b/redis/connection.py @@ -513,11 +513,6 @@ def __init__( `retry` to a valid `Retry` object. To retry on TimeoutError, `retry_on_timeout` can also be set to `True`. """ - self.pid = os.getpid() - self.host = host - self.port = int(port) - self.db = db - self.client_name = client_name if (username or password) and credential_provider is not None: raise DataError( "'username' and 'password' cannot be passed along with 'credential_" @@ -525,7 +520,11 @@ def __init__( "1. 'password' and (optional) 'username'\n" "2. 'credential_provider'" ) - + self.pid = os.getpid() + self.host = host + self.port = int(port) + self.db = db + self.client_name = client_name self.credential_provider = credential_provider self.password = password self.username = username @@ -1072,10 +1071,6 @@ def __init__( `retry` to a valid `Retry` object. To retry on TimeoutError, `retry_on_timeout` can also be set to `True`. """ - self.pid = os.getpid() - self.path = path - self.db = db - self.client_name = client_name if (username or password) and credential_provider is not None: raise DataError( "'username' and 'password' cannot be passed along with 'credential_" @@ -1083,6 +1078,10 @@ def __init__( "1. 'password' and (optional) 'username'\n" "2. 'credential_provider'" ) + self.pid = os.getpid() + self.path = path + self.db = db + self.client_name = client_name self.credential_provider = credential_provider self.password = password self.username = username diff --git a/redis/credentials.py b/redis/credentials.py index 7ba26dcde1..50c8e32676 100644 --- a/redis/credentials.py +++ b/redis/credentials.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple, Union +from typing import Coroutine, Optional, Tuple, Union class CredentialProvider: @@ -16,7 +16,8 @@ class UsernamePasswordCredentialProvider(CredentialProvider): username and password. """ - def __init__(self, username: Optional[str] = None, password: Optional[str] = None): + def __init__(self, username: Optional[str] = None, + password: Optional[str] = None): self.username = username or "" self.password = password or "" diff --git a/tests/test_asyncio/test_credentials.py b/tests/test_asyncio/test_credentials.py new file mode 100644 index 0000000000..4c1103870e --- /dev/null +++ b/tests/test_asyncio/test_credentials.py @@ -0,0 +1,286 @@ +import asyncio +import functools +import random +import string +from typing import Optional, Tuple, Union + +import pytest +import pytest_asyncio + +import redis +from redis import AuthenticationError, DataError, ResponseError +from redis.asyncio.connection import HiredisParser +from redis.credentials import CredentialProvider, UsernamePasswordCredentialProvider +from redis.utils import str_if_bytes +from tests.conftest import skip_if_redis_enterprise + + +@pytest_asyncio.fixture() +async def r_acl_teardown(r: redis.Redis): + """ + A special fixture which removes the provided names from the database after use + """ + usernames = [] + + def factory(username): + usernames.append(username) + return r + + yield factory + for username in usernames: + await r.acl_deluser(username) + + +@pytest_asyncio.fixture() +async def r_required_pass_teardown(r: redis.Redis): + """ + A special fixture which removes the provided password from the database after use + """ + passwords = [] + + def factory(username): + passwords.append(username) + return r + + yield factory + for password in passwords: + try: + await r.auth(password) + except (ResponseError, AuthenticationError): + await r.auth("default", "") + await r.config_set("requirepass", "") + + +class NoPassCredProvider(CredentialProvider): + def get_credentials(self) -> Union[Tuple[str], Tuple[str, str]]: + return "username", "" + + +class AsyncRandomAuthCredProvider(CredentialProvider): + def __init__(self, user: Optional[str], endpoint: str): + self.user = user + self.endpoint = endpoint + + @functools.lru_cache(maxsize=10) + def get_credentials(self) -> Union[Tuple[str, str], Tuple[str]]: + def get_random_string(length): + letters = string.ascii_lowercase + result_str = "".join(random.choice(letters) for i in range(length)) + return result_str + + if self.user: + auth_token: str = get_random_string(5) + self.user + "_" + self.endpoint + return self.user, auth_token + else: + auth_token: str = get_random_string(5) + self.endpoint + return (auth_token,) + + +async def init_acl_user(r, username, password): + # reset the user + await r.acl_deluser(username) + if password: + assert ( + await r.acl_setuser( + username, + enabled=True, + passwords=["+" + password], + keys="~*", + commands=[ + "+ping", + "+command", + "+info", + "+select", + "+flushdb", + "+cluster", + ], + ) + is True + ) + else: + assert ( + await r.acl_setuser( + username, + enabled=True, + keys="~*", + commands=[ + "+ping", + "+command", + "+info", + "+select", + "+flushdb", + "+cluster", + ], + nopass=True, + ) + is True + ) + + +async def init_required_pass(r, password): + await r.config_set("requirepass", password) + + +@pytest.mark.asyncio +class TestCredentialsProvider: + @skip_if_redis_enterprise() + async def test_only_pass_without_creds_provider( + self, r_required_pass_teardown, create_redis + ): + # test for default user (`username` is supposed to be optional) + password = "password" + r = r_required_pass_teardown(password) + await init_required_pass(r, password) + assert await r.auth(password) is True + + r2 = await create_redis(flushdb=False, password=password) + + assert await r2.ping() is True + + @skip_if_redis_enterprise() + async def test_user_and_pass_without_creds_provider( + self, r_acl_teardown, create_redis + ): + """ + Test backward compatibility with username and password + """ + # test for other users + username = "username" + password = "password" + r = r_acl_teardown(username) + await init_acl_user(r, username, password) + r2 = await create_redis(flushdb=False, username=username, password=password) + + assert await r2.ping() is True + + @pytest.mark.parametrize("username", ["username", None]) + @skip_if_redis_enterprise() + @pytest.mark.onlynoncluster + async def test_credential_provider_with_supplier( + self, r_acl_teardown, r_required_pass_teardown, create_redis, username + ): + creds_provider = AsyncRandomAuthCredProvider( + user=username, + endpoint="localhost", + ) + + auth_args = creds_provider.get_credentials() + password = auth_args[-1] + + if username: + r = r_acl_teardown(username) + await init_acl_user(r, username, password) + else: + r = r_required_pass_teardown(password) + await init_required_pass(r, password) + + r2 = await create_redis(flushdb=False, credential_provider=creds_provider) + + assert await r2.ping() is True + + async def test_async_credential_provider_no_password_success( + self, r_acl_teardown, create_redis + ): + username = "username" + r = r_acl_teardown(username) + await init_acl_user(r, username, "") + r2 = await create_redis( + flushdb=False, + credential_provider=NoPassCredProvider(), + ) + assert await r2.ping() is True + + @pytest.mark.onlynoncluster + async def test_credential_provider_no_password_error( + self, r_acl_teardown, create_redis + ): + username = "username" + r = r_acl_teardown(username) + await init_acl_user(r, username, "password") + with pytest.raises(AuthenticationError) as e: + await create_redis( + flushdb=False, + credential_provider=NoPassCredProvider(), + single_connection_client=True, + ) + assert e.match("invalid username-password") + assert await r.acl_deluser(username) + + @pytest.mark.onlynoncluster + async def test_password_and_username_together_with_cred_provider_raise_error( + self, r_acl_teardown, create_redis + ): + username = "username" + r = r_acl_teardown(username) + await init_acl_user(r, username, "password") + cred_provider = UsernamePasswordCredentialProvider( + username="username", password="password" + ) + with pytest.raises(DataError) as e: + await create_redis( + flushdb=False, + username="username", + password="password", + credential_provider=cred_provider, + single_connection_client=True, + ) + assert e.match( + "'username' and 'password' cannot be passed along with " + "'credential_provider'." + ) + + @pytest.mark.onlynoncluster + async def test_change_username_password_on_existing_connection( + self, r_acl_teardown, create_redis + ): + username = "origin_username" + password = "origin_password" + new_username = "new_username" + new_password = "new_password" + r = r_acl_teardown(username) + await init_acl_user(r, username, password) + r2 = await create_redis(flushdb=False, username=username, password=password) + assert await r2.ping() is True + conn = await r2.connection_pool.get_connection("_") + await conn.send_command("PING") + assert str_if_bytes(await conn.read_response()) == "PONG" + assert conn.username == username + assert conn.password == password + await init_acl_user(r, new_username, new_password) + conn.password = new_password + conn.username = new_username + await conn.send_command("PING") + assert str_if_bytes(await conn.read_response()) == "PONG" + + +@pytest.mark.asyncio +class TestUsernamePasswordCredentialProvider: + async def test_user_pass_credential_provider_acl_user_and_pass( + self, r_acl_teardown, create_redis + ): + username = "username" + password = "password" + r = r_acl_teardown(username) + provider = UsernamePasswordCredentialProvider(username, password) + assert provider.username == username + assert provider.password == password + assert provider.get_credentials() == (username, password) + await init_acl_user(r, provider.username, provider.password) + r2 = await create_redis(flushdb=False, credential_provider=provider) + assert await r2.ping() is True + + async def test_user_pass_provider_only_password( + self, r_required_pass_teardown, create_redis + ): + password = "password" + provider = UsernamePasswordCredentialProvider(password=password) + r = r_required_pass_teardown(password) + assert provider.username == "" + assert provider.password == password + assert provider.get_credentials() == (password,) + + await init_required_pass(r, password) + + r2 = await create_redis(flushdb=False, credential_provider=provider) + assert await r2.auth(provider.password) is True + assert await r2.ping() is True From b951e19b4605efccd123bc471b5ebb61028dc59c Mon Sep 17 00:00:00 2001 From: dvora-h <67596500+dvora-h@users.noreply.github.com> Date: Thu, 10 Nov 2022 09:35:59 +0200 Subject: [PATCH 20/23] linters --- tests/test_asyncio/test_credentials.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_asyncio/test_credentials.py b/tests/test_asyncio/test_credentials.py index 4c1103870e..a37c69f236 100644 --- a/tests/test_asyncio/test_credentials.py +++ b/tests/test_asyncio/test_credentials.py @@ -9,7 +9,6 @@ import redis from redis import AuthenticationError, DataError, ResponseError -from redis.asyncio.connection import HiredisParser from redis.credentials import CredentialProvider, UsernamePasswordCredentialProvider from redis.utils import str_if_bytes from tests.conftest import skip_if_redis_enterprise From 72c366d4e44785df38394e7dc24b28da6c4cec6f Mon Sep 17 00:00:00 2001 From: dvora-h <67596500+dvora-h@users.noreply.github.com> Date: Thu, 10 Nov 2022 09:40:34 +0200 Subject: [PATCH 21/23] linters --- redis/credentials.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/redis/credentials.py b/redis/credentials.py index 50c8e32676..4dd523fd9c 100644 --- a/redis/credentials.py +++ b/redis/credentials.py @@ -1,4 +1,4 @@ -from typing import Coroutine, Optional, Tuple, Union +from typing import Optional, Tuple, Union class CredentialProvider: From 4b35cb28ab9eaa176fe48203b55b56948e55e94d Mon Sep 17 00:00:00 2001 From: dvora-h <67596500+dvora-h@users.noreply.github.com> Date: Thu, 10 Nov 2022 09:40:57 +0200 Subject: [PATCH 22/23] linters --- tests/test_asyncio/test_credentials.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_asyncio/test_credentials.py b/tests/test_asyncio/test_credentials.py index a37c69f236..8e213cdb26 100644 --- a/tests/test_asyncio/test_credentials.py +++ b/tests/test_asyncio/test_credentials.py @@ -1,4 +1,3 @@ -import asyncio import functools import random import string From 4c82551995adc84e025ef28d1ee4a60a40178841 Mon Sep 17 00:00:00 2001 From: dvora-h Date: Thu, 10 Nov 2022 11:47:06 +0200 Subject: [PATCH 23/23] linters - black --- redis/asyncio/cluster.py | 4 ++-- redis/asyncio/connection.py | 2 +- redis/credentials.py | 3 +-- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index ffcdbacb9e..57aafbd69f 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -217,7 +217,7 @@ def __init__( reinitialize_steps: int = 10, cluster_error_retry_attempts: int = 3, connection_error_retry_attempts: int = 5, - max_connections: int = 2 ** 31, + max_connections: int = 2**31, # Client related kwargs db: Union[str, int] = 0, path: Optional[str] = None, @@ -862,7 +862,7 @@ def __init__( port: Union[str, int], server_type: Optional[str] = None, *, - max_connections: int = 2 ** 31, + max_connections: int = 2**31, connection_class: Type[Connection] = Connection, **connection_kwargs: Any, ) -> None: diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index a1072ad7e9..df066c4763 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -1254,7 +1254,7 @@ def __init__( max_connections: Optional[int] = None, **connection_kwargs, ): - max_connections = max_connections or 2 ** 31 + max_connections = max_connections or 2**31 if not isinstance(max_connections, int) or max_connections < 0: raise ValueError('"max_connections" must be a positive integer') diff --git a/redis/credentials.py b/redis/credentials.py index 4dd523fd9c..7ba26dcde1 100644 --- a/redis/credentials.py +++ b/redis/credentials.py @@ -16,8 +16,7 @@ class UsernamePasswordCredentialProvider(CredentialProvider): username and password. """ - def __init__(self, username: Optional[str] = None, - password: Optional[str] = None): + def __init__(self, username: Optional[str] = None, password: Optional[str] = None): self.username = username or "" self.password = password or ""