Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update RDT to v1.0.0 #923

Merged
merged 22 commits into from
Aug 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions sdv/constraints/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from copulas.multivariate.gaussian import GaussianMultivariate
from copulas.univariate import GaussianUnivariate
from rdt import HyperTransformer
from rdt.transformers import OneHotEncoder

from sdv.constraints.errors import MissingConstraintColumnError
from sdv.errors import ConstraintsNotMetError
Expand Down Expand Up @@ -323,9 +324,9 @@ def fit(self, table_data):
Table data.
"""
data_to_model = table_data[self.constraint_columns]
self._hyper_transformer = HyperTransformer(
default_data_type_transformers={'categorical': 'OneHotEncodingTransformer'}
)
self._hyper_transformer = HyperTransformer()
self._hyper_transformer.detect_initial_config(data_to_model)
self._hyper_transformer.update_transformers_by_sdtype({'categorical': OneHotEncoder})
transformed_data = self._hyper_transformer.fit_transform(data_to_model)
self._model = GaussianMultivariate(distribution=GaussianUnivariate)
self._model.fit(transformed_data)
Expand Down
4 changes: 2 additions & 2 deletions sdv/constraints/tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,7 +682,7 @@ def _fit(self, table_data):
"""Fit the constraint.

The fit process consists in generating the ``transformed_column`` name and determine
whether or not the data is ``datetime``.
whether or not the data is ``UnixTimestampEncoder``.

Args:
table_data (pandas.DataFrame):
Expand Down Expand Up @@ -845,7 +845,7 @@ def _get_is_datetime(self, table_data):
return is_datetime

def _fit(self, table_data):
"""Learn whether or not the ``column_name`` is ``datetime``.
"""Learn whether or not the ``column_name`` is ``UnixTimestampEncoder``.

Args:
table_data (pandas.DataFrame):
Expand Down
40 changes: 18 additions & 22 deletions sdv/lite/tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,9 @@ def __init__(self, name=None, metadata=None, constraints=None):
self._model = GaussianCopula(
table_metadata=metadata,
constraints=constraints,
categorical_transformer='categorical_fuzzy',
categorical_transformer='FrequencyEncoder_noised',
default_distribution='gaussian',
rounding=None,
learn_rounding_scheme=False,
)

# Decide if transformers should model the null column or not.
Expand All @@ -78,31 +78,27 @@ def __init__(self, name=None, metadata=None, constraints=None):

# If transformers should model the null column, pass None to let each transformer
# decide if it's necessary or not.
transformer_null_column = None if self._null_column else False
transformer_model_missing_values = bool(self._null_column)

