Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Text to id changes #1338

Merged
merged 4 commits into from
Mar 28, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 46 additions & 72 deletions sdv/data_processing/data_processor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Single table data processing."""

import itertools
import json
import logging
import warnings
Expand Down Expand Up @@ -68,7 +67,7 @@ class DataProcessor:
missing_value_replacement='mean',
model_missing_values=False,
),
'text': rdt.transformers.RegexGenerator()
'id': rdt.transformers.RegexGenerator()
}
_DTYPE_TO_SDTYPE = {
'i': 'numerical',
Expand Down Expand Up @@ -107,7 +106,6 @@ def __init__(self, metadata, enforce_rounding=True, enforce_min_max_values=True,
self._primary_key = self.metadata.primary_key
self._prepared_for_fitting = False
self._keys = deepcopy(self.metadata.alternate_keys)
self._keys_generators = {}
if self._primary_key:
self._keys.append(self._primary_key)

Expand Down Expand Up @@ -343,7 +341,7 @@ def _update_transformers_by_sdtypes(self, sdtype, transformer):
self._transformers_by_sdtype[sdtype] = transformer

@staticmethod
def create_anonymized_transformer(sdtype, column_metadata):
def create_anonymized_transformer(sdtype, column_metadata, enforce_uniqueness):
"""Create an instance of an ``AnonymizedFaker``.

Read the extra keyword arguments from the ``column_metadata`` and use them to create
Expand All @@ -354,6 +352,9 @@ def create_anonymized_transformer(sdtype, column_metadata):
Sematic data type or a ``Faker`` function name.
column_metadata (dict):
A dictionary representing the rest of the metadata for the given ``sdtype``.
enforce_uniqueness (bool):
If ``True`` overwrite ``enforce_uniqueness`` with ``True`` to ensure unique
generation for primary keys.

Returns:
Instance of ``rdt.transformers.pii.AnonymizedFaker``.
Expand All @@ -363,42 +364,42 @@ def create_anonymized_transformer(sdtype, column_metadata):
if key not in ['pii', 'sdtype']:
kwargs[key] = value

if enforce_uniqueness:
kwargs['enforce_uniqueness'] = True

return get_anonymized_transformer(sdtype, kwargs)

def create_key_transformer(self, column_name, sdtype, column_metadata):
"""Create an instance for the primary key or alternate key transformer.
def create_regex_generator(self, column_name, sdtype, column_metadata, is_numeric):
"""Create a ``RegexGenerator`` for the ``id`` columns.

Read the keyword arguments from the ``column_metadata`` and use them to create
an instance of an ``RegexGenerator`` or ``AnonymizedFaker`` transformer with
``enforce_uniqueness`` set to ``True``.
an instance of a ``RegexGenerator``. If ``regex_format`` is not present in the
metadata a default ``[0-1a-z]{5}`` will be used for object like data and an increasing
integer from ``0`` will be used for numerical data. Also if the column name is a primary
key or alternate key this will enforce the values to be unique.

Args:
column_name (str):
Name of the column.
sdtype (str):
Sematic data type or a ``Faker`` function name.
column_metadata (dict):
A dictionary representing the rest of the metadata for the given ``sdtype``.
is_numeric (boolean):
A boolean representing whether or not data type is numeric or not.

