Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
pvk-developer committed Sep 17, 2024
1 parent 9f703c5 commit e15a077
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 18 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ exclude = [
'.tox',
'.git',
'__pycache__',
'*.ipynb',
'.ipynb_checkpoints',
'tasks.py',
'tests/contributing.py'
Expand Down Expand Up @@ -235,4 +236,4 @@ convention = "google"

[tool.ruff.lint.pycodestyle]
max-doc-length = 100
max-line-length = 100
max-line-length = 100
38 changes: 24 additions & 14 deletions rdt/transformers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import string
import sys
import warnings
from collections import defaultdict

import numpy as np
import pandas as pd
Expand All @@ -14,6 +15,7 @@
LOGGER = logging.getLogger(__name__)

MAX_DECIMALS = sys.float_info.dig - 1
DEPRECATED_SDTYPES_MAPPING = {'text': 'id'}


def _literal(character, max_repeat):
Expand Down Expand Up @@ -279,32 +281,40 @@ def learn_rounding_digits(data):
class WarnDict(dict):
"""Custom dictionary to raise a deprecation warning."""

def get(self, key):
"""Retrun the value for key if key is in the dictionary, else default.
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._warned = defaultdict()

If the key is `text` raises a `DeprecationWarning` stating that it will be
def get(self, sdtype):
"""Retrun the value for sdtype if sdtype is in the dictionary, else default.
If the sdtype is `text` raises a `DeprecationWarning` stating that it will be
phased out.
"""
if key == 'text':
if sdtype in DEPRECATED_SDTYPES_MAPPING and not self._warned.get(sdtype):
new_sdtype = DEPRECATED_SDTYPES_MAPPING.get(sdtype)
warnings.warn(
"The sdtype 'text' is deprecated and will be phased out. "
"Please use 'id' instead.",
f"The sdtype '{sdtype}' is deprecated and will be phased out. "
f"Please use '{new_sdtype}' instead.",
DeprecationWarning,
)
self._warned[sdtype] = True

return super().get(key)
return super().get(sdtype)

def __getitem__(self, key):
"""Retrun the value for key if key is in the dictionary.
def __getitem__(self, sdtype):
"""Retrun the value for sdtype if sdtype is in the dictionary.
If the key is `text` raises a `DeprecationWarning` stating that it will be
If the sdtype is `text` raises a `DeprecationWarning` stating that it will be
phased out.
"""
if key == 'text':
if sdtype in DEPRECATED_SDTYPES_MAPPING and not self._warned.get(sdtype):
new_sdtype = DEPRECATED_SDTYPES_MAPPING.get(sdtype)
warnings.warn(
"The sdtype 'text' is deprecated and will be phased out. "
"Please use 'id' instead.",
f"The sdtype '{sdtype}' is deprecated and will be phased out. "
f"Please use '{new_sdtype}' instead.",
DeprecationWarning,
)
self._warned[sdtype] = True

return super().__getitem__(key)
return super().get(sdtype)
11 changes: 8 additions & 3 deletions tests/unit/transformers/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import sre_parse
import warnings
from sre_constants import MAXREPEAT
from unittest.mock import patch

Expand Down Expand Up @@ -309,9 +310,13 @@ def test_warn_dict():
with pytest.warns(DeprecationWarning, match=warning_msg):
result_access = instance['text']

with pytest.warns(DeprecationWarning, match=warning_msg):
result_get = instance.get('text')
# Run second time and no warning gets shown
with warnings.catch_warnings(record=True) as record:
result_access_no_warn = instance['text']
result_get_no_warn = instance.get('text')

# Assert
assert len(record) == 0
assert result_access == 'text_transformer'
assert result_get == 'text_transformer'
assert result_access_no_warn == 'text_transformer'
assert result_get_no_warn == 'text_transformer'

0 comments on commit e15a077

Please sign in to comment.