Skip to content

Commit

Permalink
feat(data-masking): add custom mask functionalities (#5837)
Browse files Browse the repository at this point in the history
* add custom mask functionalities

* change flags name to more intuitive

* fix type check error

* add draft documentation

* change doc examples

* style: format code with black

* fix format base

* add tests for new masks

* sub header for custom mask in docs

* masking rules to handle complex nest

* add test for masking rules

* modifications based on the feedback

* mypy and tests modification

* create more tests

* Refactoring tests

* Refactoring tests

* Refactoring tests

* Adding docstring + arg parameter

* Adding docstring + arg parameter

* Removing unnecessary code

* Removing unnecessary code

* Removing unnecessary code

---------

Co-authored-by: Leandro Damascena <lcdama@amazon.pt>
  • Loading branch information
anafalcao and leandrodamascena authored Feb 11, 2025
1 parent 8ff6cac commit 6ff9f11
Show file tree
Hide file tree
Showing 10 changed files with 981 additions and 256 deletions.
235 changes: 204 additions & 31 deletions aws_lambda_powertools/utilities/data_masking/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
import functools
import logging
import warnings
from typing import TYPE_CHECKING, Any, Callable, Mapping, Sequence, overload
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Callable, Mapping, Sequence

from jsonpath_ng.ext import parse

Expand All @@ -18,6 +19,7 @@
DataMaskingUnsupportedTypeError,
)
from aws_lambda_powertools.utilities.data_masking.provider import BaseProvider
from aws_lambda_powertools.warnings import PowertoolsUserWarning

if TYPE_CHECKING:
from numbers import Number
Expand Down Expand Up @@ -67,11 +69,39 @@ def encrypt(
provider_options: dict | None = None,
**encryption_context: str,
) -> str:
"""
Encrypt data using the configured encryption provider.
Parameters
----------
data : dict, Mapping, Sequence, or Number
The data to encrypt.
provider_options : dict, optional
Provider-specific options for encryption.
**encryption_context : str
Additional key-value pairs for encryption context.
Returns
-------
str
The encrypted data as a base64-encoded string.
Example
--------
encryption_provider = AWSEncryptionSDKProvider(keys=[KMS_KEY_ARN])
data_masker = DataMasking(provider=encryption_provider)
encrypted = data_masker.encrypt({"secret": "value"})
"""
return self._apply_action(
data=data,
fields=None,
action=self.provider.encrypt,
provider_options=provider_options or {},
dynamic_mask=None,
custom_mask=None,
regex_pattern=None,
mask_format=None,
**encryption_context,
)

Expand All @@ -81,37 +111,104 @@ def decrypt(
provider_options: dict | None = None,
**encryption_context: str,
) -> Any:
"""
Decrypt data using the configured encryption provider.
Parameters
----------
data : dict, Mapping, Sequence, or Number
The data to encrypt.
provider_options : dict, optional
Provider-specific options for encryption.
**encryption_context : str
Additional key-value pairs for encryption context.
Returns
-------
str
The encrypted data as a base64-encoded string.
Example
--------
encryption_provider = AWSEncryptionSDKProvider(keys=[KMS_KEY_ARN])
data_masker = DataMasking(provider=encryption_provider)
encrypted = data_masker.decrypt(encrypted_data)
"""

return self._apply_action(
data=data,
fields=None,
action=self.provider.decrypt,
provider_options=provider_options or {},
dynamic_mask=None,
custom_mask=None,
regex_pattern=None,
mask_format=None,
**encryption_context,
)

@overload
def erase(self, data, fields: None) -> str: ...

@overload
def erase(self, data: list, fields: list[str]) -> list[str]: ...

@overload
def erase(self, data: tuple, fields: list[str]) -> tuple[str]: ...
def erase(
self,
data: Any,
fields: list[str] | None = None,
*,
dynamic_mask: bool | None = None,
custom_mask: str | None = None,
regex_pattern: str | None = None,
mask_format: str | None = None,
masking_rules: dict | None = None,
) -> Any:
"""
Erase or mask sensitive data in the input.
@overload
def erase(self, data: dict, fields: list[str]) -> dict: ...
Parameters
----------
data : Any
The data to be erased or masked.
fields : list of str, optional
List of field names to be erased or masked.
dynamic_mask : bool, optional
Whether to use dynamic masking.
custom_mask : str, optional
Custom mask to apply instead of the default.
regex_pattern : str, optional
Regular expression pattern for identifying data to mask.
mask_format : str, optional
Format string for the mask.
masking_rules : dict, optional
Dictionary of custom masking rules.
def erase(self, data: Sequence | Mapping, fields: list[str] | None = None) -> str | list[str] | tuple[str] | dict:
return self._apply_action(data=data, fields=fields, action=self.provider.erase)
Returns
-------
Any
The data with sensitive information erased or masked.
"""
if masking_rules:
return self._apply_masking_rules(data=data, masking_rules=masking_rules)
else:
return self._apply_action(
data=data,
fields=fields,
action=self.provider.erase,
dynamic_mask=dynamic_mask,
custom_mask=custom_mask,
regex_pattern=regex_pattern,
mask_format=mask_format,
)