Returns:
transformer:
Instance of ``rdt.transformers.text.RegexGenerator`` or
``rdt.transformers.pii.AnonymizedFaker`` with ``enforce_uniqueness`` set to
``True``.
"""
if sdtype == 'numerical':
self._keys_generators[column_name] = itertools.count()
return None

if sdtype == 'text':
regex_format = column_metadata.get('regex_format', '[A-Za-z]{5}')
transformer = rdt.transformers.RegexGenerator(
regex_format=regex_format,
enforce_uniqueness=True
)

else:
kwargs = deepcopy(column_metadata)
kwargs['enforce_uniqueness'] = True
transformer = self.create_anonymized_transformer(sdtype, kwargs)
default_regex_format = r'\d{30}' if is_numeric else '[0-1a-z]{5}'
regex_format = column_metadata.get('regex_format', default_regex_format)
transformer = rdt.transformers.RegexGenerator(
regex_format=regex_format,
enforce_uniqueness=(column_name in self._keys)
)

return transformer

Expand Down Expand Up @@ -451,11 +452,23 @@ def _create_config(self, data, columns_created_by_constraints):
pii = column_metadata.get('pii', sdtype not in self._DEFAULT_TRANSFORMERS_BY_SDTYPE)
sdtypes[column] = 'pii' if pii else sdtype

if column in self._keys:
transformers[column] = self.create_key_transformer(column, sdtype, column_metadata)
if sdtype == 'id':
is_numeric = pd.api.types.is_numeric_dtype(data[column].dtype)
transformers[column] = self.create_regex_generator(
column,
sdtype,
column_metadata,
is_numeric
)
sdtypes[column] = 'text'

elif pii:
transformers[column] = self.create_anonymized_transformer(sdtype, column_metadata)
enforce_uniqueness = bool(column in self._keys)
transformers[column] = self.create_anonymized_transformer(
sdtype,
column_metadata,
enforce_uniqueness
)

elif sdtype in self._transformers_by_sdtype:
transformers[column] = self._get_transformer_instance(sdtype, column_metadata)
Expand Down Expand Up @@ -608,15 +621,7 @@ def fit(self, data):

def reset_sampling(self):
"""Reset the sampling state for the anonymized columns and primary keys."""
# Resetting the transformers manually until fixed on RDT
for transformer in self._hyper_transformer.field_transformers.values():
if transformer is not None:
transformer.reset_randomization()

self._keys_generators = {
key: itertools.count()
for key in self._keys_generators
}
self._hyper_transformer.reset_randomization()

def generate_keys(self, num_rows, reset_keys=False):
"""Generate the columns that are identified as ``keys``.
Expand All @@ -631,32 +636,11 @@ def generate_keys(self, num_rows, reset_keys=False):
pandas.DataFrame:
A dataframe with the newly generated primary keys of the size ``num_rows``.
"""
anonymized_keys = []
dataframes = {}
for key in self._keys:
if self._hyper_transformer.field_transformers.get(key) is None:
if reset_keys:
self._keys_generators[key] = itertools.count()

dataframes[key] = pd.DataFrame({
key: [next(self._keys_generators[key]) for _ in range(num_rows)]
})

else:
anonymized_keys.append(key)

# Add ``reset_keys`` for RDT once the version is updated.
if anonymized_keys:
anonymized_dataframe = self._hyper_transformer.create_anonymized_columns(
num_rows=num_rows,
column_names=anonymized_keys,
)
if dataframes:
return pd.concat(list(dataframes.values()) + [anonymized_dataframe], axis=1)

return anonymized_dataframe

return pd.concat(dataframes.values(), axis=1)
generated_keys = self._hyper_transformer.create_anonymized_columns(
num_rows=num_rows,
column_names=self._keys,
)
return generated_keys

def transform(self, data, is_condition=False):
"""Transform the given data.
Expand All @@ -683,17 +667,7 @@ def transform(self, data, is_condition=False):

LOGGER.debug(f'Transforming table {self.table_name}')
if self._keys and not is_condition:
keys_to_drop = []
for key in self._keys:
if key == self._primary_key:
drop_primary_key = bool(self._keys_generators.get(key))
data = data.set_index(self._primary_key, drop=drop_primary_key)

elif self._keys_generators.get(key):
keys_to_drop.append(key)

if keys_to_drop:
data = data.drop(keys_to_drop, axis=1)
data = data.set_index(self._primary_key, drop=False)

