Skip to content

Commit

Permalink
Add warning to sdtype text
Browse files Browse the repository at this point in the history
  • Loading branch information
pvk-developer committed Sep 12, 2024
1 parent 73d1a78 commit 2a8d014
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 10 deletions.
21 changes: 11 additions & 10 deletions rdt/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import importlib
import inspect
import warnings
from collections import defaultdict
from copy import deepcopy
from functools import lru_cache
Expand Down Expand Up @@ -32,6 +33,7 @@
AnonymizedFaker,
PseudoAnonymizedFaker,
)
from rdt.transformers.utils import WarnDict

__all__ = [
'BaseTransformer',
Expand Down Expand Up @@ -88,16 +90,15 @@ def get_transformer_name(transformer):
for transformer in BaseTransformer.get_subclasses()
}


DEFAULT_TRANSFORMERS = {
'boolean': UniformEncoder(),
'categorical': UniformEncoder(),
'datetime': UnixTimestampEncoder(),
'id': RegexGenerator(),
'numerical': FloatFormatter(),
'pii': AnonymizedFaker(),
'text': RegexGenerator(),
}
DEFAULT_TRANSFORMERS = WarnDict(
boolean=UniformEncoder(),
categorical=UniformEncoder(),
datetime=UnixTimestampEncoder(),
id=RegexGenerator(),
numerical=FloatFormatter(),
pii=AnonymizedFaker(),
text=RegexGenerator(),
)


@lru_cache()
Expand Down
34 changes: 34 additions & 0 deletions rdt/transformers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,3 +274,37 @@ def learn_rounding_digits(data):
name,
)
return None


class WarnDict(dict):
"""Custom dictionary to raise a deprecation warning."""

def get(self, key):
"""Retrun the value for key if key is in the dictionary, else default.
If the key is `text` raises a `DeprecationWarning` stating that it will be
phased out.
"""
if key == 'text':
warnings.warn(
"The sdtype 'text' is deprecated and will be phased out. "
"Please use 'id' instead.",
DeprecationWarning,
)

return super().get(key)

def __getitem__(self, key):
"""Retrun the value for key if key is in the dictionary.
If the key is `text` raises a `DeprecationWarning` stating that it will be
phased out.
"""
if key == 'text':
warnings.warn(
"The sdtype 'text' is deprecated and will be phased out. "
"Please use 'id' instead.",
DeprecationWarning,
)

return super().__getitem__(key)
23 changes: 23 additions & 0 deletions tests/unit/transformers/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pytest

from rdt.transformers.utils import (
WarnDict,
_any,
_max_repeat,
check_nan_in_transform,
Expand Down Expand Up @@ -295,3 +296,25 @@ def test_learn_rounding_digits_nullable_numerical_pandas_dtypes():
for column in data.columns:
output = learn_rounding_digits(data[column])
assert output == expected_output[column]


def test_warn_dict():
"""Test that ``WarnDict`` will raise a warning when called with `text`."""
# Setup
instance = WarnDict()
instance['text'] = 'text_transformer'

# Run
warning_msg = (
"The sdtype 'text' is deprecated and will be phased out. Please use 'id' instead."
)

with pytest.warns(DeprecationWarning, match=warning_msg):
result_access = instance['text']

with pytest.warns(DeprecationWarning, match=warning_msg):
result_get = instance.get('text')

# Assert
assert result_access == 'text_transformer'
assert result_get == 'text_transformer'

0 comments on commit 2a8d014

Please sign in to comment.