def _apply_action(
self,
data,
fields: list[str] | None,
action: Callable,
provider_options: dict | None = None,
**encryption_context: str,
):
dynamic_mask: bool | None = None,
custom_mask: str | None = None,
regex_pattern: str | None = None,
mask_format: str | None = None,
**kwargs: Any,
) -> Any:
"""
Helper method to determine whether to apply a given action to the entire input data
or to specific fields if the 'fields' argument is specified.
Expand All @@ -127,8 +224,6 @@ def _apply_action(
and returns the modified value.
provider_options : dict
Provider specific keyword arguments to propagate; used as an escape hatch.
encryption_context: str
Encryption context to use in encrypt and decrypt operations.
Returns
-------
Expand All @@ -143,18 +238,34 @@ def _apply_action(
fields=fields,
action=action,
provider_options=provider_options,
**encryption_context,
dynamic_mask=dynamic_mask,
custom_mask=custom_mask,
regex_pattern=regex_pattern,
mask_format=mask_format,
**kwargs,
)
else:
logger.debug(f"Running action {action.__name__} with the entire data")
return action(data=data, provider_options=provider_options, **encryption_context)
return action(
data=data,
provider_options=provider_options,
dynamic_mask=dynamic_mask,
custom_mask=custom_mask,
regex_pattern=regex_pattern,
mask_format=mask_format,
**kwargs,
)

def _apply_action_to_fields(
self,
data: dict | str,
fields: list,
action: Callable,
provider_options: dict | None = None,
dynamic_mask: bool | None = None,
custom_mask: str | None = None,
regex_pattern: str | None = None,
mask_format: str | None = None,
**encryption_context: str,
) -> dict | str:
"""
Expand Down Expand Up @@ -201,8 +312,10 @@ def _apply_action_to_fields(
new_dict = {'a': {'b': {'c': '*****'}}, 'x': {'y': '*****'}}
```
"""
if not fields:
raise ValueError("Fields parameter cannot be empty")

data_parsed: dict = self._normalize_data_to_parse(fields, data)
data_parsed: dict = self._normalize_data_to_parse(data)

# For in-place updates, json_parse accepts a callback function
# this function must receive 3 args: field_value, fields, field_name
Expand All @@ -211,6 +324,10 @@ def _apply_action_to_fields(
self._call_action,
action=action,
provider_options=provider_options,
dynamic_mask=dynamic_mask,
custom_mask=custom_mask,
regex_pattern=regex_pattern,
mask_format=mask_format,
**encryption_context, # type: ignore[arg-type]
)

Expand All @@ -232,12 +349,6 @@ def _apply_action_to_fields(
# For in-place updates, json_parse accepts a callback function
# that receives 3 args: field_value, fields, field_name
# We create a partial callback to pre-populate known provider options (action, provider opts, enc ctx)
update_callback = functools.partial(
self._call_action,
action=action,
provider_options=provider_options,
**encryption_context, # type: ignore[arg-type]
)

json_parse.update(
data_parsed,
Expand All @@ -246,13 +357,70 @@ def _apply_action_to_fields(

return data_parsed

def _apply_masking_rules(self, data: dict, masking_rules: dict) -> dict:
"""
Apply masking rules to data, supporting both simple field names and complex path expressions.
Args:
data: The dictionary containing data to mask
masking_rules: Dictionary mapping field names or path expressions to masking rules
Returns:
dict: The masked data dictionary
"""
result = deepcopy(data)

for path, rule in masking_rules.items():
try:
jsonpath_expr = parse(f"$.{path}")
matches = jsonpath_expr.find(result)

if not matches:
warnings.warn(f"No matches found for path: {path}", stacklevel=2)
continue

for match in matches:
try:
value = match.value
if value is not None:
masked_value = self.provider.erase(str(value), **rule)
match.full_path.update(result, masked_value)

except Exception as e:
warnings.warn(
f"Error masking value for path {path}: {str(e)}",
category=PowertoolsUserWarning,
stacklevel=2,
)
continue

except Exception as e:
warnings.warn(f"Error processing path {path}: {str(e)}", category=PowertoolsUserWarning, stacklevel=2)
continue

return result

def _mask_nested_field(self, data: dict, field_path: str, mask_function):
keys = field_path.split(".")
current = data
for key in keys[:-1]:
current = current.get(key, {})
if not isinstance(current, dict):
return
if keys[-1] in current:
current[keys[-1]] = self.provider.erase(current[keys[-1]], **mask_function)

@staticmethod
def _call_action(
field_value: Any,
fields: dict[str, Any],
field_name: str,
action: Callable,
provider_options: dict[str, Any] | None = None,
dynamic_mask: bool | None = None,
custom_mask: str | None = None,
regex_pattern: str | None = None,
mask_format: str | None = None,
**encryption_context,
) -> None:
"""
Expand All @@ -270,13 +438,18 @@ def _call_action(
Returns:
- fields[field_name]: Returns the processed field value
"""
fields[field_name] = action(field_value, provider_options=provider_options, **encryption_context)
fields[field_name] = action(
field_value,
provider_options=provider_options,
dynamic_mask=dynamic_mask,
custom_mask=custom_mask,
regex_pattern=regex_pattern,
mask_format=mask_format,
**encryption_context,
)
return fields[field_name]

def _normalize_data_to_parse(self, fields: list, data: str | dict) -> dict:
if not fields:
raise ValueError("No fields specified.")

def _normalize_data_to_parse(self, data: str | dict) -> dict:
if isinstance(data, str):
# Parse JSON string as dictionary
data_parsed = self.json_deserializer(data)
Expand Down
Loading

0 comments on commit 6ff9f11

Please sign in to comment.