try:
transformed = self._hyper_transformer.transform_subset(data)
Expand Down
2 changes: 1 addition & 1 deletion sdv/datasets/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def _validate_output_folder(output_folder_name):


def _download(modality, dataset_name):
dataset_url = BUCKET_URL + '/' + modality.upper() + '/' + dataset_name + '.zip'
dataset_url = f'{BUCKET_URL}/{modality.upper()}/{dataset_name}.zip'
LOGGER.info(f'Downloading dataset {dataset_name} from {dataset_url}')
try:
response = urllib.request.urlopen(dataset_url)
Expand Down
10 changes: 5 additions & 5 deletions sdv/metadata/metadata_upgrader.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,14 @@ def _upgrade_columns_and_keys(old_metadata):
column_meta['datetime_format'] = datetime_format

elif old_type == 'id':
column_meta['sdtype'] = 'id'
if subtype == 'integer':
column_meta['sdtype'] = 'numerical'

regex_format = r'\d{30}'
else:
column_meta['sdtype'] = 'text'
regex_format = field_meta.get('regex', '[A-Za-z]{5}')
if regex_format:
column_meta['regex_format'] = regex_format

if not field_meta.get('pii'):
column_meta['regex_format'] = regex_format

if field != primary_key and field_meta.get('ref') is None:
alternate_keys.append(field)
Expand Down
27 changes: 14 additions & 13 deletions sdv/metadata/single_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class SingleTableMetadata:
'datetime': frozenset(['datetime_format']),
'categorical': frozenset(['order', 'order_by']),
'boolean': frozenset([]),
'text': frozenset(['regex_format']),
'id': frozenset(['regex_format']),
}

_DTYPES_TO_SDTYPES = {
Expand Down Expand Up @@ -98,13 +98,13 @@ def _validate_categorical(column_name, **kwargs):
)

@staticmethod
def _validate_text(column_name, **kwargs):
def _validate_id(column_name, **kwargs):
regex = kwargs.get('regex_format', '')
try:
re.compile(regex)
except Exception as exception:
raise InvalidMetadataError(
f"Invalid regex format string '{regex}' for text column '{column_name}'."
f"Invalid regex format string '{regex}' for id column '{column_name}'."
) from exception

@staticmethod
Expand Down Expand Up @@ -151,8 +151,8 @@ def _validate_column(self, column_name, sdtype, **kwargs):
self._validate_numerical(column_name, **kwargs)
elif sdtype == 'datetime':
self._validate_datetime(column_name, **kwargs)
elif sdtype == 'text':
self._validate_text(column_name, **kwargs)
elif sdtype == 'id':
self._validate_id(column_name, **kwargs)
elif 'pii' in kwargs:
self._validate_pii(column_name, **kwargs)

Expand Down Expand Up @@ -291,15 +291,16 @@ def _validate_datatype(column_name):
isinstance(column_name, tuple) and all(isinstance(i, str) for i in column_name)

def _validate_keys_sdtype(self, keys, key_type):
"""Validate that no key is of type 'categorical'."""
bad_sdtypes = ('boolean', 'categorical')
categorical_keys = sorted(
{key for key in keys if self.columns[key]['sdtype'] in bad_sdtypes}
)
if categorical_keys:
"""Validate that each key is of type 'id' or a valid Faker function."""
bad_keys = set()
for key in keys:
if not (self.columns[key]['sdtype'] == 'id' or
is_faker_function(self.columns[key]['sdtype'])):
bad_keys.add(key)
if bad_keys:
raise InvalidMetadataError(
f"The {key_type}_keys {categorical_keys} cannot be type 'categorical' or "
"'boolean'."
f"The {key_type}_keys {sorted(bad_keys)} must be type 'id' or "
'a valid Faker function.'
)

def _validate_key(self, column_name, key_type):
Expand Down
15 changes: 8 additions & 7 deletions tests/integration/data_processing/test_data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@

