Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace MultiTableMetadata with Metadata #2146

Merged
8 changes: 4 additions & 4 deletions sdv/io/local/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import pandas as pd

from sdv.metadata import MultiTableMetadata
from sdv.metadata import Metadata


class BaseLocalHandler:
Expand All @@ -25,11 +25,11 @@ def create_metadata(self, data):
Dictionary of table names to dataframes.

Returns:
MultiTableMetadata:
An ``sdv.metadata.MultiTableMetadata`` object with the detected metadata
Metadata:
An ``sdv.metadata.Metadata`` object with the detected metadata
properties from the data.
"""
metadata = MultiTableMetadata()
metadata = Metadata()
metadata.detect_from_dataframes(data)
return metadata

Expand Down
4 changes: 4 additions & 0 deletions sdv/metadata/multi_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@
LOGGER = logging.getLogger(__name__)
MULTITABLEMETADATA_LOGGER = get_sdv_logger('MultiTableMetadata')
WARNINGS_COLUMN_ORDER = ['Table Name', 'Column Name', 'sdtype', 'datetime_format']
DEPRECATION_MSG = (
"The 'MultiTableMetadata' is deprecated. Please use the new "
"'Metadata' class for synthesizers."
)


class MultiTableMetadata:
Expand Down
2 changes: 1 addition & 1 deletion sdv/metadata/single_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -1406,7 +1406,7 @@ def upgrade_metadata(cls, filepath):
if len(tables) > 1:
raise InvalidMetadataError(
'There are multiple tables specified in the JSON. '
'Try using the MultiTableMetadata class to upgrade this file.'
'Try using the Metadata class to upgrade this file.'
)

else:
Expand Down
13 changes: 9 additions & 4 deletions sdv/multi_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
SynthesizerInputError,
)
from sdv.logging import disable_single_table_logger, get_sdv_logger
from sdv.metadata.metadata import Metadata
from sdv.metadata.multi_table import DEPRECATION_MSG, MultiTableMetadata
from sdv.single_table.copulas import GaussianCopulaSynthesizer

SYNTHESIZER_LOGGER = get_sdv_logger('MultiTableSynthesizer')
Expand All @@ -38,9 +40,8 @@ class BaseMultiTableSynthesizer:
multi table synthesizers need to implement, as well as common functionality.
Args:
metadata (sdv.metadata.multi_table.MultiTableMetadata):
Multi table metadata representing the data tables that this synthesizer will be used
for.
metadata (sdv.metadata.Metadata):
Metadata representing the data tables that this synthesizer will utilize.
locales (list or str):
The default locale(s) to use for AnonymizedFaker transformers.
Defaults to ``['en_US']``.
Expand Down Expand Up @@ -71,8 +72,9 @@ def _initialize_models(self):
with disable_single_table_logger():
for table_name, table_metadata in self.metadata.tables.items():
synthesizer_parameters = self._table_parameters.get(table_name, {})
metadata = Metadata.load_from_dict(table_metadata.to_dict())
self._table_synthesizers[table_name] = self._synthesizer(
metadata=table_metadata, locales=self.locales, **synthesizer_parameters
metadata=metadata, locales=self.locales, **synthesizer_parameters
)
self._table_synthesizers[table_name]._data_processor.table_name = table_name

Expand All @@ -97,6 +99,9 @@ def _check_metadata_updated(self):

def __init__(self, metadata, locales=['en_US'], synthesizer_kwargs=None):
self.metadata = metadata
if type(metadata) is MultiTableMetadata:
self.metadata = Metadata().load_from_dict(metadata.to_dict())
warnings.warn(DEPRECATION_MSG, FutureWarning)
with warnings.catch_warnings():
warnings.filterwarnings('ignore', message=r'.*column relationship.*')
self.metadata.validate()
Expand Down
4 changes: 2 additions & 2 deletions sdv/multi_table/hma.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class HMASynthesizer(BaseHierarchicalSampler, BaseMultiTableSynthesizer):
"""Hierarchical Modeling Algorithm One.

Args:
metadata (sdv.metadata.multi_table.MultiTableMetadata):
metadata (sdv.metadata.Metadata):
Multi table metadata representing the data tables that this synthesizer will be used
for.
locales (list or str):
Expand All @@ -47,7 +47,7 @@ def _get_num_data_columns(metadata):
"""Get the number of data columns, ie colums that are not id, for each table.

Args:
metadata (MultiTableMetadata):
metadata (Metadata):
Metadata of the datasets.
"""
columns_per_table = {}
Expand Down
26 changes: 13 additions & 13 deletions sdv/multi_table/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def _simplify_relationships_and_tables(metadata, tables_to_drop):
Removes the tables that are not direct child or grandchild of the root table.

Args:
metadata (MultiTableMetadata):
metadata (Metadata):
Metadata of the datasets.
tables_to_drop (set):
Set of the tables that relationships will be removed.
Expand All @@ -149,7 +149,7 @@ def _simplify_grandchildren(metadata, grandchildren):
- Drop all modelables columns.

Args:
metadata (MultiTableMetadata):
metadata (Metadata):
Metadata of the datasets.
grandchildren (set):
Set of the grandchildren of the root table.
Expand All @@ -174,7 +174,7 @@ def _get_num_column_to_drop(metadata, child_table, max_col_per_relationships):
- minimum number of column to drop = n + k - sqrt(k^2 + 1 + 2m)

