diff --git a/sdv/io/local/local.py b/sdv/io/local/local.py index ccba2f51d..82a992d13 100644 --- a/sdv/io/local/local.py +++ b/sdv/io/local/local.py @@ -7,7 +7,7 @@ import pandas as pd -from sdv.metadata import MultiTableMetadata +from sdv.metadata import Metadata class BaseLocalHandler: @@ -25,11 +25,11 @@ def create_metadata(self, data): Dictionary of table names to dataframes. Returns: - MultiTableMetadata: - An ``sdv.metadata.MultiTableMetadata`` object with the detected metadata + Metadata: + An ``sdv.metadata.Metadata`` object with the detected metadata properties from the data. """ - metadata = MultiTableMetadata() + metadata = Metadata() metadata.detect_from_dataframes(data) return metadata diff --git a/sdv/metadata/multi_table.py b/sdv/metadata/multi_table.py index 26efe2b94..9bc93b115 100644 --- a/sdv/metadata/multi_table.py +++ b/sdv/metadata/multi_table.py @@ -26,6 +26,10 @@ LOGGER = logging.getLogger(__name__) MULTITABLEMETADATA_LOGGER = get_sdv_logger('MultiTableMetadata') WARNINGS_COLUMN_ORDER = ['Table Name', 'Column Name', 'sdtype', 'datetime_format'] +DEPRECATION_MSG = ( + "The 'MultiTableMetadata' is deprecated. Please use the new " + "'Metadata' class for synthesizers." +) class MultiTableMetadata: diff --git a/sdv/metadata/single_table.py b/sdv/metadata/single_table.py index f0fe05910..8ff57388e 100644 --- a/sdv/metadata/single_table.py +++ b/sdv/metadata/single_table.py @@ -1406,7 +1406,7 @@ def upgrade_metadata(cls, filepath): if len(tables) > 1: raise InvalidMetadataError( 'There are multiple tables specified in the JSON. ' - 'Try using the MultiTableMetadata class to upgrade this file.' + 'Try using the Metadata class to upgrade this file.' ) else: diff --git a/sdv/multi_table/base.py b/sdv/multi_table/base.py index 07d527a75..17861dbca 100644 --- a/sdv/multi_table/base.py +++ b/sdv/multi_table/base.py @@ -26,6 +26,8 @@ SynthesizerInputError, ) from sdv.logging import disable_single_table_logger, get_sdv_logger +from sdv.metadata.metadata import Metadata +from sdv.metadata.multi_table import DEPRECATION_MSG, MultiTableMetadata from sdv.single_table.copulas import GaussianCopulaSynthesizer SYNTHESIZER_LOGGER = get_sdv_logger('MultiTableSynthesizer') @@ -38,9 +40,8 @@ class BaseMultiTableSynthesizer: multi table synthesizers need to implement, as well as common functionality. Args: - metadata (sdv.metadata.multi_table.MultiTableMetadata): - Multi table metadata representing the data tables that this synthesizer will be used - for. + metadata (sdv.metadata.Metadata): + Metadata representing the data tables that this synthesizer will utilize. locales (list or str): The default locale(s) to use for AnonymizedFaker transformers. Defaults to ``['en_US']``. @@ -71,8 +72,9 @@ def _initialize_models(self): with disable_single_table_logger(): for table_name, table_metadata in self.metadata.tables.items(): synthesizer_parameters = self._table_parameters.get(table_name, {}) + metadata = Metadata.load_from_dict(table_metadata.to_dict()) self._table_synthesizers[table_name] = self._synthesizer( - metadata=table_metadata, locales=self.locales, **synthesizer_parameters + metadata=metadata, locales=self.locales, **synthesizer_parameters ) self._table_synthesizers[table_name]._data_processor.table_name = table_name @@ -97,6 +99,9 @@ def _check_metadata_updated(self): def __init__(self, metadata, locales=['en_US'], synthesizer_kwargs=None): self.metadata = metadata + if type(metadata) is MultiTableMetadata: + self.metadata = Metadata().load_from_dict(metadata.to_dict()) + warnings.warn(DEPRECATION_MSG, FutureWarning) with warnings.catch_warnings(): warnings.filterwarnings('ignore', message=r'.*column relationship.*') self.metadata.validate() diff --git a/sdv/multi_table/hma.py b/sdv/multi_table/hma.py index 5df4dc358..790fdaeba 100644 --- a/sdv/multi_table/hma.py +++ b/sdv/multi_table/hma.py @@ -23,7 +23,7 @@ class HMASynthesizer(BaseHierarchicalSampler, BaseMultiTableSynthesizer): """Hierarchical Modeling Algorithm One. Args: - metadata (sdv.metadata.multi_table.MultiTableMetadata): + metadata (sdv.metadata.Metadata): Multi table metadata representing the data tables that this synthesizer will be used for. locales (list or str): @@ -47,7 +47,7 @@ def _get_num_data_columns(metadata): """Get the number of data columns, ie colums that are not id, for each table. Args: - metadata (MultiTableMetadata): + metadata (Metadata): Metadata of the datasets. """ columns_per_table = {} diff --git a/sdv/multi_table/utils.py b/sdv/multi_table/utils.py index 9cdb5297f..bc9e99497 100644 --- a/sdv/multi_table/utils.py +++ b/sdv/multi_table/utils.py @@ -123,7 +123,7 @@ def _simplify_relationships_and_tables(metadata, tables_to_drop): Removes the tables that are not direct child or grandchild of the root table. Args: - metadata (MultiTableMetadata): + metadata (Metadata): Metadata of the datasets. tables_to_drop (set): Set of the tables that relationships will be removed. @@ -149,7 +149,7 @@ def _simplify_grandchildren(metadata, grandchildren): - Drop all modelables columns. Args: - metadata (MultiTableMetadata): + metadata (Metadata): Metadata of the datasets. grandchildren (set): Set of the grandchildren of the root table. @@ -174,7 +174,7 @@ def _get_num_column_to_drop(metadata, child_table, max_col_per_relationships): - minimum number of column to drop = n + k - sqrt(k^2 + 1 + 2m) Args: - metadata (MultiTableMetadata): + metadata (Metadata): Metadata of the datasets. child_table (str): Name of the child table. @@ -232,7 +232,7 @@ def _simplify_child(metadata, child_table, max_col_per_relationships): """Simplify the child table. Args: - metadata (MultiTableMetadata): + metadata (Metadata): Metadata of the datasets. child_table (str): Name of the child table. @@ -252,7 +252,7 @@ def _simplify_children(metadata, children, root_table, num_data_column): - Drop some modelable columns to have at most MAX_NUMBER_OF_COLUMNS columns to model. Args: - metadata (MultiTableMetadata): + metadata (Metadata): Metadata of the datasets. children (set): Set of the children of the root table. @@ -288,11 +288,11 @@ def _simplify_metadata(metadata): - Drop some modelable columns in the children to have at most 1000 columns to model. Args: - metadata (MultiTableMetadata): + metadata (Metadata): Metadata of the datasets. Returns: - MultiTableMetadata: + Metadata: Simplified metadata. """ simplified_metadata = deepcopy(metadata) @@ -330,7 +330,7 @@ def _simplify_data(data, metadata): data (dict): Dictionary that maps each table name (string) to the data for that table (pandas.DataFrame). - metadata (MultiTableMetadata): + metadata (Metadata): Metadata of the datasets. Returns: @@ -375,7 +375,7 @@ def _get_rows_to_drop(data, metadata): This ensures that we preserve the referential integrity between all the relationships. Args: - metadata (MultiTableMetadata): + metadata (Metadata): Metadata of the datasets. data (dict): Dictionary that maps each table name (string) to the data for that @@ -470,7 +470,7 @@ def _subsample_table_and_descendants(data, metadata, table, num_rows): data (dict): Dictionary that maps each table name (string) to the data for that table (pandas.DataFrame). - metadata (MultiTableMetadata): + metadata (Metadata): Metadata of the datasets. table (str): Name of the table. @@ -496,7 +496,7 @@ def _get_primary_keys_referenced(data, metadata): data (dict): Dictionary that maps each table name (string) to the data for that table (pandas.DataFrame). - metadata (MultiTableMetadata): + metadata (Metadata): Metadata of the datasets. Returns: @@ -568,7 +568,7 @@ def _subsample_ancestors(data, metadata, table, primary_keys_referenced): data (dict): Dictionary that maps each table name (string) to the data for that table (pandas.DataFrame). - metadata (MultiTableMetadata): + metadata (Metadata): Metadata of the datasets. table (str): Name of the table. @@ -604,7 +604,7 @@ def _subsample_data(data, metadata, main_table_name, num_rows): referenced by the descendants and some unreferenced rows. Args: - metadata (MultiTableMetadata): + metadata (Metadata): Metadata of the datasets. data (dict): Dictionary that maps each table name (string) to the data for that diff --git a/sdv/sampling/hierarchical_sampler.py b/sdv/sampling/hierarchical_sampler.py index 641b6ccbf..cc134585c 100644 --- a/sdv/sampling/hierarchical_sampler.py +++ b/sdv/sampling/hierarchical_sampler.py @@ -12,7 +12,7 @@ class BaseHierarchicalSampler: """Hierarchical sampler mixin. Args: - metadata (sdv.metadata.multi_table.MultiTableMetadata): + metadata (sdv.metadata.Metadata): Multi-table metadata representing the data tables that this sampler will be used for. table_synthesizers (dict): Dictionary mapping each table to a synthesizer. Should be instantiated and passed to diff --git a/sdv/sampling/independent_sampler.py b/sdv/sampling/independent_sampler.py index b510e8726..2f4116094 100644 --- a/sdv/sampling/independent_sampler.py +++ b/sdv/sampling/independent_sampler.py @@ -10,7 +10,7 @@ class BaseIndependentSampler: """Independent sampler mixin. Args: - metadata (sdv.metadata.multi_table.MultiTableMetadata): + metadata (sdv.metadata.Metadata): Multi-table metadata representing the data tables that this sampler will be used for. table_synthesizers (dict): Dictionary mapping each table to a synthesizer. Should be instantiated and passed to diff --git a/sdv/utils/poc.py b/sdv/utils/poc.py index 682895bad..4f55e63d0 100644 --- a/sdv/utils/poc.py +++ b/sdv/utils/poc.py @@ -40,7 +40,7 @@ def simplify_schema(data, metadata, verbose=True): data (dict): Dictionary that maps each table name (string) to the data for that table (pandas.DataFrame). - metadata (MultiTableMetadata): + metadata (Metadata): Metadata of the datasets. verbose (bool): If True, print information about the simplification process. @@ -50,7 +50,7 @@ def simplify_schema(data, metadata, verbose=True): tuple: dict: Dictionary with the simplified dataframes. - MultiTableMetadata: + Metadata: Simplified metadata. """ try: @@ -93,7 +93,7 @@ def get_random_subset(data, metadata, main_table_name, num_rows, verbose=True): data (dict): Dictionary that maps each table name (string) to the data for that table (pandas.DataFrame). - metadata (MultiTableMetadata): + metadata (Metadata): Metadata of the datasets. main_table_name (str): Name of the main table. diff --git a/sdv/utils/utils.py b/sdv/utils/utils.py index 340761ca8..293abca71 100644 --- a/sdv/utils/utils.py +++ b/sdv/utils/utils.py @@ -18,7 +18,7 @@ def drop_unknown_references(data, metadata, drop_missing_values=True, verbose=Tr data (dict): Dictionary that maps each table name (string) to the data for that table (pandas.DataFrame). - metadata (MultiTableMetadata): + metadata (Metadata): Metadata of the datasets. drop_missing_values (bool): Boolean describing whether or not to also drop foreign keys with missing values diff --git a/tests/integration/metadata/test_metadata.py b/tests/integration/metadata/test_metadata.py index a8cede88b..bde8b15c7 100644 --- a/tests/integration/metadata/test_metadata.py +++ b/tests/integration/metadata/test_metadata.py @@ -1,9 +1,13 @@ +import os +import re + import pytest from sdv.datasets.demo import download_demo from sdv.metadata.metadata import DEFAULT_TABLE_NAME, Metadata from sdv.metadata.multi_table import MultiTableMetadata from sdv.metadata.single_table import SingleTableMetadata +from sdv.multi_table.hma import HMASynthesizer from sdv.single_table.copulas import GaussianCopulaSynthesizer @@ -118,7 +122,7 @@ def test_detect_from_csvs(tmp_path): metadata = Metadata() for table_name, dataframe in real_data.items(): - csv_path = tmp_path / f'{table_name}.csv' + csv_path = os.path.join(tmp_path, f'{table_name}.csv') dataframe.to_csv(csv_path, index=False) # Run @@ -181,11 +185,11 @@ def test_detect_table_from_csv(tmp_path): metadata = Metadata() for table_name, dataframe in real_data.items(): - csv_path = tmp_path / f'{table_name}.csv' + csv_path = os.path.join(tmp_path, f'{table_name}.csv') dataframe.to_csv(csv_path, index=False) # Run - metadata.detect_table_from_csv('hotels', tmp_path / 'hotels.csv') + metadata.detect_table_from_csv('hotels', os.path.join(tmp_path, 'hotels.csv')) # Assert metadata.update_column( @@ -254,12 +258,12 @@ def test_single_table_compatibility(tmp_path): with pytest.warns(FutureWarning, match=warn_msg): synthesizer = GaussianCopulaSynthesizer(metadata) synthesizer.fit(data) - model_path = tmp_path / 'synthesizer.pkl' + model_path = os.path.join(tmp_path, 'synthesizer.pkl') synthesizer.save(model_path) # Assert - assert model_path.exists() - assert model_path.is_file() + assert os.path.exists(model_path) + assert os.path.isfile(model_path) loaded_synthesizer = GaussianCopulaSynthesizer.load(model_path) assert isinstance(synthesizer, GaussianCopulaSynthesizer) assert loaded_synthesizer.get_info() == synthesizer.get_info() @@ -274,3 +278,93 @@ def test_single_table_compatibility(tmp_path): metadata_sample = synthesizer.sample(10) assert loaded_synthesizer.metadata.to_dict() == synthesizer_2.metadata.to_dict() assert metadata_sample.columns.to_list() == loaded_sample.columns.to_list() + + +def test_multi_table_compatibility(tmp_path): + """Test if MultiTableMetadata still has compatibility with multi table synthesizers.""" + # Setup + data, _ = download_demo('multi_table', 'fake_hotels') + warn_msg = re.escape( + "The 'MultiTableMetadata' is deprecated. Please use the new " + "'Metadata' class for synthesizers." + ) + + multi_dict = { + 'tables': { + 'guests': { + 'primary_key': 'guest_email', + 'columns': { + 'guest_email': {'sdtype': 'email', 'pii': True}, + 'hotel_id': {'sdtype': 'id', 'regex_format': '[A-Za-z]{5}'}, + 'has_rewards': {'sdtype': 'boolean'}, + 'room_type': {'sdtype': 'categorical'}, + 'amenities_fee': {'sdtype': 'numerical', 'computer_representation': 'Float'}, + 'checkin_date': {'sdtype': 'datetime', 'datetime_format': '%d %b %Y'}, + 'checkout_date': {'sdtype': 'datetime', 'datetime_format': '%d %b %Y'}, + 'room_rate': {'sdtype': 'numerical', 'computer_representation': 'Float'}, + 'billing_address': {'sdtype': 'address', 'pii': True}, + 'credit_card_number': {'sdtype': 'credit_card_number', 'pii': True}, + }, + }, + 'hotels': { + 'primary_key': 'hotel_id', + 'columns': { + 'hotel_id': {'sdtype': 'id', 'regex_format': 'HID_[0-9]{3}'}, + 'city': {'sdtype': 'categorical'}, + 'state': {'sdtype': 'categorical'}, + 'rating': {'sdtype': 'numerical', 'computer_representation': 'Float'}, + 'classification': {'sdtype': 'categorical'}, + }, + }, + }, + 'relationships': [ + { + 'parent_table_name': 'hotels', + 'parent_primary_key': 'hotel_id', + 'child_table_name': 'guests', + 'child_foreign_key': 'hotel_id', + } + ], + 'METADATA_SPEC_VERSION': 'MULTI_TABLE_V1', + } + metadata = MultiTableMetadata.load_from_dict(multi_dict) + assert type(metadata) is MultiTableMetadata + + # Run + with pytest.warns(FutureWarning, match=warn_msg): + synthesizer = HMASynthesizer(metadata) + + synthesizer.fit(data) + model_path = os.path.join(tmp_path, 'synthesizer.pkl') + synthesizer.save(model_path) + + # Assert + assert os.path.exists(model_path) + assert os.path.isfile(model_path) + + # Load HMASynthesizer + loaded_synthesizer = HMASynthesizer.load(model_path) + + # Asserts + assert isinstance(synthesizer, HMASynthesizer) + assert loaded_synthesizer.get_info() == synthesizer.get_info() + assert isinstance(loaded_synthesizer.metadata, Metadata) + + # Load Metadata + expected_metadata = metadata.to_dict() + expected_metadata['METADATA_SPEC_VERSION'] = 'V1' + + # Asserts + assert loaded_synthesizer.metadata.to_dict() == expected_metadata + + # Sample from loaded synthesizer + loaded_sample = loaded_synthesizer.sample(10) + synthesizer.validate(loaded_sample) + + # Run against Metadata + synthesizer_2 = HMASynthesizer(Metadata._convert_to_unified_metadata(metadata)) + synthesizer_2.fit(data) + metadata_sample = synthesizer.sample(10) + assert loaded_synthesizer.metadata.to_dict() == synthesizer_2.metadata.to_dict() + for table in metadata_sample: + assert metadata_sample[table].columns.to_list() == loaded_sample[table].columns.to_list() diff --git a/tests/integration/multi_table/test_hma.py b/tests/integration/multi_table/test_hma.py index 493a648f4..fcb4910e2 100644 --- a/tests/integration/multi_table/test_hma.py +++ b/tests/integration/multi_table/test_hma.py @@ -18,6 +18,7 @@ from sdv.datasets.local import load_csvs from sdv.errors import SamplingError, SynthesizerInputError, VersionError from sdv.evaluation.multi_table import evaluate_quality, get_column_pair_plot, get_column_plot +from sdv.metadata.metadata import Metadata from sdv.metadata.multi_table import MultiTableMetadata from sdv.multi_table import HMASynthesizer from tests.integration.single_table.custom_constraints import MyConstraint @@ -393,6 +394,7 @@ def test_save_and_load(self, tmp_path): """Test saving and loading a multi-table synthesizer.""" # Setup _, _, metadata = self.get_custom_constraint_data_and_metadata() + metadata = Metadata.load_from_dict(metadata.to_dict()) synthesizer = HMASynthesizer(metadata) model_path = tmp_path / 'synthesizer.pkl' @@ -460,6 +462,7 @@ def test_synthesize_multiple_tables_using_hma(self, tmp_path): """ # Loading the demo data real_data, metadata = download_demo(modality='multi_table', dataset_name='fake_hotels') + metadata = Metadata.load_from_dict(metadata.to_dict()) # Creating a Synthesizer synthesizer = HMASynthesizer(metadata) @@ -1250,9 +1253,8 @@ def test_metadata_updated_no_warning(self, tmp_path): instance.fit(data) # Assert - for warning in captured_warnings: - assert warning.category is FutureWarning - assert len(captured_warnings) == 3 + assert len(captured_warnings) == 1 + assert captured_warnings[0].category is FutureWarning # Run 2 metadata_detect = MultiTableMetadata() @@ -1271,9 +1273,8 @@ def test_metadata_updated_no_warning(self, tmp_path): instance.fit(data) # Assert - for warning in captured_warnings: - assert warning.category is FutureWarning - assert len(captured_warnings) == 3 + assert len(captured_warnings) == 1 + assert captured_warnings[0].category is FutureWarning # Run 3 instance = HMASynthesizer(metadata_detect) @@ -1297,7 +1298,7 @@ def test_metadata_updated_warning_detect(self): """ # Setup data, metadata = download_demo('multi_table', 'got_families') - metadata_detect = MultiTableMetadata() + metadata_detect = Metadata() metadata_detect.detect_from_dataframes(data) metadata_detect.relationships = metadata.relationships @@ -1316,16 +1317,7 @@ def test_metadata_updated_warning_detect(self): instance.fit(data) # Assert - future_warnings = 0 - user_warnings = 0 - for warning in record: - if warning.category is FutureWarning: - future_warnings += 1 - if warning.category is UserWarning: - user_warnings += 1 - assert future_warnings == 3 - assert user_warnings == 1 - assert len(record) == 4 + assert len(record) == 1 def test_null_foreign_keys(self): """Test that the synthesizer crashes when there are null foreign keys.""" @@ -1595,7 +1587,7 @@ def test_metadata_updated_warning(method, kwargs): The warning should be raised during synthesizer initialization. """ - metadata = MultiTableMetadata().load_from_dict({ + metadata = Metadata().load_from_dict({ 'tables': { 'departure': { 'primary_key': 'id', diff --git a/tests/unit/io/local/test_local.py b/tests/unit/io/local/test_local.py index 48395f276..05e7b41e3 100644 --- a/tests/unit/io/local/test_local.py +++ b/tests/unit/io/local/test_local.py @@ -36,7 +36,7 @@ def test_create_metadata(self): # Assert assert isinstance(metadata, MultiTableMetadata) assert metadata.to_dict() == { - 'METADATA_SPEC_VERSION': 'MULTI_TABLE_V1', + 'METADATA_SPEC_VERSION': 'V1', 'relationships': [], 'tables': { 'guests': { diff --git a/tests/unit/metadata/test_single_table.py b/tests/unit/metadata/test_single_table.py index c5b979d4f..698ea57bc 100644 --- a/tests/unit/metadata/test_single_table.py +++ b/tests/unit/metadata/test_single_table.py @@ -3163,7 +3163,7 @@ def test_upgrade_metadata_multiple_tables_fails( # Run message = ( 'There are multiple tables specified in the JSON. ' - 'Try using the MultiTableMetadata class to upgrade this file.' + 'Try using the Metadata class to upgrade this file.' ) with pytest.raises(InvalidMetadataError, match=message): SingleTableMetadata.upgrade_metadata('old') diff --git a/tests/unit/multi_table/test_base.py b/tests/unit/multi_table/test_base.py index 22f903e33..219b862bf 100644 --- a/tests/unit/multi_table/test_base.py +++ b/tests/unit/multi_table/test_base.py @@ -18,6 +18,7 @@ SynthesizerInputError, VersionError, ) +from sdv.metadata.metadata import Metadata from sdv.metadata.multi_table import MultiTableMetadata from sdv.metadata.single_table import SingleTableMetadata from sdv.multi_table.base import BaseMultiTableSynthesizer @@ -53,14 +54,22 @@ def test__initialize_models(self): } instance._synthesizer.assert_has_calls([ call( - metadata=instance.metadata.tables['nesreca'], + metadata=ANY, default_distribution='gamma', locales=locales, ), - call(metadata=instance.metadata.tables['oseba'], locales=locales), - call(metadata=instance.metadata.tables['upravna_enota'], locales=locales), + call(metadata=ANY, locales=locales), + call(metadata=ANY, locales=locales), ]) + expected_call_0 = instance.metadata.tables['nesreca'].to_dict() + expected_call_1 = instance.metadata.tables['oseba'].to_dict() + expected_call_2 = instance.metadata.tables['upravna_enota'].to_dict() + call_list = instance._synthesizer.call_args_list + assert call_list[0][1]['metadata'].tables['default_table_name'].to_dict() == expected_call_0 + assert call_list[1][1]['metadata'].tables['default_table_name'].to_dict() == expected_call_1 + assert call_list[2][1]['metadata'].tables['default_table_name'].to_dict() == expected_call_2 + def test__get_pbar_args(self): """Test that ``_get_pbar_args`` returns a dictionary with disable opposite to verbose.""" # Setup @@ -115,6 +124,7 @@ def test___init__( mock_generate_synthesizer_id.return_value = synthesizer_id mock_datetime.datetime.now.return_value = '2024-04-19 16:20:10.037183' metadata = get_multi_table_metadata() + metadata = Metadata.load_from_dict(metadata.to_dict()) metadata.validate = Mock() # Run @@ -122,6 +132,7 @@ def test___init__( instance = BaseMultiTableSynthesizer(metadata) # Assert + assert type(instance.metadata) is Metadata assert instance.metadata == metadata assert isinstance(instance._table_synthesizers['nesreca'], GaussianCopulaSynthesizer) assert isinstance(instance._table_synthesizers['oseba'], GaussianCopulaSynthesizer) @@ -370,12 +381,14 @@ def test_get_metadata(self): """Test that the metadata object is returned.""" # Setup metadata = get_multi_table_metadata() + metadata = Metadata.load_from_dict(metadata.to_dict()) instance = BaseMultiTableSynthesizer(metadata) # Run result = instance.get_metadata() # Assert + assert type(metadata) is Metadata assert metadata == result def test_validate(self): @@ -903,9 +916,16 @@ def test_preprocess_warning(self, mock_warnings): synth_nesreca._preprocess.assert_called_once_with(data['nesreca']) synth_oseba._preprocess.assert_called_once_with(data['oseba']) synth_upravna_enota._preprocess.assert_called_once_with(data['upravna_enota']) - mock_warnings.warn.assert_called_once_with( - 'This model has already been fitted. To use the new preprocessed data, ' - "please refit the model using 'fit' or 'fit_processed_data'." + + arg_list = mock_warnings.warn.call_args_list + assert arg_list[0][0][0] == ( + "The 'MultiTableMetadata' is deprecated. " + "Please use the new 'Metadata' class for synthesizers." + ) + assert arg_list[1][0][0] == ( + 'This model has already been fitted. ' + 'To use the new preprocessed data, please ' + "refit the model using 'fit' or 'fit_processed_data'." ) @patch('sdv.multi_table.base.datetime') diff --git a/tests/unit/multi_table/test_hma.py b/tests/unit/multi_table/test_hma.py index 51d5aa1a4..47b9b2566 100644 --- a/tests/unit/multi_table/test_hma.py +++ b/tests/unit/multi_table/test_hma.py @@ -6,6 +6,7 @@ import pytest from sdv.errors import SynthesizerInputError +from sdv.metadata.metadata import Metadata from sdv.metadata.multi_table import MultiTableMetadata from sdv.multi_table.hma import HMASynthesizer from sdv.single_table.copulas import GaussianCopulaSynthesizer @@ -17,6 +18,7 @@ def test___init__(self): """Test the default initialization of the ``HMASynthesizer``.""" # Run metadata = get_multi_table_metadata() + metadata = Metadata.load_from_dict(metadata.to_dict()) metadata.validate = Mock() instance = HMASynthesizer(metadata)