import numpy as np
from rdt.transformers import (
AnonymizedFaker, BinaryEncoder, FloatFormatter, LabelEncoder, UnixTimestampEncoder)
AnonymizedFaker, BinaryEncoder, FloatFormatter, LabelEncoder, RegexGenerator,
UnixTimestampEncoder)

from sdv.data_processing import DataProcessor
from sdv.data_processing.datetime_formatter import DatetimeFormatter
Expand Down Expand Up @@ -91,7 +92,7 @@ def test_data_processor_with_anonymized_columns_and_primary_key():
metadata.update_column('occupation', sdtype='job', pii=True)

# Add primary key field
metadata.add_column('id', sdtype='text', regex_format='ID_\\d{4}[0-9]')
metadata.add_column('id', sdtype='id', regex_format='ID_\\d{4}[0-9]')
metadata.set_primary_key('id')

# Add id
Expand Down Expand Up @@ -147,7 +148,7 @@ def test_data_processor_with_primary_key_numerical():
adult_metadata.detect_from_dataframe(data=data)

# Add primary key field
adult_metadata.add_column('id', sdtype='numerical')
adult_metadata.add_column('id', sdtype='id')
adult_metadata.set_primary_key('id')

# Add id
Expand Down Expand Up @@ -188,11 +189,11 @@ def test_data_processor_with_alternate_keys():
adult_metadata.detect_from_dataframe(data=data)

# Add primary key field
adult_metadata.add_column('id', sdtype='numerical')
adult_metadata.add_column('id', sdtype='id')
adult_metadata.set_primary_key('id')

adult_metadata.add_column('secondary_id', sdtype='numerical')
adult_metadata.update_column('fnlwgt', sdtype='text', regex_format='ID_\\d{4}[0-9]')
adult_metadata.add_column('secondary_id', sdtype='id')
adult_metadata.update_column('fnlwgt', sdtype='id', regex_format='ID_\\d{4}[0-9]')

adult_metadata.add_alternate_keys(['secondary_id', 'fnlwgt'])

Expand Down Expand Up @@ -249,7 +250,7 @@ def test_data_processor_prepare_for_fitting():
'mba_spec': LabelEncoder,
'employability_perc': FloatFormatter,
'placed': LabelEncoder,
'student_id': None,
'student_id': RegexGenerator,
'experience_years': FloatFormatter,
'duration': LabelEncoder,
'salary': FloatFormatter,
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/datasets/test_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,4 +67,4 @@ def test_get_available_demos_multi_table():
]
})
expected_table['size_MB'] = expected_table['size_MB'].astype(float).round(2)
assert len(expected_table.merge(tables_info)) == len(expected_table)
assert len(expected_table.merge(tables_info, on='dataset_name')) == len(expected_table)
10 changes: 5 additions & 5 deletions tests/integration/metadata/test_multi_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,19 +91,19 @@ def test_upgrade_metadata(tmp_path):
'nesreca': {
'primary_key': 'id_nesreca',
'columns': {
'upravna_enota': {'sdtype': 'numerical'},
'id_nesreca': {'sdtype': 'numerical'}
'upravna_enota': {'sdtype': 'id', 'regex_format': r'\d{30}'},
'id_nesreca': {'sdtype': 'id', 'regex_format': r'\d{30}'}
}
},
'oseba': {
'columns': {
'upravna_enota': {'sdtype': 'numerical'},
'id_nesreca': {'sdtype': 'numerical'}
'upravna_enota': {'sdtype': 'id', 'regex_format': r'\d{30}'},
'id_nesreca': {'sdtype': 'id', 'regex_format': r'\d{30}'}
}
},
'upravna_enota': {
'primary_key': 'id_upravna_enota',
'columns': {'id_upravna_enota': {'sdtype': 'numerical'}}
'columns': {'id_upravna_enota': {'sdtype': 'id', 'regex_format': r'\d{30}'}}
}
},
'relationships': [
Expand Down
Loading