dtype_transformers = {
'i': rdt.transformers.NumericalTransformer(
dtype=np.int64,
nan='mean' if self._null_column else None,
null_column=transformer_null_column,
min_value='auto',
max_value='auto',
'i': rdt.transformers.FloatFormatter(
missing_value_replacement='mean' if self._null_column else None,
model_missing_values=transformer_model_missing_values,
enforce_min_max_values=True,
),
'f': rdt.transformers.NumericalTransformer(
dtype=np.float64,
nan='mean' if self._null_column else None,
null_column=transformer_null_column,
min_value='auto',
max_value='auto',
'f': rdt.transformers.FloatFormatter(
missing_value_replacement='mean' if self._null_column else None,
model_missing_values=transformer_model_missing_values,
enforce_min_max_values=True,
),
'O': rdt.transformers.CategoricalTransformer(fuzzy=True),
'b': rdt.transformers.BooleanTransformer(
nan=-1 if self._null_column else None,
null_column=transformer_null_column,
'O': rdt.transformers.FrequencyEncoder(add_noise=True),
'b': rdt.transformers.BinaryEncoder(
missing_value_replacement=-1 if self._null_column else None,
model_missing_values=transformer_model_missing_values,
),
'M': rdt.transformers.DatetimeTransformer(
nan='mean' if self._null_column else None,
null_column=transformer_null_column,
'M': rdt.transformers.UnixTimestampEncoder(
missing_value_replacement='mean' if self._null_column else None,
model_missing_values=transformer_model_missing_values,
),
}
self._model._metadata._dtype_transformers.update(dtype_transformers)
Expand Down
89 changes: 5 additions & 84 deletions sdv/metadata/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import numpy as np
import pandas as pd
from rdt import HyperTransformer, transformers
from rdt import HyperTransformer

from sdv.constraints import Constraint
from sdv.metadata import visualization
Expand Down Expand Up @@ -401,83 +401,6 @@ def get_dtypes(self, table_name, ids=False, errors=None):

return dtypes

def _get_pii_fields(self, table_name):
"""Get the ``pii_category`` for each field of the table that contains PII.

Args:
table_name (str):
Table name for which to get the pii fields.

Returns:
dict:
pii field names and categories.
"""
pii_fields = dict()
for name, field in self.get_table_meta(table_name)['fields'].items():
if field['type'] == 'categorical' and field.get('pii', False):
pii_fields[name] = field['pii_category']

return pii_fields

@staticmethod
def _get_transformers(dtypes, pii_fields):
"""Create the transformer instances needed to process the given dtypes.

Temporary drop-in replacement of ``HyperTransformer._analyze`` method,
before RDT catches up.

Args:
dtypes (dict):
mapping of field names and dtypes.
pii_fields (dict):
mapping of pii field names and categories.

Returns:
dict:
mapping of field names and transformer instances.
"""
transformers_dict = dict()
for name, dtype in dtypes.items():
dtype = np.dtype(dtype)
if dtype.kind == 'i':
transformer = transformers.NumericalTransformer(dtype=int)
elif dtype.kind == 'f':
transformer = transformers.NumericalTransformer(dtype=float)
elif dtype.kind == 'O':
anonymize = pii_fields.get(name)
transformer = transformers.CategoricalTransformer(anonymize=anonymize)
elif dtype.kind == 'b':
transformer = transformers.BooleanTransformer()
elif dtype.kind == 'M':
transformer = transformers.DatetimeTransformer()
else:
raise ValueError('Unsupported dtype: {}'.format(dtype))

LOGGER.info('Loading transformer %s for field %s',
transformer.__class__.__name__, name)
transformers_dict[name] = transformer

return transformers_dict

def _load_hyper_transformer(self, table_name):
"""Create and return a new ``rdt.HyperTransformer`` instance for a table.

First get the ``dtypes`` and ``pii fields`` from a given table, then use
those to build a transformer dictionary to be used by the ``HyperTransformer``.

Args:
table_name (str):
Table name for which to load the HyperTransformer.

Returns:
rdt.HyperTransformer:
Instance of ``rdt.HyperTransformer`` for the given table.
"""
dtypes = self.get_dtypes(table_name)
pii_fields = self._get_pii_fields(table_name)
transformers_dict = self._get_transformers(dtypes, pii_fields)
return HyperTransformer(field_transformers=transformers_dict)

def transform(self, table_name, data):
"""Transform data for a given table.

Expand All @@ -495,14 +418,12 @@ def transform(self, table_name, data):
"""
hyper_transformer = self._hyper_transformers.get(table_name)
if hyper_transformer is None:
hyper_transformer = self._load_hyper_transformer(table_name)
fields = list(hyper_transformer.transformers.keys())
hyper_transformer.fit(data[fields])
hyper_transformer = HyperTransformer()
hyper_transformer.detect_initial_config(data)
hyper_transformer.fit(data)
self._hyper_transformers[table_name] = hyper_transformer

hyper_transformer = self._hyper_transformers.get(table_name)
fields = list(hyper_transformer.transformers.keys())
return hyper_transformer.transform(data[fields])
return hyper_transformer.transform(data)

def reverse_transform(self, table_name, data):
"""Reverse the transformed data for a given table.
Expand Down
Loading