Args:
metadata (MultiTableMetadata):
metadata (Metadata):
Metadata of the datasets.
child_table (str):
Name of the child table.
Expand Down Expand Up @@ -232,7 +232,7 @@ def _simplify_child(metadata, child_table, max_col_per_relationships):
"""Simplify the child table.

Args:
metadata (MultiTableMetadata):
metadata (Metadata):
Metadata of the datasets.
child_table (str):
Name of the child table.
Expand All @@ -252,7 +252,7 @@ def _simplify_children(metadata, children, root_table, num_data_column):
- Drop some modelable columns to have at most MAX_NUMBER_OF_COLUMNS columns to model.

Args:
metadata (MultiTableMetadata):
metadata (Metadata):
Metadata of the datasets.
children (set):
Set of the children of the root table.
Expand Down Expand Up @@ -288,11 +288,11 @@ def _simplify_metadata(metadata):
- Drop some modelable columns in the children to have at most 1000 columns to model.

Args:
metadata (MultiTableMetadata):
metadata (Metadata):
Metadata of the datasets.

Returns:
MultiTableMetadata:
Metadata:
Simplified metadata.
"""
simplified_metadata = deepcopy(metadata)
Expand Down Expand Up @@ -330,7 +330,7 @@ def _simplify_data(data, metadata):
data (dict):
Dictionary that maps each table name (string) to the data for that
table (pandas.DataFrame).
metadata (MultiTableMetadata):
metadata (Metadata):
Metadata of the datasets.

Returns:
Expand Down Expand Up @@ -375,7 +375,7 @@ def _get_rows_to_drop(data, metadata):
This ensures that we preserve the referential integrity between all the relationships.

Args:
metadata (MultiTableMetadata):
metadata (Metadata):
Metadata of the datasets.
data (dict):
Dictionary that maps each table name (string) to the data for that
Expand Down Expand Up @@ -470,7 +470,7 @@ def _subsample_table_and_descendants(data, metadata, table, num_rows):
data (dict):
Dictionary that maps each table name (string) to the data for that
table (pandas.DataFrame).
metadata (MultiTableMetadata):
metadata (Metadata):
Metadata of the datasets.
table (str):
Name of the table.
Expand All @@ -496,7 +496,7 @@ def _get_primary_keys_referenced(data, metadata):
data (dict):
Dictionary that maps each table name (string) to the data for that
table (pandas.DataFrame).
metadata (MultiTableMetadata):
metadata (Metadata):
Metadata of the datasets.

Returns:
Expand Down Expand Up @@ -568,7 +568,7 @@ def _subsample_ancestors(data, metadata, table, primary_keys_referenced):
data (dict):
Dictionary that maps each table name (string) to the data for that
table (pandas.DataFrame).
metadata (MultiTableMetadata):
metadata (Metadata):
Metadata of the datasets.
table (str):
Name of the table.
Expand Down Expand Up @@ -604,7 +604,7 @@ def _subsample_data(data, metadata, main_table_name, num_rows):
referenced by the descendants and some unreferenced rows.

Args:
metadata (MultiTableMetadata):
metadata (Metadata):
Metadata of the datasets.
data (dict):
Dictionary that maps each table name (string) to the data for that
Expand Down
2 changes: 1 addition & 1 deletion sdv/sampling/hierarchical_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class BaseHierarchicalSampler:
"""Hierarchical sampler mixin.

Args:
metadata (sdv.metadata.multi_table.MultiTableMetadata):
metadata (sdv.metadata.Metadata):
Multi-table metadata representing the data tables that this sampler will be used for.
table_synthesizers (dict):
Dictionary mapping each table to a synthesizer. Should be instantiated and passed to
Expand Down
2 changes: 1 addition & 1 deletion sdv/sampling/independent_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class BaseIndependentSampler:
"""Independent sampler mixin.

Args:
metadata (sdv.metadata.multi_table.MultiTableMetadata):
metadata (sdv.metadata.Metadata):
Multi-table metadata representing the data tables that this sampler will be used for.
table_synthesizers (dict):
Dictionary mapping each table to a synthesizer. Should be instantiated and passed to
Expand Down
6 changes: 3 additions & 3 deletions sdv/utils/poc.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def simplify_schema(data, metadata, verbose=True):
data (dict):
Dictionary that maps each table name (string) to the data for that
table (pandas.DataFrame).
metadata (MultiTableMetadata):
metadata (Metadata):
Metadata of the datasets.
verbose (bool):
If True, print information about the simplification process.
Expand All @@ -50,7 +50,7 @@ def simplify_schema(data, metadata, verbose=True):
tuple:
dict:
Dictionary with the simplified dataframes.
MultiTableMetadata:
Metadata:
Simplified metadata.
"""
try:
Expand Down Expand Up @@ -93,7 +93,7 @@ def get_random_subset(data, metadata, main_table_name, num_rows, verbose=True):
data (dict):
Dictionary that maps each table name (string) to the data for that
table (pandas.DataFrame).
metadata (MultiTableMetadata):
metadata (Metadata):
Metadata of the datasets.
main_table_name (str):
Name of the main table.
Expand Down
2 changes: 1 addition & 1 deletion sdv/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def drop_unknown_references(data, metadata, drop_missing_values=True, verbose=Tr
data (dict):
Dictionary that maps each table name (string) to the data for that
table (pandas.DataFrame).
metadata (MultiTableMetadata):
metadata (Metadata):
Metadata of the datasets.
drop_missing_values (bool):
Boolean describing whether or not to also drop foreign keys with missing values
Expand Down
Loading
Loading