diff --git a/sdv/metadata/table.py b/sdv/metadata/table.py deleted file mode 100644 index 92ababba3..000000000 --- a/sdv/metadata/table.py +++ /dev/null @@ -1,833 +0,0 @@ -"""Metadata for a single table.""" - -import copy -import json -import logging -import warnings - -import numpy as np -import pandas as pd -import rdt -from faker import Faker - -from sdv.constraints import Constraint -from sdv.constraints.errors import ( - AggregateConstraintsError, FunctionError, MissingConstraintColumnError) -from sdv.metadata.errors import InvalidMetadataError, MetadataNotFittedError -from sdv.metadata.utils import strings_from_regex - -LOGGER = logging.getLogger(__name__) - - -class Table: - """Table Metadata. - - The Metadata class provides a unified layer of abstraction over the metadata - of a single Table, which includes all the necessary details to handle the - table of this data, including the data types, the fields with pii information - and the constraints that affect this data. - - Args: - name (str): - Name of this table. Optional. - field_names (list[str]): - List of names of the fields that need to be modeled - and included in the generated output data. Any additional - fields found in the data will be ignored and will not be - included in the generated output. - If ``None``, all the fields found in the data are used. - field_types (dict[str, dict]): - Dictinary specifying the data types and subtypes - of the fields that will be modeled. Field types and subtypes - combinations must be compatible with the SDV Metadata Schema. - field_transformers (dict[str, str]): - Dictinary specifying which transformers to use for each field. - Available transformers are: - - * ``FloatFormatter``: Uses a ``FloatFormatter`` for numerical data. - * ``FrequencyEncoder``: Uses a ``FrequencyEncoder`` without gaussian noise. - * ``FrequencyEncoder_noised``: Uses a ``FrequencyEncoder`` adding gaussian noise. - * ``OneHotEncoder``: Uses a ``OneHotEncoder``. - * ``LabelEncoder``: Uses a ``LabelEncoder``. - * ``BinaryEncoder``: Uses a ``BinaryEncoder``. - * ``UnixTimestampEncoder``: Uses a ``UnixTimestampEncoder``. - - anonymize_fields (dict[str, str]): - Dict specifying which fields to anonymize and what faker - category they belong to. - primary_key (str): - Name of the field which is the primary key of the table. - constraints (list[Constraint, dict]): - List of Constraint objects or dicts. - dtype_transformers (dict): - Dictionary of transformer templates to be used for the - different data types. The keys must be any of the `dtype.kind` - values, `i`, `f`, `O`, `b` or `M`, and the values must be - either RDT Transformer classes or RDT Transformer instances. - model_kwargs (dict): - Dictionary specifiying the kwargs that need to be used in - each tabular model when working on this table. This dictionary - contains as keys the name of the TabularModel class and as - values a dictionary containing the keyword arguments to use. - This argument exists mostly to ensure that the models are - fitted using the same arguments when the same Table is used - to fit different model instances on different slices of the - same table. - sequence_index (str): - Name of the column that acts as the order index of each - sequence. The sequence index column can be of any type that can - be sorted, such as integer values or datetimes. - entity_columns (list[str]): - Names of the columns which identify different time series - sequences. These will be used to group the data in separated - training examples. - context_columns (list[str]): - The columns in the dataframe which are constant within each - group/entity. These columns will be provided at sampling time - (i.e. the samples will be conditioned on the context variables). - learn_rounding_scheme (bool): - Define rounding scheme for ``FloatFormatter``. If ``True``, the data returned by - ``reverse_transform`` will be rounded to that place. Defaults to ``True``. - enforce_min_max_values (bool): - Specify whether or not to clip the data returned by ``reverse_transform`` of - the numerical transformer, ``FloatFormatter``, to the min and max values seen - during ``fit``. Defaults to ``True``. - """ - - _hyper_transformer = None - _fields_metadata = None - fitted = False - - _ANONYMIZATION_MAPPINGS = {} - _TRANSFORMER_TEMPLATES = { - 'FloatFormatter': rdt.transformers.FloatFormatter( - learn_rounding_scheme=True, - enforce_min_max_values=True, - missing_value_replacement='mean', - model_missing_values=True, - ), - 'FrequencyEncoder': rdt.transformers.FrequencyEncoder, - 'FrequencyEncoder_noised': rdt.transformers.FrequencyEncoder(add_noise=True), - 'OneHotEncoder': rdt.transformers.OneHotEncoder, - 'LabelEncoder': rdt.transformers.LabelEncoder, - 'LabelEncoder_noised': rdt.transformers.LabelEncoder(add_noise=True), - 'BinaryEncoder': rdt.transformers.BinaryEncoder( - missing_value_replacement=-1, - model_missing_values=True - ), - 'UnixTimestampEncoder': rdt.transformers.UnixTimestampEncoder( - missing_value_replacement='mean', - model_missing_values=True, - ) - } - _DTYPE_TRANSFORMERS = { - 'i': 'FloatFormatter', - 'f': 'FloatFormatter', - 'O': 'OneHotEncoder', - 'b': 'BinaryEncoder', - 'M': 'UnixTimestampEncoder', - } - _DTYPES_TO_TYPES = { - 'i': { - 'type': 'numerical', - 'subtype': 'integer', - }, - 'f': { - 'type': 'numerical', - 'subtype': 'float', - }, - 'O': { - 'type': 'categorical', - }, - 'b': { - 'type': 'boolean', - }, - 'M': { - 'type': 'datetime', - } - } - _TYPES_TO_DTYPES = { - ('categorical', None): 'object', - ('boolean', None): 'bool', - ('numerical', None): 'float', - ('numerical', 'float'): 'float', - ('numerical', 'integer'): 'int', - ('datetime', None): 'datetime64', - ('id', None): 'int', - ('id', 'integer'): 'int', - ('id', 'string'): 'str' - } - - @staticmethod - def _get_faker(field_metadata): - """Return the faker object with localisaton set if specified in field_metadata. - - Args: - field_metadata (dict): - Metadata for field to read localisation from if set in `pii_locales`. - - Returns: - Faker object: - The Faker object to anonymize the data in the field using its functions. - """ - pii_locales = field_metadata.get('pii_locales', None) - return Faker(locale=pii_locales) - - @staticmethod - def _get_faker_method(faker, category): - """Return the faker function to anonymize data. - - Args: - faker (Faker object): - The faker object created to get functions from. - category (str or tuple): - Fake category to use. If a tuple is passed, the first element is - the category and the rest are additional arguments for the Faker. - - Returns: - function: - Faker function to generate new fake data instances. - - Raises: - ValueError: - A ``ValueError`` is raised if the faker category we want don't exist. - """ - if isinstance(category, (tuple, list)): - category, *args = category - else: - args = () - - try: - if args: - def _faker(): - return getattr(faker, category)(*args) - - else: - def _faker(): - return getattr(faker, category)() - - return _faker - except AttributeError: - raise ValueError(f'Category "{category}" couldn\'t be found on faker') - - @staticmethod - def _get_fake_values(field_metadata, num_values): - """Return the anonymized values from Faker. - - Args: - field_metadata (dict): - Metadata for field to read localisation from if set in `pii_locales`. - And to read the faker category from `pii_category`. - num_values (int): - Number of values to create. - - Returns: - generator: - Generator containing the anonymized values. - """ - faker = Table._get_faker(field_metadata) - faker_method = Table._get_faker_method(faker, field_metadata['pii_category']) - return ( - faker_method() - for _ in range(num_values) - ) - - def _update_transformer_templates(self, learn_rounding_scheme, enforce_min_max_values): - custom_float_formatter = rdt.transformers.FloatFormatter( - missing_value_replacement='mean', - model_missing_values=True, - learn_rounding_scheme=learn_rounding_scheme, - enforce_min_max_values=enforce_min_max_values - ) - self._transformer_templates.update({ - 'FloatFormatter': custom_float_formatter, - }) - - @staticmethod - def _load_constraints(constraints): - constraints = constraints or [] - loaded_constraints = [] - for constraint in constraints: - if isinstance(constraint, dict): - loaded_constraints.append(Constraint.from_dict(constraint)) - else: - loaded_constraints.append(constraint) - - return loaded_constraints - - def __init__(self, name=None, field_names=None, field_types=None, field_transformers=None, - anonymize_fields=None, primary_key=None, constraints=None, - dtype_transformers=None, model_kwargs=None, sequence_index=None, - entity_columns=None, context_columns=None, - learn_rounding_scheme=True, enforce_min_max_values=True): - self.name = name - self._field_names = field_names - self._field_types = field_types or {} - self._field_transformers = field_transformers or {} - self._anonymize_fields = anonymize_fields or {} - self._model_kwargs = model_kwargs or {} - - self._primary_key = primary_key - self._sequence_index = sequence_index - self._entity_columns = entity_columns or [] - self._context_columns = context_columns or [] - self._constraints = self._load_constraints(constraints) - self._constraints_to_reverse = [] - self._dtype_transformers = self._DTYPE_TRANSFORMERS.copy() - self._transformer_templates = self._TRANSFORMER_TEMPLATES.copy() - self._update_transformer_templates(learn_rounding_scheme, enforce_min_max_values) - - if dtype_transformers: - self._dtype_transformers.update(dtype_transformers) - - def __repr__(self): - return f'Table(name={self.name}, field_names={self._field_names})' - - def get_model_kwargs(self, model_name): - """Return the required model kwargs for the indicated model. - - Args: - model_name (str): - Qualified Name of the model for which model kwargs - are needed. - - Returns: - dict: - Keyword arguments to use on the indicated model. - """ - return copy.deepcopy(self._model_kwargs.get(model_name)) - - def set_model_kwargs(self, model_name, model_kwargs): - """Set the model kwargs used for the indicated model.""" - self._model_kwargs[model_name] = model_kwargs - - def _get_field_dtype(self, field_name, field_metadata): - field_type = field_metadata['type'] - field_subtype = field_metadata.get('subtype') - dtype = self._TYPES_TO_DTYPES.get((field_type, field_subtype)) - if not dtype: - raise InvalidMetadataError( - 'Invalid type and subtype combination for field ' - f'{field_name}: ({field_type}, {field_subtype})' - ) - - return dtype - - def get_fields(self): - """Get fields metadata. - - Returns: - dict: - Dictionary of fields metadata for this table. - """ - return copy.deepcopy(self._fields_metadata) - - def get_dtypes(self, ids=False): - """Get a ``dict`` with the ``dtypes`` for each field of the table. - - Args: - ids (bool): - Whether or not to include the id fields. Defaults to ``False``. - - Returns: - dict: - Dictionary that contains the field names and data types. - """ - dtypes = {} - for name, field_meta in self._fields_metadata.items(): - field_type = field_meta['type'] - - if ids or (field_type != 'id'): - dtypes[name] = self._get_field_dtype(name, field_meta) - - return dtypes - - def _build_fields_metadata(self, data): - """Build all the fields metadata. - - Args: - data (pandas.DataFrame): - Data to be analyzed. - - Returns: - dict: - Dict of valid fields. - - Raises: - ValueError: - If a column from the data analyzed is an unsupported data type - """ - fields_metadata = {} - for field_name in self._field_names: - if field_name not in data: - raise ValueError(f'Field {field_name} not found in given data') - - field_meta = self._field_types.get(field_name) - if field_meta: - dtype = self._get_field_dtype(field_name, field_meta) - else: - dtype = data[field_name].dtype - field_template = self._DTYPES_TO_TYPES.get(dtype.kind) - if field_template is None: - raise ValueError(f'Unsupported dtype {dtype} in column {field_name}') - - field_meta = copy.deepcopy(field_template) - - field_transformer = self._field_transformers.get(field_name) - if field_transformer: - field_meta['transformer'] = field_transformer - else: - field_meta['transformer'] = self._dtype_transformers.get(np.dtype(dtype).kind) - - anonymize_category = self._anonymize_fields.get(field_name) - if anonymize_category: - field_meta['pii'] = True - field_meta['pii_category'] = anonymize_category - - fields_metadata[field_name] = field_meta - - return fields_metadata - - def _get_hypertransformer_config(self, dtypes): - """Create the transformer instances needed to process the given dtypes. - - Args: - dtypes (dict): - mapping of field names and dtypes. - - Returns: - dict: - A dict containing the ``sdtypes`` and ``transformers`` config for the - ``rdt.HyperTransformer``. - """ - transformers = {} - sdtypes = {} - for name, dtype in dtypes.items(): - dtype = np.dtype(dtype).kind - field_metadata = self._fields_metadata.get(name, {}) - transformer_template = field_metadata.get( - 'transformer', self._dtype_transformers[dtype]) - - if transformer_template is None: - sdtypes[name] = self._DTYPES_TO_TYPES.get(dtype, {}).get('type', 'categorical') - transformers[name] = None - continue - - field_metadata['transformer'] = transformer_template - if isinstance(transformer_template, str): - transformer_template = self._transformer_templates[transformer_template] - - if isinstance(transformer_template, type): - transformer = transformer_template() - else: - transformer = copy.deepcopy(transformer_template) - - LOGGER.debug('Loading transformer %s for field %s', - transformer.__class__.__name__, name) - transformers[name] = transformer - sdtypes[name] = self._DTYPES_TO_TYPES.get(dtype, {}).get('type', 'categorical') - - return {'sdtypes': sdtypes, 'transformers': transformers} - - def _fit_constraints(self, data): - errors = [] - for constraint in self._constraints: - try: - constraint.fit(data) - except Exception as e: - errors.append(e) - - if errors: - raise AggregateConstraintsError(errors) - - def _transform_constraints(self, data, is_condition=False): - errors = [] - if not is_condition: - self._constraints_to_reverse = [] - - for constraint in self._constraints: - try: - data = constraint.transform(data) - if not is_condition: - self._constraints_to_reverse.append(constraint) - - except (MissingConstraintColumnError, FunctionError) as e: - if isinstance(e, MissingConstraintColumnError): - warnings.warn( - f'{constraint.__class__.__name__} cannot be transformed because columns: ' - f'{e.missing_columns} were not found. Using the reject sampling approach ' - 'instead.' - ) - else: - warnings.warn( - f'Error transforming {constraint.__class__.__name__}. ' - 'Using the reject sampling approach instead.' - ) - if is_condition: - indices_to_drop = data.columns.isin(constraint.constraint_columns) - columns_to_drop = data.columns.where(indices_to_drop).dropna() - data = data.drop(columns_to_drop, axis=1) - - except Exception as e: - errors.append(e) - - if errors: - raise AggregateConstraintsError(errors) - - return data - - def _fit_transform_constraints(self, data): - # Fit and validate all constraints first because `transform` might change columns - # making the following constraints invalid - self._fit_constraints(data) - data = self._transform_constraints(data) - - return data - - def _fit_hyper_transformer(self, data, extra_columns): - """Create and return a new ``rdt.HyperTransformer`` instance. - - First get the ``dtypes`` and then use them to build a transformer dictionary - to be used by the ``HyperTransformer``. - - Args: - data (pandas.DataFrame): - Data to transform. - extra_columns (set): - Names of columns that are not in the metadata but that should also - be transformed. In most cases, these are the fields that were added - by previous transformations which the data underwent. - - Returns: - rdt.HyperTransformer - """ - meta_dtypes = self.get_dtypes(ids=False) - dtypes = {} - numerical_extras = [] - for column in data.columns: - if column in meta_dtypes: - dtypes[column] = meta_dtypes[column] - elif column in extra_columns: - dtype_kind = data[column].dtype.kind - if dtype_kind in ('i', 'f'): - numerical_extras.append(column) - else: - dtypes[column] = dtype_kind - - ht_config = self._get_hypertransformer_config(dtypes) - for column in numerical_extras: - dtypes[column] = 'numerical' - ht_config['sdtypes'][column] = 'numerical' - ht_config['transformers'][column] = rdt.transformers.FloatFormatter( - missing_value_replacement='mean', - model_missing_values=True, - ) - - self._hyper_transformer = rdt.HyperTransformer() - self._hyper_transformer.set_config(ht_config) - fit_columns = list(dtypes) - if not data[fit_columns].empty: - self._hyper_transformer.fit(data[fit_columns]) - - @staticmethod - def _get_key_subtype(field_meta): - """Get the appropriate key subtype.""" - field_type = field_meta['type'] - - if field_type == 'categorical': - field_subtype = 'string' - - elif field_type in ('numerical', 'id'): - field_subtype = field_meta['subtype'] - if field_subtype not in ('integer', 'string'): - raise ValueError(f'Invalid field "subtype" for key field: "{field_subtype}"') - - else: - raise ValueError('Invalid field "type" for key field: "{field_type}"') - - return field_subtype - - def set_primary_key(self, primary_key): - """Set the primary key of this table. - - The field must exist and either be an integer or categorical field. - - Args: - primary_key (str or list): - Name of the field(s) to be used as the new primary key. - - Raises: - ValueError: - If the table or the field do not exist or if the field has an - invalid type or subtype. - """ - if primary_key is not None: - fields = primary_key if isinstance(primary_key, list) else [primary_key] - for field_name in fields: - if field_name not in self._fields_metadata: - raise ValueError(f'Field "{field_name}" does not exist in this table') - - field_metadata = self._fields_metadata[field_name] - if field_metadata['type'] != 'id': - field_subtype = self._get_key_subtype(field_metadata) - - field_metadata.update({ - 'type': 'id', - 'subtype': field_subtype - }) - - self._primary_key = primary_key - - def _make_anonymization_mappings(self, data): - mappings = {} - for name, field_metadata in self._fields_metadata.items(): - if field_metadata['type'] != 'id' and field_metadata.get('pii'): - uniques = data[name].unique() - mappings[name] = dict( - zip(uniques, Table._get_fake_values(field_metadata, len(uniques))) - ) - - self._ANONYMIZATION_MAPPINGS[id(self)] = mappings - - def _anonymize(self, data): - anonymization_mappings = self._ANONYMIZATION_MAPPINGS.get(id(self)) - if anonymization_mappings: - data = data.copy() - for name, mapping in anonymization_mappings.items(): - if name in data: - data[name] = data[name].map(mapping) - - return data - - def fit(self, data): - """Fit this metadata to the given data. - - Args: - data (pandas.DataFrame): - Table to be analyzed. - """ - LOGGER.info('Fitting table %s metadata', self.name) - if not self._field_names: - self._field_names = list(data.columns) - elif isinstance(self._field_names, set): - self._field_names = [field for field in data.columns if field in self._field_names] - - self._dtypes = data[self._field_names].dtypes - - if not self._fields_metadata: - self._fields_metadata = self._build_fields_metadata(data) - - # Re-set the primary key to validate its name and type - self.set_primary_key(self._primary_key) - - self._make_anonymization_mappings(data) - LOGGER.info('Anonymizing table %s', self.name) - data = self._anonymize(data) - - LOGGER.info('Fitting constraints for table %s', self.name) - constrained = self._fit_transform_constraints(data) - extra_columns = set(constrained.columns) - set(data.columns) - - LOGGER.info('Fitting HyperTransformer for table %s', self.name) - self._fit_hyper_transformer(constrained, extra_columns) - self.fitted = True - - def transform(self, data, is_condition=False): - """Transform the given data. - - Args: - data (pandas.DataFrame): - Table data. - - Returns: - pandas.DataFrame: - Transformed data. - """ - if not self.fitted: - raise MetadataNotFittedError() - - fields = [field for field in self.get_dtypes(ids=False) if field in data.columns] - LOGGER.debug('Anonymizing table %s', self.name) - data = self._anonymize(data[fields]) - - LOGGER.debug('Transforming constraints for table %s', self.name) - data = self._transform_constraints(data, is_condition) - - LOGGER.debug('Transforming table %s', self.name) - try: - return self._hyper_transformer.transform_subset(data) - except (rdt.errors.NotFittedError, rdt.errors.ConfigNotSetError): - return data - - @classmethod - def _make_ids(cls, field_metadata, length): - field_subtype = field_metadata.get('subtype', 'integer') - if field_subtype == 'string': - regex = field_metadata.get('regex', '[a-zA-Z]+') - generator, max_size = strings_from_regex(regex) - if max_size < length: - raise ValueError(( - f'Unable to generate {length} unique values for regex {regex}, the ' - f'maximum number of unique values is {max_size}.' - )) - values = [next(generator) for _ in range(length)] - - return pd.Series(list(values)[:length]) - else: - return pd.Series(np.arange(length)) - - def reverse_transform(self, data): - """Reverse the transformed data to the original format. - - Args: - data (pandas.DataFrame): - Data to be reverse transformed. - - Returns: - pandas.DataFrame - """ - if not self.fitted: - raise MetadataNotFittedError() - - reversible_columns = [ - column - for column in self._hyper_transformer._output_columns - if column in data.columns - ] - reversed_data = data - try: - if not data.empty: - reversed_data = self._hyper_transformer.reverse_transform_subset( - data[reversible_columns] - ) - except rdt.errors.NotFittedError: - LOGGER.info('HyperTransformer has not been fitted for table %s', self.name) - - for constraint in reversed(self._constraints_to_reverse): - reversed_data = constraint.reverse_transform(reversed_data) - - for name, field_metadata in self._fields_metadata.items(): - field_type = field_metadata['type'] - if field_type == 'id' and name not in reversed_data: - field_data = self._make_ids(field_metadata, len(reversed_data)) - elif field_metadata.get('pii', False): - field_data = pd.Series(Table._get_fake_values(field_metadata, len(reversed_data))) - else: - field_data = reversed_data[name] - - if field_metadata['type'] == 'numerical' and field_metadata['subtype'] == 'integer': - field_data = field_data.round() - - reversed_data[name] = field_data[field_data.notna()].astype(self._dtypes[name]) - - return reversed_data[self._field_names] - - def filter_valid(self, data): - """Filter the data using the constraints and return only the valid rows. - - Args: - data (pandas.DataFrame): - Table data. - - Returns: - pandas.DataFrame: - Table containing only the valid rows. - """ - for constraint in self._constraints: - data = constraint.filter_valid(data) - - return data - - def make_ids_unique(self, data): - """Repopulate any id fields in provided data to guarantee uniqueness. - - Args: - data (pandas.DataFrame): - Table data. - - Returns: - pandas.DataFrame: - Table where all id fields are unique. - """ - for name, field_metadata in self._fields_metadata.items(): - if field_metadata['type'] == 'id' and not data[name].is_unique: - ids = self._make_ids(field_metadata, len(data)) - ids.index = data.index.copy() - data[name] = ids - - return data - - # ###################### # - # Metadata Serialization # - # ###################### # - - def to_dict(self): - """Get a dict representation of this metadata. - - Returns: - dict: - dict representation of this metadata. - """ - return { - 'fields': copy.deepcopy(self._fields_metadata), - 'constraints': [ - constraint if isinstance(constraint, dict) else constraint.to_dict() - for constraint in self._constraints - ], - 'model_kwargs': copy.deepcopy(self._model_kwargs), - 'name': self.name, - 'primary_key': self._primary_key, - 'sequence_index': self._sequence_index, - 'entity_columns': self._entity_columns, - 'context_columns': self._context_columns, - } - - def to_json(self, path): - """Dump this metadata into a JSON file. - - Args: - path (str): - Path of the JSON file where this metadata will be stored. - """ - with open(path, 'w') as out_file: - json.dump(self.to_dict(), out_file, indent=4) - - @classmethod - def from_dict(cls, metadata_dict, dtype_transformers=None): - """Load a Table from a metadata dict. - - Args: - metadata_dict (dict): - Dict metadata to load. - dtype_transformers (dict): - If passed, set the dtype_transformers on the new instance. - """ - metadata_dict = copy.deepcopy(metadata_dict) - fields = metadata_dict['fields'] or {} - instance = cls( - name=metadata_dict.get('name'), - field_names=set(fields.keys()), - field_types=fields, - constraints=metadata_dict.get('constraints') or [], - model_kwargs=metadata_dict.get('model_kwargs') or {}, - primary_key=metadata_dict.get('primary_key'), - sequence_index=metadata_dict.get('sequence_index'), - entity_columns=metadata_dict.get('entity_columns') or [], - context_columns=metadata_dict.get('context_columns') or [], - dtype_transformers=dtype_transformers, - enforce_min_max_values=metadata_dict.get('enforce_min_max_values', True), - learn_rounding_scheme=metadata_dict.get('learn_rounding_scheme', True), - ) - instance._fields_metadata = fields - return instance - - @classmethod - def from_json(cls, path): - """Load a Table from a JSON. - - Args: - path (str): - Path of the JSON file to load - """ - with open(path, 'r') as in_file: - return cls.from_dict(json.load(in_file)) diff --git a/sdv/multi_table/base.py b/sdv/multi_table/base.py index 3ba9e889e..b0ea6ba84 100644 --- a/sdv/multi_table/base.py +++ b/sdv/multi_table/base.py @@ -90,7 +90,7 @@ def get_parameters(self, table_name): """ return self._table_synthesizers.get(table_name).get_parameters() - def update_table_parameters(self, table_name, table_parameters): + def set_table_parameters(self, table_name, table_parameters): """Update the table's synthesizer instantiation parameters. Args: diff --git a/sdv/multi_table/hma.py b/sdv/multi_table/hma.py index fe40b5c4b..ede101a62 100644 --- a/sdv/multi_table/hma.py +++ b/sdv/multi_table/hma.py @@ -31,7 +31,7 @@ def __init__(self, metadata, synthesizer_kwargs=None): self._max_child_rows = {} self._modeled_tables = [] for table_name in self.metadata._tables: - self.update_table_parameters(table_name, self._synthesizer_kwargs) + self.set_table_parameters(table_name, self._synthesizer_kwargs) def _get_extension(self, child_name, child_table, foreign_key): """Generate the extension columns for this child table. diff --git a/sdv/tabular/ctgan.py b/sdv/tabular/ctgan.py deleted file mode 100644 index 5948ed512..000000000 --- a/sdv/tabular/ctgan.py +++ /dev/null @@ -1,309 +0,0 @@ -"""Wrapper around CTGAN model.""" - -import ctgan -import numpy as np - -from sdv.tabular.base import BaseTabularModel - - -class CTGANModel(BaseTabularModel): - """Base class for all the CTGAN models. - - The ``CTGANModel`` class provides a wrapper for all the CTGAN models. - """ - - _MODEL_CLASS = None - _model_kwargs = None - - _DTYPE_TRANSFORMERS = { - 'O': None - } - - def _build_model(self): - return self._MODEL_CLASS(**self._model_kwargs) - - def _fit(self, table_data): - """Fit the model to the table. - - Args: - table_data (pandas.DataFrame): - Data to be learned. - """ - self._model = self._build_model() - - categoricals = [] - fields_before_transform = self._metadata.get_fields() - for field in table_data.columns: - if field in fields_before_transform: - meta = fields_before_transform[field] - if meta['type'] == 'categorical': - categoricals.append(field) - - else: - field_data = table_data[field].dropna() - if set(field_data.unique()) == {0.0, 1.0}: - # booleans encoded as float values must be modeled as bool - field_data = field_data.astype(bool) - - dtype = field_data.infer_objects().dtype - try: - kind = np.dtype(dtype).kind - except TypeError: - # probably category - kind = 'O' - if kind in ['O', 'b']: - categoricals.append(field) - - self._model.fit( - table_data, - discrete_columns=categoricals - ) - - def _sample(self, num_rows, conditions=None): - """Sample the indicated number of rows from the model. - - Args: - num_rows (int): - Amount of rows to sample. - conditions (dict): - If specified, this dictionary maps column names to the column - value. Then, this method generates `num_rows` samples, all of - which are conditioned on the given variables. - - Returns: - pandas.DataFrame: - Sampled data. - """ - if conditions is None: - return self._model.sample(num_rows) - - raise NotImplementedError(f"{self._MODEL_CLASS} doesn't support conditional sampling.") - - def _set_random_state(self, random_state): - """Set the random state of the model's random number generator. - - Args: - random_state (int, tuple[np.random.RandomState, torch.Generator], or None): - Seed or tuple of random states to use. - """ - self._model.set_random_state(random_state) - - -class CTGAN(CTGANModel): - """Model wrapping ``CTGAN`` model. - - Args: - field_names (list[str]): - List of names of the fields that need to be modeled - and included in the generated output data. Any additional - fields found in the data will be ignored and will not be - included in the generated output. - If ``None``, all the fields found in the data are used. - field_types (dict[str, dict]): - Dictinary specifying the data types and subtypes - of the fields that will be modeled. Field types and subtypes - combinations must be compatible with the SDV Metadata Schema. - field_transformers (dict[str, str]): - Dictinary specifying which transformers to use for each field. - Available transformers are: - - * ``FloatFormatter``: Uses a ``FloatFormatter`` for numerical data. - * ``FrequencyEncoder``: Uses a ``FrequencyEncoder`` without gaussian noise. - * ``FrequencyEncoder_noised``: Uses a ``FrequencyEncoder`` adding gaussian noise. - * ``OneHotEncoder``: Uses a ``OneHotEncoder``. - * ``LabelEncoder``: Uses a ``LabelEncoder`` without gaussian nose. - * ``LabelEncoder_noised``: Uses a ``LabelEncoder`` adding gaussian noise. - * ``BinaryEncoder``: Uses a ``BinaryEncoder``. - * ``UnixTimestampEncoder``: Uses a ``UnixTimestampEncoder``. - - anonymize_fields (dict[str, str]): - Dict specifying which fields to anonymize and what faker - category they belong to. - primary_key (str): - Name of the field which is the primary key of the table. - constraints (list[Constraint, dict]): - List of Constraint objects or dicts. - table_metadata (dict or metadata.Table): - Table metadata instance or dict representation. - If given alongside any other metadata-related arguments, an - exception will be raised. - If not given at all, it will be built using the other - arguments or learned from the data. - embedding_dim (int): - Size of the random sample passed to the Generator. Defaults to 128. - generator_dim (tuple or list of ints): - Size of the output samples for each one of the Residuals. A Residual Layer - will be created for each one of the values provided. Defaults to (256, 256). - discriminator_dim (tuple or list of ints): - Size of the output samples for each one of the Discriminator Layers. A Linear Layer - will be created for each one of the values provided. Defaults to (256, 256). - generator_lr (float): - Learning rate for the generator. Defaults to 2e-4. - generator_decay (float): - Generator weight decay for the Adam Optimizer. Defaults to 1e-6. - discriminator_lr (float): - Learning rate for the discriminator. Defaults to 2e-4. - discriminator_decay (float): - Discriminator weight decay for the Adam Optimizer. Defaults to 1e-6. - batch_size (int): - Number of data samples to process in each step. - discriminator_steps (int): - Number of discriminator updates to do for each generator update. - From the WGAN paper: https://arxiv.org/abs/1701.07875. WGAN paper - default is 5. Default used is 1 to match original CTGAN implementation. - log_frequency (boolean): - Whether to use log frequency of categorical levels in conditional - sampling. Defaults to ``True``. - verbose (boolean): - Whether to have print statements for progress results. Defaults to ``False``. - epochs (int): - Number of training epochs. Defaults to 300. - pac (int): - Number of samples to group together when applying the discriminator. - Defaults to 10. - cuda (bool or str): - If ``True``, use CUDA. If a ``str``, use the indicated device. - If ``False``, do not use cuda at all. - learn_rounding_scheme (bool): - Define rounding scheme for ``FloatFormatter``. If ``True``, the data returned by - ``reverse_transform`` will be rounded to that place. Defaults to ``True``. - enforce_min_max_values (bool): - Specify whether or not to clip the data returned by ``reverse_transform`` of - the numerical transformer, ``FloatFormatter``, to the min and max values seen - during ``fit``. Defaults to ``True``. - """ - - _MODEL_CLASS = ctgan.CTGAN - - def __init__(self, field_names=None, field_types=None, field_transformers=None, - anonymize_fields=None, primary_key=None, constraints=None, table_metadata=None, - embedding_dim=128, generator_dim=(256, 256), discriminator_dim=(256, 256), - generator_lr=2e-4, generator_decay=1e-6, discriminator_lr=2e-4, - discriminator_decay=1e-6, batch_size=500, discriminator_steps=1, - log_frequency=True, verbose=False, epochs=300, pac=10, cuda=True, - learn_rounding_scheme=True, enforce_min_max_values=True): - super().__init__( - field_names=field_names, - primary_key=primary_key, - field_types=field_types, - field_transformers=field_transformers, - anonymize_fields=anonymize_fields, - constraints=constraints, - table_metadata=table_metadata, - learn_rounding_scheme=learn_rounding_scheme, - enforce_min_max_values=enforce_min_max_values - ) - - self._model_kwargs = { - 'embedding_dim': embedding_dim, - 'generator_dim': generator_dim, - 'discriminator_dim': discriminator_dim, - 'generator_lr': generator_lr, - 'generator_decay': generator_decay, - 'discriminator_lr': discriminator_lr, - 'discriminator_decay': discriminator_decay, - 'batch_size': batch_size, - 'discriminator_steps': discriminator_steps, - 'log_frequency': log_frequency, - 'verbose': verbose, - 'epochs': epochs, - 'pac': pac, - 'cuda': cuda - } - - -class TVAE(CTGANModel): - """Model wrapping ``TVAE`` model. - - Args: - field_names (list[str]): - List of names of the fields that need to be modeled - and included in the generated output data. Any additional - fields found in the data will be ignored and will not be - included in the generated output. - If ``None``, all the fields found in the data are used. - field_types (dict[str, dict]): - Dictinary specifying the data types and subtypes - of the fields that will be modeled. Field types and subtypes - combinations must be compatible with the SDV Metadata Schema. - field_transformers (dict[str, str]): - Dictinary specifying which transformers to use for each field. - Available transformers are: - - * ``FloatFormatter``: Uses a ``FloatFormatter`` for numerical data. - * ``FrequencyEncoder``: Uses a ``FrequencyEncoder`` without gaussian noise. - * ``FrequencyEncoder_noised``: Uses a ``FrequencyEncoder`` adding gaussian noise. - * ``OneHotEncoder``: Uses a ``OneHotEncoder``. - * ``LabelEncoder``: Uses a ``LabelEncoder`` without gaussian nose. - * ``LabelEncoder_noised``: Uses a ``LabelEncoder`` adding gaussian noise. - * ``BinaryEncoder``: Uses a ``BinaryEncoder``. - * ``UnixTimestampEncoder``: Uses a ``UnixTimestampEncoder``. - - anonymize_fields (dict[str, str]): - Dict specifying which fields to anonymize and what faker - category they belong to. - primary_key (str): - Name of the field which is the primary key of the table. - constraints (list[Constraint, dict]): - List of Constraint objects or dicts. - table_metadata (dict or metadata.Table): - Table metadata instance or dict representation. - If given alongside any other metadata-related arguments, an - exception will be raised. - If not given at all, it will be built using the other - arguments or learned from the data. - embedding_dim (int): - Size of the random sample passed to the Generator. Defaults to 128. - compress_dims (tuple or list of ints): - Size of each hidden layer in the encoder. Defaults to (128, 128). - decompress_dims (tuple or list of ints): - Size of each hidden layer in the decoder. Defaults to (128, 128). - l2scale (int): - Regularization term. Defaults to 1e-5. - batch_size (int): - Number of data samples to process in each step. - epochs (int): - Number of training epochs. Defaults to 300. - loss_factor (int): - Multiplier for the reconstruction error. Defaults to 2. - cuda (bool or str): - If ``True``, use CUDA. If a ``str``, use the indicated device. - If ``False``, do not use cuda at all. - learn_rounding_scheme (bool): - Define rounding scheme for ``FloatFormatter``. If ``True``, the data returned by - ``reverse_transform`` will be rounded to that place. Defaults to ``True``. - enforce_min_max_values (bool): - Specify whether or not to clip the data returned by ``reverse_transform`` of - the numerical transformer, ``FloatFormatter``, to the min and max values seen - during ``fit``. Defaults to ``True``. - """ - - _MODEL_CLASS = ctgan.TVAE - - def __init__(self, field_names=None, field_types=None, field_transformers=None, - anonymize_fields=None, primary_key=None, constraints=None, table_metadata=None, - embedding_dim=128, compress_dims=(128, 128), decompress_dims=(128, 128), - l2scale=1e-5, batch_size=500, epochs=300, loss_factor=2, cuda=True, - learn_rounding_scheme=True, enforce_min_max_values=True): - super().__init__( - field_names=field_names, - primary_key=primary_key, - field_types=field_types, - field_transformers=field_transformers, - anonymize_fields=anonymize_fields, - constraints=constraints, - table_metadata=table_metadata, - learn_rounding_scheme=learn_rounding_scheme, - enforce_min_max_values=enforce_min_max_values - ) - - self._model_kwargs = { - 'embedding_dim': embedding_dim, - 'compress_dims': compress_dims, - 'decompress_dims': decompress_dims, - 'l2scale': l2scale, - 'batch_size': batch_size, - 'epochs': epochs, - 'loss_factor': loss_factor, - 'cuda': cuda - } diff --git a/tests/integration/datasets/test_demo.py b/tests/integration/datasets/test_demo.py index 1a78f5ea7..a4480aad3 100644 --- a/tests/integration/datasets/test_demo.py +++ b/tests/integration/datasets/test_demo.py @@ -14,17 +14,17 @@ def test_get_available_demos_single_table(): 'dataset_name': [ 'KRK_v1', 'adult', 'alarm', 'asia', 'census', 'census_extended', 'child', 'covtype', 'credit', - 'expedia_hotel_logs', 'grid', 'gridr', 'insurance', + 'expedia_hotel_logs', 'fake_companies', 'grid', 'gridr', 'insurance', 'intrusion', 'mnist12', 'mnist28', 'news', 'ring', 'student_placements', 'student_placements_pii' ], 'size_MB': [ '0.072128', '3.907448', '4.520128', '1.280128', '98.165608', '4.9494', '3.200128', '255.645408', '68.353808', '0.200128', - '0.320128', '0.320128', '3.340128', '162.039016', '81.200128', + '0.0', '0.320128', '0.320128', '3.340128', '162.039016', '81.200128', '439.600128', '18.712096', '0.320128', '0.026358', '0.028078' ], - 'num_tables': ['1'] * 20 + 'num_tables': ['1'] * 21 }) expected_table['size_MB'] = expected_table['size_MB'].astype(float).round(2) pd.testing.assert_frame_equal(tables_info, expected_table) diff --git a/tests/integration/multi_table/test_hma.py b/tests/integration/multi_table/test_hma.py index de853ab80..b71fd8cba 100644 --- a/tests/integration/multi_table/test_hma.py +++ b/tests/integration/multi_table/test_hma.py @@ -114,3 +114,29 @@ def test_get_info(): 'last_fit_date': today, 'fitted_sdv_version': version } + + +def test_hma_set_parameters(): + """Test the ``set_table_parameters``. + + Validate that the ``set_table_parameters`` sets new parameters to the synthesizers. + """ + # Setup + data, metadata = download_demo('multi_table', 'got_families') + hmasynthesizer = HMASynthesizer(metadata) + + # Run + hmasynthesizer.set_table_parameters('characters', {'default_distribution': 'gamma'}) + hmasynthesizer.set_table_parameters('families', {'default_distribution': 'uniform'}) + hmasynthesizer.set_table_parameters('character_families', {'default_distribution': 'norm'}) + + # Assert + assert hmasynthesizer.get_table_parameters('characters') == {'default_distribution': 'gamma'} + assert hmasynthesizer.get_table_parameters('families') == {'default_distribution': 'uniform'} + assert hmasynthesizer.get_table_parameters('character_families') == { + 'default_distribution': 'norm' + } + + assert hmasynthesizer._table_synthesizers['characters'].default_distribution == 'gamma' + assert hmasynthesizer._table_synthesizers['families'].default_distribution == 'uniform' + assert hmasynthesizer._table_synthesizers['character_families'].default_distribution == 'norm' diff --git a/tests/unit/multi_table/test_base.py b/tests/unit/multi_table/test_base.py index 449b72a7d..b428e8bbe 100644 --- a/tests/unit/multi_table/test_base.py +++ b/tests/unit/multi_table/test_base.py @@ -112,7 +112,7 @@ def test_get_parameters(self): 'numerical_distributions': {} } - def test_update_table_parameters(self): + def test_set_table_parameters(self): """Test that the table's parameters are being updated. This test should ensure that the ``self._table_parameters`` for the given table @@ -124,7 +124,7 @@ def test_update_table_parameters(self): instance = BaseMultiTableSynthesizer(metadata) # Run - instance.update_table_parameters('oseba', {'default_distribution': 'gamma'}) + instance.set_table_parameters('oseba', {'default_distribution': 'gamma'}) # Assert assert instance._table_parameters['oseba'] == {'default_distribution': 'gamma'}