Skip to content

Commit

Permalink
Unify all gets to return None if no values present in Redis (#336)
Browse files Browse the repository at this point in the history
* Unify all gets to return None if no values present in Redis

- Improve documentation, clarifying arbitrary data formats

* Update notifications_utils/clients/redis/redis_client.py

Co-authored-by: Andrew <andrew.leith@cds-snc.ca>

---------

Co-authored-by: Andrew <andrew.leith@cds-snc.ca>
  • Loading branch information
whabanks and andrewleith authored Nov 14, 2024
1 parent acfde00 commit b344e5a
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 14 deletions.
2 changes: 1 addition & 1 deletion .github/actions/waffles/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
docopt==0.6.2
Flask==2.3.3
markupsafe==2.1.5
git+https://github.com/cds-snc/notifier-utils.git@52.3.8#egg=notifications-utils
git+https://github.com/cds-snc/notifier-utils.git@52.3.9#egg=notifications-utils
67 changes: 59 additions & 8 deletions notifications_utils/clients/redis/annual_limit.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,16 @@ def annual_limit_status_key(service_id):
return f"annual-limit:{service_id}:status"


def decode_byte_dict(dict: dict, value_type=str):
def decode_byte_dict(byte_dict: dict, value_type=str):
"""
Redis-py returns byte strings for keys and values. This function decodes them to UTF-8 strings.
"""
# Check if expected_value_type is one of the allowed types
if value_type not in {int, float, str}:
raise ValueError("expected_value_type must be int, float, or str")
return {key.decode("utf-8"): value_type(value.decode("utf-8")) for key, value in dict.items() if dict.items()}
if byte_dict is None or not byte_dict.items():
return None
return {key.decode("utf-8"): value_type(value.decode("utf-8")) for key, value in byte_dict.items()}


class RedisAnnualLimit:
Expand All @@ -75,13 +77,21 @@ def init_app(self, app, *args, **kwargs):
pass

def increment_notification_count(self, service_id: str, field: str):
"""Increments the specified daily notification count field for a service.
Fields that can be set: `sms_delivered`, `email_delivered`, `sms_failed`, `email_failed`
Args:
service_id (str): _description_
field (str): _description_
"""
self._redis_client.increment_hash_value(annual_limit_notifications_key(service_id), field)

def get_notification_count(self, service_id: str, field: str):
"""
Retrieves the specified daily notification count for a service. (e.g. SMS_DELIVERED, EMAIL_FAILED, etc.)
"""
return int(self._redis_client.get_hash_field(annual_limit_notifications_key(service_id), field))
count = self._redis_client.get_hash_field(annual_limit_notifications_key(service_id), field)
return count and int(count.decode("utf-8"))

def get_all_notification_counts(self, service_id: str):
"""
Expand All @@ -90,9 +100,11 @@ def get_all_notification_counts(self, service_id: str):
return decode_byte_dict(self._redis_client.get_all_from_hash(annual_limit_notifications_key(service_id)), int)

def reset_all_notification_counts(self, service_ids=None):
"""
Resets all daily notification metrics.
:param: service_ids: list of service_ids to reset, if None, resets all services
"""Resets all daily notification metrics.
Args:
service_ids (Optional): A list of service_ids to reset notification counts for. Resets all services if None.
"""
hashes = (
annual_limit_notifications_key("*")
Expand All @@ -103,6 +115,22 @@ def reset_all_notification_counts(self, service_ids=None):
self._redis_client.delete_hash_fields(hashes=hashes)

def seed_annual_limit_notifications(self, service_id: str, mapping: dict):
"""Seeds annual limit notifications for a service.
Args:
service_id (str): Service to seed annual limit notifications for.
mapping (dict): A dict used to map notification counts to their respective fields formatted as follows
Examples:
`mapping` format:
{
"sms_delivered": int,
"email_delivered": int,
"sms_failed": int,
"email_failed": int
}
"""
self._redis_client.bulk_set_hash_fields(key=annual_limit_notifications_key(service_id), mapping=mapping)

def was_seeded_today(self, service_id):
Expand All @@ -124,18 +152,41 @@ def clear_notification_counts(self, service_id: str):

def set_annual_limit_status(self, service_id: str, field: str, value: datetime):
"""
Sets the status (e.g., 'nearing_limit', 'over_limit') in the annual limits Redis hash.
Sets the specified status field in the annual limits Redis hash for a service.
Fields that can be set: `near_sms_limit`, `near_email_limit`, `over_sms_limit`, `over_email_limit`, `seeded_at`
Args:
service_id (str): The service to set the annual limit status field for
field (str): The field to set in the annual limit status hash.
value (datetime): The date to set the status to
"""
self._redis_client.set_hash_value(annual_limit_status_key(service_id), field, value.strftime("%Y-%m-%d"))

def get_annual_limit_status(self, service_id: str, field: str):
"""
Retrieves the value of a specific annual limit status from the Redis hash.
Fields that can be fetched: `near_sms_limit`, `near_email_limit`, `over_sms_limit`, `over_email_limit`, `seeded_at`
Args:
service_id (str): The service to fetch the annual limit status field for
field (str): The field to fetch from the annual limit status hash values:
`near_sms_limit`, `near_email_limit`, `over_sms_limit`, `over_email_limit`, `seeded_at`
Returns:
str | None: The date the status was set, or None if the status has not been set
"""
response = self._redis_client.get_hash_field(annual_limit_status_key(service_id), field)
return response.decode("utf-8") if response is not None else None
return response and response.decode("utf-8")

def get_all_annual_limit_statuses(self, service_id: str):
"""Retrieves all annual limit status fields for a specified service from Redis
Args:
service_id (str): The service to fetch annual limit statuses for
Returns:
dict | None: A dictionary of annual limit statuses or None if no statuses are found
"""
return decode_byte_dict(self._redis_client.get_all_from_hash(annual_limit_status_key(service_id)))

def clear_annual_limit_statuses(self, service_id: str):
Expand Down
2 changes: 1 addition & 1 deletion notifications_utils/clients/redis/redis_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def bulk_set_hash_fields(self, mapping, pattern=None, key=None, raise_exception=
"""
Bulk set hash fields.
:param pattern: the pattern to match keys
:param mappting: the mapping of fields to set
:param mapping: the mapping of fields to set
:param raise_exception: True if we should allow the exception to bubble up
"""
if self.active:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "notifications-utils"
version = "52.3.8"
version = "52.3.9"
description = "Shared python code for Notification - Provides logging utils etc."
authors = ["Canadian Digital Service"]
license = "MIT license"
Expand Down
33 changes: 30 additions & 3 deletions tests/test_annual_limit.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,18 +113,26 @@ def test_get_notification_count(mock_annual_limit_client, mocked_service_id):
assert result == 1


def test_get_notification_count_returns_none_when_field_does_not_exist(mock_annual_limit_client, mocked_service_id):
assert mock_annual_limit_client.get_notification_count(mocked_service_id, SMS_DELIVERED) is None


def test_get_all_notification_counts(mock_annual_limit_client, mock_notification_count_types, mocked_service_id):
for field in mock_notification_count_types:
mock_annual_limit_client.increment_notification_count(mocked_service_id, field)
assert len(mock_annual_limit_client.get_all_notification_counts(mocked_service_id)) == 4


def test_get_all_notification_counts_returns_none_if_fields_do_not_exist(mock_annual_limit_client, mocked_service_id):
assert mock_annual_limit_client.get_all_notification_counts(mocked_service_id) is None


def test_clear_notification_counts(mock_annual_limit_client, mock_notification_count_types, mocked_service_id):
for field in mock_notification_count_types:
mock_annual_limit_client.increment_notification_count(mocked_service_id, field)
assert len(mock_annual_limit_client.get_all_notification_counts(mocked_service_id)) == 4
mock_annual_limit_client.clear_notification_counts(mocked_service_id)
assert len(mock_annual_limit_client.get_all_notification_counts(mocked_service_id)) == 0
assert mock_annual_limit_client.get_all_notification_counts(mocked_service_id) is None


@pytest.mark.parametrize(
Expand All @@ -147,7 +155,7 @@ def test_bulk_reset_notification_counts(mock_annual_limit_client, mock_notificat
mock_annual_limit_client.reset_all_notification_counts()

for service_id in service_ids:
assert len(mock_annual_limit_client.get_all_notification_counts(service_id)) == 0
assert mock_annual_limit_client.get_all_notification_counts(service_id) is None


def test_set_annual_limit_status(mock_annual_limit_client, mocked_service_id):
Expand All @@ -164,13 +172,28 @@ def test_get_annual_limit_status(mock_annual_limit_client, mocked_service_id):
assert result == near_limit_date.strftime("%Y-%m-%d")


def test_get_annual_limit_status_returns_none_when_fields_do_not_exist(mock_annual_limit_client, mocked_service_id):
assert mock_annual_limit_client.get_annual_limit_status(mocked_service_id, NEAR_SMS_LIMIT) is None


@freeze_time("2024-10-25 12:00:00.000000")
def test_get_all_annual_limit_statuses(mock_annual_limit_client, mock_annual_limit_statuses, mocked_service_id):
for status in mock_annual_limit_statuses:
mock_annual_limit_client.set_annual_limit_status(mocked_service_id, status, datetime.utcnow())
assert len(mock_annual_limit_client.get_all_annual_limit_statuses(mocked_service_id)) == 4


def test_get_all_annual_limit_statuses_returns_none_when_fields_do_not_exist(mock_annual_limit_client, mocked_service_id):
assert mock_annual_limit_client.get_all_annual_limit_statuses(mocked_service_id) is None


@freeze_time("2024-10-25 12:00:00.000000")
def test_clear_annual_limit_statuses(mock_annual_limit_client, mock_annual_limit_statuses, mocked_service_id):
for status in mock_annual_limit_statuses:
mock_annual_limit_client.set_annual_limit_status(mocked_service_id, status, datetime.utcnow())
assert len(mock_annual_limit_client.get_all_annual_limit_statuses(mocked_service_id)) == 4
mock_annual_limit_client.clear_annual_limit_statuses(mocked_service_id)
assert len(mock_annual_limit_client.get_all_annual_limit_statuses(mocked_service_id)) == 0
assert mock_annual_limit_client.get_all_annual_limit_statuses(mocked_service_id) is None


@freeze_time("2024-10-25 12:00:00.000000")
Expand All @@ -196,6 +219,10 @@ def test_get_seeded_at(mock_annual_limit_client, seeded_at_value, expected_value
assert result == expected_value


def test_get_seeded_at_returns_none_when_field_does_not_exist(mock_annual_limit_client, mocked_service_id):
assert mock_annual_limit_client.get_seeded_at(mocked_service_id) is None


@freeze_time("2024-10-25 12:00:00.000000")
def test_set_nearing_sms_limit(mock_annual_limit_client, mocked_service_id):
mock_annual_limit_client.set_nearing_sms_limit(mocked_service_id)
Expand Down

0 comments on commit b344e5a

Please sign in to comment.