From a100edcb1079d8be84ab7cad2c1ab94334ffa049 Mon Sep 17 00:00:00 2001 From: R-Palazzo <116157184+R-Palazzo@users.noreply.github.com> Date: Wed, 18 Sep 2024 20:23:41 +0200 Subject: [PATCH] For single-table use cases, make it frictionless to update Metadata (#2228) --- sdv/metadata/metadata.py | 72 +++++++++- sdv/metadata/multi_table.py | 8 +- sdv/multi_table/base.py | 3 +- .../data_processing/test_data_processor.py | 24 ++-- .../evaluation/test_single_table.py | 2 +- tests/integration/metadata/test_metadata.py | 130 +++++++++++------- .../integration/metadata/test_multi_table.py | 12 +- .../metadata/test_visualization.py | 6 +- tests/integration/multi_table/test_hma.py | 102 +++++++------- tests/integration/sequential/test_par.py | 31 +++-- tests/integration/single_table/test_base.py | 32 ++--- .../single_table/test_constraints.py | 34 ++--- .../integration/single_table/test_copulas.py | 6 +- tests/integration/single_table/test_ctgan.py | 12 +- tests/unit/evaluation/test_single_table.py | 48 +++---- tests/unit/metadata/test_metadata.py | 51 +++++++ tests/unit/metadata/test_multi_table.py | 6 +- tests/unit/multi_table/test_base.py | 20 +-- tests/unit/multi_table/test_hma.py | 12 +- tests/unit/multi_table/test_utils.py | 2 +- tests/unit/sequential/test_par.py | 34 +++-- tests/unit/single_table/test_base.py | 37 ++--- tests/unit/single_table/test_copulagan.py | 10 +- tests/unit/single_table/test_copulas.py | 8 +- tests/unit/single_table/test_ctgan.py | 14 +- 25 files changed, 439 insertions(+), 277 deletions(-) diff --git a/sdv/metadata/metadata.py b/sdv/metadata/metadata.py index 053ef5896..06b9c2ddf 100644 --- a/sdv/metadata/metadata.py +++ b/sdv/metadata/metadata.py @@ -161,7 +161,18 @@ def _convert_to_single_table(self): return next(iter(self.tables.values()), SingleTableMetadata()) - def set_sequence_index(self, table_name, column_name): + def _handle_table_name(self, table_name): + if table_name is None: + if len(self.tables) == 1: + table_name = next(iter(self.tables)) + else: + raise ValueError( + 'Metadata contains more than one table, please specify the `table_name`.' + ) + + return table_name + + def set_sequence_index(self, column_name, table_name=None): """Set the sequence index of a table. Args: @@ -170,18 +181,21 @@ def set_sequence_index(self, table_name, column_name): column_name (str): Name of the sequence index column. """ + table_name = self._handle_table_name(table_name) self._validate_table_exists(table_name) self.tables[table_name].set_sequence_index(column_name) - def set_sequence_key(self, table_name, column_name): + def set_sequence_key(self, column_name, table_name=None): """Set the sequence key of a table. Args: - table_name (str): - Name of the table to set the sequence key. column_name (str, tulple[str]): Name (or tuple of names) of the sequence key column(s). + table_name (str): + Name of the table to set the sequence key. + Defaults to None. """ + table_name = self._handle_table_name(table_name) self._validate_table_exists(table_name) self.tables[table_name].set_sequence_key(column_name) @@ -206,3 +220,53 @@ def validate_table(self, data, table_name=None): ) return self.validate_data({table_name: data}) + + def get_column_names(self, table_name=None, **kwargs): + """Return a list of column names that match the given metadata keyword arguments.""" + table_name = self._handle_table_name(table_name) + return super().get_column_names(table_name, **kwargs) + + def update_column(self, column_name, table_name=None, **kwargs): + """Update an existing column for a table in the ``Metadata``.""" + table_name = self._handle_table_name(table_name) + super().update_column(table_name, column_name, **kwargs) + + def update_columns(self, column_names, table_name=None, **kwargs): + """Update the metadata of multiple columns.""" + table_name = self._handle_table_name(table_name) + super().update_columns(table_name, column_names, **kwargs) + + def update_columns_metadata(self, column_metadata, table_name=None): + """Update the metadata of multiple columns.""" + table_name = self._handle_table_name(table_name) + super().update_columns_metadata(table_name, column_metadata) + + def add_column(self, column_name, table_name=None, **kwargs): + """Add a column to the metadata.""" + table_name = self._handle_table_name(table_name) + super().add_column(table_name, column_name, **kwargs) + + def add_column_relationship( + self, + relationship_type, + column_names, + table_name=None, + ): + """Add a column relationship to the metadata.""" + table_name = self._handle_table_name(table_name) + super().add_column_relationship(table_name, relationship_type, column_names) + + def set_primary_key(self, column_name, table_name=None): + """Set the primary key of a table.""" + table_name = self._handle_table_name(table_name) + super().set_primary_key(table_name, column_name) + + def remove_primary_key(self, table_name=None): + """Remove the primary key of a table.""" + table_name = self._handle_table_name(table_name) + super().remove_primary_key(table_name) + + def add_alternate_keys(self, column_names, table_name=None): + """Add alternate keys to a table.""" + table_name = self._handle_table_name(table_name) + super().add_alternate_keys(table_name, column_names) diff --git a/sdv/metadata/multi_table.py b/sdv/metadata/multi_table.py index cb38bdb12..201a8c8c6 100644 --- a/sdv/metadata/multi_table.py +++ b/sdv/metadata/multi_table.py @@ -515,14 +515,18 @@ def _detect_relationships(self, data=None): try: original_foreign_key_sdtype = child_meta.columns[primary_key]['sdtype'] if original_foreign_key_sdtype != 'id': - self.update_column(child_candidate, primary_key, sdtype='id') + self.update_column( + table_name=child_candidate, column_name=primary_key, sdtype='id' + ) self.add_relationship( parent_candidate, child_candidate, primary_key, primary_key ) except InvalidMetadataError: self.update_column( - child_candidate, primary_key, sdtype=original_foreign_key_sdtype + table_name=child_candidate, + column_name=primary_key, + sdtype=original_foreign_key_sdtype, ) continue diff --git a/sdv/multi_table/base.py b/sdv/multi_table/base.py index 0588936cf..843653d06 100644 --- a/sdv/multi_table/base.py +++ b/sdv/multi_table/base.py @@ -76,7 +76,8 @@ def _set_temp_numpy_seed(self): def _initialize_models(self): with disable_single_table_logger(): for table_name, table_metadata in self.metadata.tables.items(): - synthesizer_parameters = self._table_parameters.get(table_name, {}) + synthesizer_parameters = {'locales': self.locales} + synthesizer_parameters.update(self._table_parameters.get(table_name, {})) metadata_dict = {'tables': {table_name: table_metadata.to_dict()}} metadata = Metadata.load_from_dict(metadata_dict) self._table_synthesizers[table_name] = self._synthesizer( diff --git a/tests/integration/data_processing/test_data_processor.py b/tests/integration/data_processing/test_data_processor.py index 65f865d2f..bb87029b4 100644 --- a/tests/integration/data_processing/test_data_processor.py +++ b/tests/integration/data_processing/test_data_processor.py @@ -53,7 +53,7 @@ def test_with_anonymized_columns(self): data, metadata = download_demo('single_table', 'adult') # Add anonymized field - metadata.update_column('adult', 'occupation', sdtype='job', pii=True) + metadata.update_column('occupation', 'adult', sdtype='job', pii=True) # Instance ``DataProcessor`` dp = DataProcessor(metadata._convert_to_single_table()) @@ -101,11 +101,11 @@ def test_with_anonymized_columns_and_primary_key(self): data, metadata = download_demo('single_table', 'adult') # Add anonymized field - metadata.update_column('adult', 'occupation', sdtype='job', pii=True) + metadata.update_column('occupation', 'adult', sdtype='job', pii=True) # Add primary key field - metadata.add_column('adult', 'id', sdtype='id', regex_format='ID_\\d{4}[0-9]') - metadata.set_primary_key('adult', 'id') + metadata.add_column('id', 'adult', sdtype='id', regex_format='ID_\\d{4}[0-9]') + metadata.set_primary_key('id', 'adult') # Add id size = len(data) @@ -159,8 +159,8 @@ def test_with_primary_key_numerical(self): adult_metadata = Metadata.detect_from_dataframes({'adult': data}) # Add primary key field - adult_metadata.add_column('adult', 'id', sdtype='id') - adult_metadata.set_primary_key('adult', 'id') + adult_metadata.add_column('id', 'adult', sdtype='id') + adult_metadata.set_primary_key('id', 'adult') # Add id size = len(data) @@ -198,13 +198,13 @@ def test_with_alternate_keys(self): adult_metadata = Metadata.detect_from_dataframes({'adult': data}) # Add primary key field - adult_metadata.add_column('adult', 'id', sdtype='id') - adult_metadata.set_primary_key('adult', 'id') + adult_metadata.add_column('id', 'adult', sdtype='id') + adult_metadata.set_primary_key('id', 'adult') - adult_metadata.add_column('adult', 'secondary_id', sdtype='id') - adult_metadata.update_column('adult', 'fnlwgt', sdtype='id', regex_format='ID_\\d{4}[0-9]') + adult_metadata.add_column('secondary_id', 'adult', sdtype='id') + adult_metadata.update_column('fnlwgt', 'adult', sdtype='id', regex_format='ID_\\d{4}[0-9]') - adult_metadata.add_alternate_keys('adult', ['secondary_id', 'fnlwgt']) + adult_metadata.add_alternate_keys(['secondary_id', 'fnlwgt'], 'adult') # Add id size = len(data) @@ -345,7 +345,7 @@ def test_localized_anonymized_columns(self): """Test data processor uses the default locale for anonymized columns.""" # Setup data, metadata = download_demo('single_table', 'adult') - metadata.update_column('adult', 'occupation', sdtype='job', pii=True) + metadata.update_column('occupation', 'adult', sdtype='job', pii=True) dp = DataProcessor(metadata._convert_to_single_table(), locales=['en_CA', 'fr_CA']) diff --git a/tests/integration/evaluation/test_single_table.py b/tests/integration/evaluation/test_single_table.py index 8b3cf23fa..d943b02d3 100644 --- a/tests/integration/evaluation/test_single_table.py +++ b/tests/integration/evaluation/test_single_table.py @@ -12,7 +12,7 @@ def test_evaluation(): data = pd.DataFrame({'col': [1, 2, 3]}) metadata = Metadata() metadata.add_table('table') - metadata.add_column('table', 'col', sdtype='numerical') + metadata.add_column('col', 'table', sdtype='numerical') synthesizer = GaussianCopulaSynthesizer(metadata, default_distribution='truncnorm') # Run and Assert diff --git a/tests/integration/metadata/test_metadata.py b/tests/integration/metadata/test_metadata.py index 5467dfb49..fd5909b24 100644 --- a/tests/integration/metadata/test_metadata.py +++ b/tests/integration/metadata/test_metadata.py @@ -1,5 +1,6 @@ import os import re +from copy import deepcopy import pytest @@ -241,56 +242,6 @@ def test_detect_from_csvs(tmp_path): assert metadata.to_dict() == expected_metadata -def test_detect_table_from_csv(tmp_path): - """Test the ``detect_table_from_csv`` method.""" - # Setup - real_data, _ = download_demo(modality='multi_table', dataset_name='fake_hotels') - - metadata = Metadata() - - for table_name, dataframe in real_data.items(): - csv_path = tmp_path / f'{table_name}.csv' - dataframe.to_csv(csv_path, index=False) - - # Run - metadata.detect_table_from_csv('hotels', tmp_path / 'hotels.csv') - - # Assert - metadata.update_column( - table_name='hotels', - column_name='city', - sdtype='categorical', - ) - metadata.update_column( - table_name='hotels', - column_name='state', - sdtype='categorical', - ) - metadata.update_column( - table_name='hotels', - column_name='classification', - sdtype='categorical', - ) - expected_metadata = { - 'tables': { - 'hotels': { - 'columns': { - 'hotel_id': {'sdtype': 'id'}, - 'city': {'sdtype': 'categorical'}, - 'state': {'sdtype': 'categorical'}, - 'rating': {'sdtype': 'numerical'}, - 'classification': {'sdtype': 'categorical'}, - }, - 'primary_key': 'hotel_id', - } - }, - 'relationships': [], - 'METADATA_SPEC_VERSION': 'V1', - } - - assert metadata.to_dict() == expected_metadata - - def test_single_table_compatibility(tmp_path): """Test if SingleTableMetadata still has compatibility with single table synthesizers.""" # Setup @@ -432,3 +383,82 @@ def test_multi_table_compatibility(tmp_path): assert expected_metadata == synthesizer_2.metadata.to_dict() for table in metadata_sample: assert metadata_sample[table].columns.to_list() == loaded_sample[table].columns.to_list() + + +params = [ + ('update_column', ['column_name'], {'column_name': 'has_rewards', 'sdtype': 'categorical'}), + ( + 'update_columns', + ['column_names'], + {'column_names': ['has_rewards', 'billing_address'], 'sdtype': 'categorical'}, + ), + ( + 'update_columns_metadata', + ['column_metadata'], + {'column_metadata': {'has_rewards': {'sdtype': 'categorical'}}}, + ), + ('add_column', ['column_name'], {'column_name': 'has_rewards_2', 'sdtype': 'categorical'}), + ('set_primary_key', ['column_name'], {'column_name': 'billing_address'}), + ('remove_primary_key', [], {}), + ( + 'add_column_relationship', + ['relationship_type', 'column_names'], + {'column_names': ['billing_address'], 'relationship_type': 'address'}, + ), + ('add_alternate_keys', ['column_names'], {'column_names': ['billing_address']}), + ('set_sequence_key', ['column_name'], {'column_name': 'billing_address'}), + ('get_column_names', [], {'sdtype': 'datetime'}), +] + + +@pytest.mark.parametrize('method, args, kwargs', params) +def test_any_metadata_update_single_table(method, args, kwargs): + """Test that any method that updates metadata works for single-table case.""" + # Setup + _, metadata = download_demo('single_table', 'fake_hotel_guests') + metadata.update_column( + table_name='fake_hotel_guests', column_name='billing_address', sdtype='street_address' + ) + parameter = [kwargs[arg] for arg in args] + remaining_kwargs = {key: value for key, value in kwargs.items() if key not in args} + metadata_before = deepcopy(metadata).to_dict() + + # Run + result = getattr(metadata, method)(*parameter, **remaining_kwargs) + + # Assert + expected_dict = metadata.to_dict() + if method != 'get_column_names': + assert expected_dict != metadata_before + else: + assert result == ['checkin_date', 'checkout_date'] + + +@pytest.mark.parametrize('method, args, kwargs', params) +def test_any_metadata_update_multi_table(method, args, kwargs): + """Test that any method that updates metadata works for multi-table case.""" + # Setup + _, metadata = download_demo('multi_table', 'fake_hotels') + metadata.update_column( + table_name='guests', column_name='billing_address', sdtype='street_address' + ) + parameter = [kwargs[arg] for arg in args] + remaining_kwargs = {key: value for key, value in kwargs.items() if key not in args} + metadata_before = deepcopy(metadata).to_dict() + expected_error = re.escape( + 'Metadata contains more than one table, please specify the `table_name`.' + ) + + # Run + with pytest.raises(ValueError, match=expected_error): + getattr(metadata, method)(*parameter, **remaining_kwargs) + + parameter.append('guests') + result = getattr(metadata, method)(*parameter, **remaining_kwargs) + + # Assert + expected_dict = metadata.to_dict() + if method != 'get_column_names': + assert expected_dict != metadata_before + else: + assert result == ['checkin_date', 'checkout_date'] diff --git a/tests/integration/metadata/test_multi_table.py b/tests/integration/metadata/test_multi_table.py index 1384eaf74..23e16de42 100644 --- a/tests/integration/metadata/test_multi_table.py +++ b/tests/integration/metadata/test_multi_table.py @@ -34,11 +34,11 @@ def _validate_sdtypes(cls, columns_to_sdtypes): mock_rdt_transformers.address.RandomLocationGenerator = RandomLocationGeneratorMock _, instance = download_demo('multi_table', 'fake_hotels') - instance.update_column('hotels', 'city', sdtype='city') - instance.update_column('hotels', 'state', sdtype='state') + instance.update_column('city', 'hotels', sdtype='city') + instance.update_column('state', 'hotels', sdtype='state') # Run - instance.add_column_relationship('hotels', 'address', ['city', 'state']) + instance.add_column_relationship('address', ['city', 'state'], 'hotels') # Assert instance.validate() @@ -303,9 +303,9 @@ def test_get_table_metadata(): """Test the ``get_table_metadata`` method.""" # Setup metadata = get_multi_table_metadata() - metadata.add_column('nesreca', 'latitude', sdtype='latitude') - metadata.add_column('nesreca', 'longitude', sdtype='longitude') - metadata.add_column_relationship('nesreca', 'gps', ['latitude', 'longitude']) + metadata.add_column('latitude', 'nesreca', sdtype='latitude') + metadata.add_column('longitude', 'nesreca', sdtype='longitude') + metadata.add_column_relationship('gps', ['latitude', 'longitude'], 'nesreca') # Run table_metadata = metadata.get_table_metadata('nesreca') diff --git a/tests/integration/metadata/test_visualization.py b/tests/integration/metadata/test_visualization.py index 2d801ea5d..23a7aee17 100644 --- a/tests/integration/metadata/test_visualization.py +++ b/tests/integration/metadata/test_visualization.py @@ -26,9 +26,9 @@ def test_visualize_graph_for_multi_table(): data2 = pd.DataFrame({'\\|=/bla@#$324%^,"&*()><...': ['a', 'b', 'c']}) tables = {'1': data1, '2': data2} metadata = Metadata.detect_from_dataframes(tables) - metadata.update_column('1', '\\|=/bla@#$324%^,"&*()><...', sdtype='id') - metadata.update_column('2', '\\|=/bla@#$324%^,"&*()><...', sdtype='id') - metadata.set_primary_key('1', '\\|=/bla@#$324%^,"&*()><...') + metadata.update_column('\\|=/bla@#$324%^,"&*()><...', '1', sdtype='id') + metadata.update_column('\\|=/bla@#$324%^,"&*()><...', '2', sdtype='id') + metadata.set_primary_key('\\|=/bla@#$324%^,"&*()><...', '1') metadata.add_relationship( '1', '2', '\\|=/bla@#$324%^,"&*()><...', '\\|=/bla@#$324%^,"&*()><...' ) diff --git a/tests/integration/multi_table/test_hma.py b/tests/integration/multi_table/test_hma.py index 5efc300bb..288e470ad 100644 --- a/tests/integration/multi_table/test_hma.py +++ b/tests/integration/multi_table/test_hma.py @@ -87,8 +87,8 @@ def test_hma_reset_sampling(self): faker = Faker() data, metadata = download_demo('multi_table', 'got_families') metadata.add_column( - 'characters', 'ssn', + 'characters', sdtype='ssn', ) data['characters']['ssn'] = [faker.lexify() for _ in range(len(data['characters']))] @@ -126,7 +126,7 @@ def test_get_info(self): today = datetime.datetime.today().strftime('%Y-%m-%d') metadata = Metadata() metadata.add_table('tab') - metadata.add_column('tab', 'col', sdtype='numerical') + metadata.add_column('col', 'tab', sdtype='numerical') synthesizer = HMASynthesizer(metadata) # Run @@ -221,12 +221,12 @@ def get_custom_constraint_data_and_metadata(self): metadata = Metadata() metadata.detect_table_from_dataframe('parent', parent_data) - metadata.update_column('parent', 'primary_key', sdtype='id') + metadata.update_column('primary_key', 'parent', sdtype='id') metadata.detect_table_from_dataframe('child', child_data) - metadata.update_column('child', 'user_id', sdtype='id') - metadata.update_column('child', 'id', sdtype='id') - metadata.set_primary_key('parent', 'primary_key') - metadata.set_primary_key('child', 'id') + metadata.update_column('user_id', 'child', sdtype='id') + metadata.update_column('id', 'child', sdtype='id') + metadata.set_primary_key('primary_key', 'parent') + metadata.set_primary_key('id', 'child') metadata.add_relationship( parent_primary_key='primary_key', parent_table_name='parent', @@ -361,10 +361,10 @@ def test_hma_with_inequality_constraint(self): metadata = Metadata() metadata.detect_table_from_dataframe(table_name='parent_table', data=parent_table) - metadata.update_column('parent_table', 'id', sdtype='id') + metadata.update_column('id', 'parent_table', sdtype='id') metadata.detect_table_from_dataframe(table_name='child_table', data=child_table) - metadata.update_column('child_table', 'id', sdtype='id') - metadata.update_column('child_table', 'parent_id', sdtype='id') + metadata.update_column('id', 'child_table', sdtype='id') + metadata.update_column('parent_id', 'child_table', sdtype='id') metadata.set_primary_key(table_name='parent_table', column_name='id') metadata.set_primary_key(table_name='child_table', column_name='id') @@ -452,14 +452,14 @@ def test_hma_primary_key_and_foreign_key_only(self): for table_name, table in data.items(): metadata.detect_table_from_dataframe(table_name, table) - metadata.update_column('users', 'user_id', sdtype='id') - metadata.update_column('sessions', 'session_id', sdtype='id') - metadata.update_column('games', 'game_id', sdtype='id') - metadata.update_column('games', 'session_id', sdtype='id') - metadata.update_column('games', 'user_id', sdtype='id') - metadata.set_primary_key('users', 'user_id') - metadata.set_primary_key('sessions', 'session_id') - metadata.set_primary_key('games', 'game_id') + metadata.update_column('user_id', 'users', sdtype='id') + metadata.update_column('session_id', 'sessions', sdtype='id') + metadata.update_column('game_id', 'games', sdtype='id') + metadata.update_column('session_id', 'games', sdtype='id') + metadata.update_column('user_id', 'games', sdtype='id') + metadata.set_primary_key('user_id', 'users') + metadata.set_primary_key('session_id', 'sessions') + metadata.set_primary_key('game_id', 'games') metadata.add_relationship('users', 'games', 'user_id', 'user_id') metadata.add_relationship('sessions', 'games', 'session_id', 'session_id') @@ -1350,25 +1350,25 @@ def test_null_foreign_keys(self): # Setup metadata = Metadata() metadata.add_table('parent_table1') - metadata.add_column('parent_table1', 'id', sdtype='id') - metadata.set_primary_key('parent_table1', 'id') + metadata.add_column('id', 'parent_table1', sdtype='id') + metadata.set_primary_key('id', 'parent_table1') metadata.add_table('parent_table2') - metadata.add_column('parent_table2', 'id', sdtype='id') - metadata.set_primary_key('parent_table2', 'id') + metadata.add_column('id', 'parent_table2', sdtype='id') + metadata.set_primary_key('id', 'parent_table2') metadata.add_table('child_table1') - metadata.add_column('child_table1', 'id', sdtype='id') - metadata.set_primary_key('child_table1', 'id') - metadata.add_column('child_table1', 'fk1', sdtype='id') - metadata.add_column('child_table1', 'fk2', sdtype='id') + metadata.add_column('id', 'child_table1', sdtype='id') + metadata.set_primary_key('id', 'child_table1') + metadata.add_column('fk1', 'child_table1', sdtype='id') + metadata.add_column('fk2', 'child_table1', sdtype='id') metadata.add_table('child_table2') - metadata.add_column('child_table2', 'id', sdtype='id') - metadata.set_primary_key('child_table2', 'id') - metadata.add_column('child_table2', 'fk1', sdtype='id') - metadata.add_column('child_table2', 'fk2', sdtype='id') - metadata.add_column('child_table2', 'cat_type', sdtype='categorical') + metadata.add_column('id', 'child_table2', sdtype='id') + metadata.set_primary_key('id', 'child_table2') + metadata.add_column('fk1', 'child_table2', sdtype='id') + metadata.add_column('fk2', 'child_table2', sdtype='id') + metadata.add_column('cat_type', 'child_table2', sdtype='categorical') metadata.add_relationship( parent_table_name='parent_table1', @@ -1841,7 +1841,7 @@ def test_fit_and_sample_numerical_col_names(): } ] metadata = Metadata.load_from_dict(metadata_dict) - metadata.set_primary_key('0', '1') + metadata.set_primary_key('1', '0') # Run synth = HMASynthesizer(metadata) @@ -1874,11 +1874,11 @@ def test_detect_from_dataframe_numerical_col(): metadata = Metadata() metadata.detect_table_from_dataframe('parent_data', parent_data) metadata.detect_table_from_dataframe('child_data', child_data) - metadata.update_column('parent_data', '1', sdtype='id') - metadata.update_column('child_data', '3', sdtype='id') - metadata.update_column('child_data', '4', sdtype='id') - metadata.set_primary_key('parent_data', '1') - metadata.set_primary_key('child_data', '4') + metadata.update_column('1', 'parent_data', sdtype='id') + metadata.update_column('3', 'child_data', sdtype='id') + metadata.update_column('4', 'child_data', sdtype='id') + metadata.set_primary_key('1', 'parent_data') + metadata.set_primary_key('4', 'child_data') metadata.add_relationship( parent_primary_key='1', parent_table_name='parent_data', @@ -1887,11 +1887,11 @@ def test_detect_from_dataframe_numerical_col(): ) test_metadata = Metadata.detect_from_dataframes(data) - test_metadata.update_column('parent_data', '1', sdtype='id') - test_metadata.update_column('child_data', '3', sdtype='id') - test_metadata.update_column('child_data', '4', sdtype='id') - test_metadata.set_primary_key('parent_data', '1') - test_metadata.set_primary_key('child_data', '4') + test_metadata.update_column('1', 'parent_data', sdtype='id') + test_metadata.update_column('3', 'child_data', sdtype='id') + test_metadata.update_column('4', 'child_data', sdtype='id') + test_metadata.set_primary_key('1', 'parent_data') + test_metadata.set_primary_key('4', 'child_data') test_metadata.add_relationship( parent_primary_key='1', parent_table_name='parent_data', @@ -2004,13 +2004,13 @@ def test_hma_synthesizer_with_fixed_combinations(): # Creating metadata for the dataset metadata = Metadata.detect_from_dataframes(data) - metadata.update_column('users', 'user_id', sdtype='id') - metadata.update_column('records', 'record_id', sdtype='id') - metadata.update_column('records', 'user_id', sdtype='id') - metadata.update_column('records', 'location_id', sdtype='id') - metadata.update_column('locations', 'location_id', sdtype='id') - metadata.set_primary_key('users', 'user_id') - metadata.set_primary_key('locations', 'location_id') + metadata.update_column('user_id', 'users', sdtype='id') + metadata.update_column('record_id', 'records', sdtype='id') + metadata.update_column('user_id', 'records', sdtype='id') + metadata.update_column('location_id', 'records', sdtype='id') + metadata.update_column('location_id', 'locations', sdtype='id') + metadata.set_primary_key('user_id', 'users') + metadata.set_primary_key('location_id', 'locations') metadata.add_relationship('users', 'records', 'user_id', 'user_id') metadata.add_relationship('locations', 'records', 'location_id', 'location_id') @@ -2052,8 +2052,8 @@ def test_fit_int_primary_key_regex_includes_zero(regex): 'child_data': child_data, } metadata = Metadata.detect_from_dataframes(data) - metadata.update_column('parent_data', 'parent_id', sdtype='id', regex_format=regex) - metadata.set_primary_key('parent_data', 'parent_id') + metadata.update_column('parent_id', 'parent_data', sdtype='id', regex_format=regex) + metadata.set_primary_key('parent_id', 'parent_data') # Run and Assert instance = HMASynthesizer(metadata) diff --git a/tests/integration/sequential/test_par.py b/tests/integration/sequential/test_par.py index 6ae0af54e..620a95cf2 100644 --- a/tests/integration/sequential/test_par.py +++ b/tests/integration/sequential/test_par.py @@ -22,9 +22,11 @@ def _get_par_data_and_metadata(): 'context': ['a', 'a', 'b', 'b'], }) metadata = Metadata.detect_from_dataframes({'table': data}) - metadata.update_column('table', 'entity', sdtype='id') - metadata.set_sequence_key('table', 'entity') - metadata.set_sequence_index('table', 'date') + metadata.update_column('entity', 'table', sdtype='id') + metadata.set_sequence_key('entity', 'table') + + metadata.set_sequence_index('date', 'table') + return data, metadata @@ -34,9 +36,11 @@ def test_par(): data = load_demo() data['date'] = pd.to_datetime(data['date']) metadata = Metadata.detect_from_dataframes({'table': data}) - metadata.update_column('table', 'store_id', sdtype='id') - metadata.set_sequence_key('table', 'store_id') - metadata.set_sequence_index('table', 'date') + metadata.update_column('store_id', 'table', sdtype='id') + metadata.set_sequence_key('store_id', 'table') + + metadata.set_sequence_index('date', 'table') + model = PARSynthesizer( metadata=metadata, context_columns=['region'], @@ -67,9 +71,10 @@ def test_column_after_date_simple(): 'col2': ['hello', 'world'], }) metadata = Metadata.detect_from_dataframes({'table': data}) - metadata.update_column('table', 'col', sdtype='id') - metadata.set_sequence_key('table', 'col') - metadata.set_sequence_index('table', 'date') + metadata.update_column('col', 'table', sdtype='id') + metadata.set_sequence_key('col', 'table') + + metadata.set_sequence_index('date', 'table') # Run model = PARSynthesizer(metadata=metadata, epochs=1) @@ -347,8 +352,10 @@ def test_par_unique_sequence_index_with_enforce_min_max(): ) metadata = Metadata.detect_from_dataframes({'table': test_df}) metadata.update_column(table_name='table', column_name='s_key', sdtype='id') - metadata.set_sequence_key('table', 's_key') - metadata.set_sequence_index('table', 'visits') + metadata.set_sequence_key('s_key', 'table') + + metadata.set_sequence_index('visits', 'table') + synthesizer = PARSynthesizer( metadata, enforce_min_max_values=True, enforce_rounding=False, epochs=100, verbose=True ) @@ -441,7 +448,7 @@ def test_par_categorical_column_represented_by_floats(): # Setup data, metadata = download_demo('sequential', 'nasdaq100_2019') data['category'] = [100.0 if i % 2 == 0 else 50.0 for i in data.index] - metadata.add_column('nasdaq100_2019', 'category', sdtype='categorical') + metadata.add_column('category', 'nasdaq100_2019', sdtype='categorical') # Run synth = PARSynthesizer(metadata) diff --git a/tests/integration/single_table/test_base.py b/tests/integration/single_table/test_base.py index 7ab9bb27c..88265ec1f 100644 --- a/tests/integration/single_table/test_base.py +++ b/tests/integration/single_table/test_base.py @@ -93,9 +93,9 @@ def test_sample_from_conditions_with_batch_size(): metadata = Metadata() metadata.add_table('table') - metadata.add_column('table', 'column1', sdtype='numerical') - metadata.add_column('table', 'column2', sdtype='numerical') - metadata.add_column('table', 'column3', sdtype='numerical') + metadata.add_column('column1', 'table', sdtype='numerical') + metadata.add_column('column2', 'table', sdtype='numerical') + metadata.add_column('column3', 'table', sdtype='numerical') model = GaussianCopulaSynthesizer(metadata) model.fit(data) @@ -120,9 +120,9 @@ def test_sample_from_conditions_negative_float(): metadata = Metadata() metadata.add_table('table') - metadata.add_column('table', 'column1', sdtype='numerical') - metadata.add_column('table', 'column2', sdtype='numerical') - metadata.add_column('table', 'column3', sdtype='numerical') + metadata.add_column('column1', 'table', sdtype='numerical') + metadata.add_column('column2', 'table', sdtype='numerical') + metadata.add_column('column3', 'table', sdtype='numerical') model = GaussianCopulaSynthesizer(metadata) model.fit(data) @@ -203,7 +203,7 @@ def test_sample_keys_are_scrambled(): """Test that the keys are scrambled in the sampled data.""" # Setup data, metadata = download_demo(modality='single_table', dataset_name='fake_hotel_guests') - metadata.update_column('fake_hotel_guests', 'guest_email', sdtype='id', regex_format='[A-Z]{3}') + metadata.update_column('guest_email', 'fake_hotel_guests', sdtype='id', regex_format='[A-Z]{3}') synthesizer = GaussianCopulaSynthesizer(metadata) synthesizer.fit(data) @@ -234,9 +234,9 @@ def test_multiple_fits(): }) metadata = Metadata() metadata.add_table('table') - metadata.add_column('table', 'city', sdtype='categorical') - metadata.add_column('table', 'state', sdtype='categorical') - metadata.add_column('table', 'measurement', sdtype='numerical') + metadata.add_column('city', 'table', sdtype='categorical') + metadata.add_column('state', 'table', sdtype='categorical') + metadata.add_column('measurement', 'table', sdtype='numerical') constraint = { 'constraint_class': 'FixedCombinations', 'constraint_parameters': {'column_names': ['city', 'state']}, @@ -338,7 +338,7 @@ def test_transformers_correctly_auto_assigned(): metadata.update_column( table_name='table', column_name='primary_key', sdtype='id', regex_format='user-[0-9]{3}' ) - metadata.set_primary_key('table', 'primary_key') + metadata.set_primary_key('primary_key', 'table') metadata.update_column(table_name='table', column_name='pii_col', sdtype='address', pii=True) synthesizer = GaussianCopulaSynthesizer( metadata, enforce_min_max_values=False, enforce_rounding=False @@ -428,7 +428,7 @@ def test_auto_assign_transformers_and_update_with_pii(): # Run metadata.update_column(table_name='table', column_name='id', sdtype='first_name') metadata.update_column(table_name='table', column_name='name', sdtype='name') - metadata.set_primary_key('table', 'id') + metadata.set_primary_key('id', 'table') synthesizer = GaussianCopulaSynthesizer(metadata) synthesizer.auto_assign_transformers(data) @@ -457,8 +457,8 @@ def test_refitting_a_model(): metadata = Metadata.detect_from_dataframes({'table': data}) metadata.update_column(table_name='table', column_name='name', sdtype='name') - metadata.update_column('table', 'id', sdtype='id') - metadata.set_primary_key('table', 'id') + metadata.update_column('id', 'table', sdtype='id') + metadata.set_primary_key('id', 'table') synthesizer = GaussianCopulaSynthesizer(metadata) synthesizer.fit(data) @@ -483,7 +483,7 @@ def test_get_info(): today = datetime.datetime.today().strftime('%Y-%m-%d') metadata = Metadata() metadata.add_table('table') - metadata.add_column('table', 'col', sdtype='numerical') + metadata.add_column('col', 'table', sdtype='numerical') synthesizer = GaussianCopulaSynthesizer(metadata) # Run @@ -628,7 +628,7 @@ def test_metadata_updated_no_warning(mock__fit, tmp_path): # Run 3 instance = BaseSingleTableSynthesizer(metadata_detect) - metadata_detect.update_column('mock_table', 'col 1', sdtype='categorical') + metadata_detect.update_column('col 1', 'mock_table', sdtype='categorical') file_name = tmp_path / 'singletable_2.json' metadata_detect.save_to_json(file_name) with warnings.catch_warnings(record=True) as captured_warnings: diff --git a/tests/integration/single_table/test_constraints.py b/tests/integration/single_table/test_constraints.py index 3477d4557..6e7a8c9e0 100644 --- a/tests/integration/single_table/test_constraints.py +++ b/tests/integration/single_table/test_constraints.py @@ -74,9 +74,9 @@ def test_fit_with_unique_constraint_on_data_with_only_index_column(): metadata = Metadata() metadata.add_table('table') - metadata.add_column('table', 'key', sdtype='id') - metadata.add_column('table', 'index', sdtype='categorical') - metadata.set_primary_key('table', 'key') + metadata.add_column('key', 'table', sdtype='id') + metadata.add_column('index', 'table', sdtype='categorical') + metadata.set_primary_key('key', 'table') model = GaussianCopulaSynthesizer(metadata) constraint = { @@ -139,10 +139,10 @@ def test_fit_with_unique_constraint_on_data_which_has_index_column(): metadata = Metadata() metadata.add_table('table') - metadata.add_column('table', 'key', sdtype='id') - metadata.add_column('table', 'index', sdtype='categorical') - metadata.add_column('table', 'test_column', sdtype='categorical') - metadata.set_primary_key('table', 'key') + metadata.add_column('key', 'table', sdtype='id') + metadata.add_column('index', 'table', sdtype='categorical') + metadata.add_column('test_column', 'table', sdtype='categorical') + metadata.set_primary_key('key', 'table') model = GaussianCopulaSynthesizer(metadata) constraint = { @@ -198,9 +198,9 @@ def test_fit_with_unique_constraint_on_data_subset(): metadata = Metadata() metadata.add_table('table') - metadata.add_column('table', 'key', sdtype='id') - metadata.add_column('table', 'test_column', sdtype='categorical') - metadata.set_primary_key('table', 'key') + metadata.add_column('key', 'table', sdtype='id') + metadata.add_column('test_column', 'table', sdtype='categorical') + metadata.set_primary_key('key', 'table') test_df = test_df.iloc[[1, 3, 4]] constraint = { @@ -295,9 +295,9 @@ def test_conditional_sampling_constraint_uses_reject_sampling(gm_mock, isinstanc metadata = Metadata() metadata.add_table('table') - metadata.add_column('table', 'city', sdtype='categorical') - metadata.add_column('table', 'state', sdtype='categorical') - metadata.add_column('table', 'age', sdtype='numerical') + metadata.add_column('city', 'table', sdtype='categorical') + metadata.add_column('state', 'table', sdtype='categorical') + metadata.add_column('age', 'table', sdtype='numerical') model = GaussianCopulaSynthesizer(metadata) @@ -815,8 +815,8 @@ def reverse_transform(column_names, data): 'other': [7, 8, 9], }) metadata = Metadata.detect_from_dataframes({'table': data}) - metadata.update_column('table', 'key', sdtype='id', regex_format=r'\w_\d') - metadata.set_primary_key('table', 'key') + metadata.update_column('key', 'table', sdtype='id', regex_format=r'\w_\d') + metadata.set_primary_key('key', 'table') synth = GaussianCopulaSynthesizer(metadata) synth.add_custom_constraint_class(custom_constraint, 'custom') @@ -845,8 +845,8 @@ def test_timezone_aware_constraints(): metadata = Metadata() metadata.add_table('table') - metadata.add_column('table', 'col1', sdtype='datetime') - metadata.add_column('table', 'col2', sdtype='datetime') + metadata.add_column('col1', 'table', sdtype='datetime') + metadata.add_column('col2', 'table', sdtype='datetime') my_constraint = { 'constraint_class': 'Inequality', diff --git a/tests/integration/single_table/test_copulas.py b/tests/integration/single_table/test_copulas.py index 6a25fe5dc..f22427104 100644 --- a/tests/integration/single_table/test_copulas.py +++ b/tests/integration/single_table/test_copulas.py @@ -278,8 +278,8 @@ def test_update_transformers_with_id_generator(): data = pd.DataFrame({'user_id': list(range(4)), 'user_cat': ['a', 'b', 'c', 'd']}) stm = Metadata.detect_from_dataframes({'table': data}) - stm.update_column('table', 'user_id', sdtype='id') - stm.set_primary_key('table', 'user_id') + stm.update_column('user_id', 'table', sdtype='id') + stm.set_primary_key('user_id', 'table') gc = GaussianCopulaSynthesizer(stm) custom_id = IDGenerator(starting_value=min_value_id) @@ -434,7 +434,7 @@ def test_unknown_sdtype(): }) metadata = Metadata.detect_from_dataframes({'table': data}) - metadata.update_column('table', 'unknown', sdtype='unknown') + metadata.update_column('unknown', 'table', sdtype='unknown') synthesizer = GaussianCopulaSynthesizer(metadata) diff --git a/tests/integration/single_table/test_ctgan.py b/tests/integration/single_table/test_ctgan.py index 9f892f878..0f87c5fe7 100644 --- a/tests/integration/single_table/test_ctgan.py +++ b/tests/integration/single_table/test_ctgan.py @@ -17,12 +17,12 @@ def test__estimate_num_columns(): # Setup metadata = Metadata() metadata.add_table('table') - metadata.add_column('table', 'numerical', sdtype='numerical') - metadata.add_column('table', 'categorical', sdtype='categorical') - metadata.add_column('table', 'categorical2', sdtype='categorical') - metadata.add_column('table', 'categorical3', sdtype='categorical') - metadata.add_column('table', 'datetime', sdtype='datetime') - metadata.add_column('table', 'boolean', sdtype='boolean') + metadata.add_column('numerical', 'table', sdtype='numerical') + metadata.add_column('categorical', 'table', sdtype='categorical') + metadata.add_column('categorical2', 'table', sdtype='categorical') + metadata.add_column('categorical3', 'table', sdtype='categorical') + metadata.add_column('datetime', 'table', sdtype='datetime') + metadata.add_column('boolean', 'table', sdtype='boolean') data = pd.DataFrame({ 'numerical': [0.1, 0.2, 0.3], 'datetime': ['2020-01-01', '2020-01-02', '2020-01-03'], diff --git a/tests/unit/evaluation/test_single_table.py b/tests/unit/evaluation/test_single_table.py index f669a6e40..08106923f 100644 --- a/tests/unit/evaluation/test_single_table.py +++ b/tests/unit/evaluation/test_single_table.py @@ -23,7 +23,7 @@ def test_evaluate_quality(): data2 = pd.DataFrame({'col': [2, 1, 3]}) metadata = Metadata() metadata.add_table('table') - metadata.add_column('table', 'col', sdtype='numerical') + metadata.add_column('col', 'table', sdtype='numerical') QualityReport.generate = Mock() # Run @@ -59,7 +59,7 @@ def test_run_diagnostic(): data2 = pd.DataFrame({'col': [2, 1, 3]}) metadata = Metadata() metadata.add_table('table') - metadata.add_column('table', 'col', sdtype='numerical') + metadata.add_column('col', 'table', sdtype='numerical') DiagnosticReport.generate = Mock(return_value=123) # Run @@ -100,7 +100,7 @@ def test_get_column_plot_continuous_data(mock_get_plot): data2 = pd.DataFrame({'col': [2, 1, 3]}) metadata = Metadata() metadata.add_table('table') - metadata.add_column('table', 'col', sdtype='numerical') + metadata.add_column('col', 'table', sdtype='numerical') # Run plot = get_column_plot(data1, data2, metadata, 'col') @@ -143,7 +143,7 @@ def test_get_column_plot_discrete_data(mock_get_plot): data2 = pd.DataFrame({'col': ['a', 'b', 'c']}) metadata = Metadata() metadata.add_table('table') - metadata.add_column('table', 'col', sdtype='categorical') + metadata.add_column('col', 'table', sdtype='categorical') # Run plot = get_column_plot(data1, data2, metadata, 'col') @@ -187,7 +187,7 @@ def test_get_column_plot_discrete_data_with_distplot(mock_get_plot): data2 = pd.DataFrame({'col': ['a', 'b', 'c']}) metadata = Metadata() metadata.add_table('table') - metadata.add_column('table', 'col', sdtype='categorical') + metadata.add_column('col', 'table', sdtype='categorical') # Run plot = get_column_plot(data1, data2, metadata, 'col', plot_type='distplot') @@ -231,7 +231,7 @@ def test_get_column_plot_invalid_sdtype(mock_get_plot): data2 = pd.DataFrame({'col': ['a', 'b', 'c']}) metadata = Metadata() metadata.add_table('table') - metadata.add_column('table', 'col', sdtype='id') + metadata.add_column('col', 'table', sdtype='id') # Run and Assert error_msg = re.escape( @@ -276,7 +276,7 @@ def test_get_column_plot_invalid_sdtype_with_plot_type(mock_get_plot): data2 = pd.DataFrame({'col': ['a', 'b', 'c']}) metadata = Metadata() metadata.add_table('table') - metadata.add_column('table', 'col', sdtype='id') + metadata.add_column('col', 'table', sdtype='id') # Run plot = get_column_plot(data1, data2, metadata, 'col', plot_type='bar') @@ -319,7 +319,7 @@ def test_get_column_plot_with_datetime_sdtype(mock_get_plot): synthetic_data = pd.DataFrame({'datetime': ['2023-02-21', '2022-12-13']}) metadata = Metadata() metadata.add_table('table') - metadata.add_column('table', 'datetime', sdtype='datetime', datetime_format='%Y-%m-%d') + metadata.add_column('datetime', 'table', sdtype='datetime', datetime_format='%Y-%m-%d') # Run plot = get_column_plot(real_data, synthetic_data, metadata, 'datetime') @@ -354,8 +354,8 @@ def test_get_column_pair_plot_with_continous_data(mock_get_plot): }) metadata = Metadata() metadata.add_table('table') - metadata.add_column('table', 'amount', sdtype='numerical') - metadata.add_column('table', 'date', sdtype='datetime') + metadata.add_column('amount', 'table', sdtype='numerical') + metadata.add_column('date', 'table', sdtype='datetime') # Run plot = get_column_pair_plot(real_data, synthetic_data, metadata, columns) @@ -389,8 +389,8 @@ def test_get_column_pair_plot_with_discrete_data(mock_get_plot): synthetic_data = pd.DataFrame({'name': ['John', 'Johanna'], 'subscriber': [False, False]}) metadata = Metadata() metadata.add_table('table') - metadata.add_column('table', 'name', sdtype='categorical') - metadata.add_column('table', 'subscriber', sdtype='boolean') + metadata.add_column('name', 'table', sdtype='categorical') + metadata.add_column('subscriber', 'table', sdtype='boolean') # Run plot = get_column_pair_plot(real_data, synthetic_data, metadata, columns) @@ -416,8 +416,8 @@ def test_get_column_pair_plot_with_mixed_data(mock_get_plot): synthetic_data = pd.DataFrame({'name': ['John', 'Johanna'], 'counts': [3, 1]}) metadata = Metadata() metadata.add_table('table') - metadata.add_column('table', 'name', sdtype='categorical') - metadata.add_column('table', 'counts', sdtype='numerical') + metadata.add_column('name', 'table', sdtype='categorical') + metadata.add_column('counts', 'table', sdtype='numerical') # Run plot = get_column_pair_plot(real_data, synthetic_data, metadata, columns) @@ -449,8 +449,8 @@ def test_get_column_pair_plot_with_forced_plot_type(mock_get_plot): }) metadata = Metadata() metadata.add_table('table') - metadata.add_column('table', 'amount', sdtype='numerical') - metadata.add_column('table', 'date', sdtype='datetime') + metadata.add_column('amount', 'table', sdtype='numerical') + metadata.add_column('date', 'table', sdtype='datetime') # Run plot = get_column_pair_plot(real_data, synthetic_data, metadata, columns, plot_type='heatmap') @@ -491,8 +491,8 @@ def test_get_column_pair_plot_with_invalid_sdtype(mock_get_plot): }) metadata = Metadata() metadata.add_table('table') - metadata.add_column('table', 'amount', sdtype='numerical') - metadata.add_column('table', 'id', sdtype='id') + metadata.add_column('amount', 'table', sdtype='numerical') + metadata.add_column('id', 'table', sdtype='id') # Run and Assert error_msg = re.escape( @@ -522,8 +522,8 @@ def test_get_column_pair_plot_with_invalid_sdtype_and_plot_type(mock_get_plot): }) metadata = Metadata() metadata.add_table('table') - metadata.add_column('table', 'amount', sdtype='numerical') - metadata.add_column('table', 'id', sdtype='id') + metadata.add_column('amount', 'table', sdtype='numerical') + metadata.add_column('id', 'table', sdtype='id') # Run plot = get_column_pair_plot(real_data, synthetic_data, metadata, columns, plot_type='heatmap') @@ -551,8 +551,8 @@ def test_get_column_pair_plot_with_sample_size(mock_get_plot): }) metadata = Metadata() metadata.add_table('table') - metadata.add_column('table', 'amount', sdtype='numerical') - metadata.add_column('table', 'price', sdtype='numerical') + metadata.add_column('amount', 'table', sdtype='numerical') + metadata.add_column('price', 'table', sdtype='numerical') # Run get_column_pair_plot(real_data, synthetic_data, metadata, columns, sample_size=2) @@ -614,8 +614,8 @@ def test_get_column_pair_plot_with_sample_size_too_big(mock_get_plot): }) metadata = Metadata() metadata.add_table('table') - metadata.add_column('table', 'amount', sdtype='numerical') - metadata.add_column('table', 'price', sdtype='numerical') + metadata.add_column('amount', 'table', sdtype='numerical') + metadata.add_column('price', 'table', sdtype='numerical') # Run plot = get_column_pair_plot(real_data, synthetic_data, metadata, columns, sample_size=10) diff --git a/tests/unit/metadata/test_metadata.py b/tests/unit/metadata/test_metadata.py index 1524c0757..c7f38052e 100644 --- a/tests/unit/metadata/test_metadata.py +++ b/tests/unit/metadata/test_metadata.py @@ -643,3 +643,54 @@ def test_detect_from_dataframe_raises_error_if_not_dataframe(self): expected_message = 'The provided data must be a pandas DataFrame object.' with pytest.raises(ValueError, match=expected_message): Metadata.detect_from_dataframe(Mock()) + + def test__handle_table_name(self): + """Test the ``_handle_table_name`` method.""" + # Setup + metadata = Metadata() + metadata.tables = ['table_name'] + expected_error = re.escape( + 'Metadata contains more than one table, please specify the `table_name`.' + ) + + # Run + result_none = metadata._handle_table_name(None) + result_table_1 = metadata._handle_table_name('table_1') + metadata.tables = ['table_1', 'table_2'] + with pytest.raises(ValueError, match=expected_error): + metadata._handle_table_name(None) + + result_table_2 = metadata._handle_table_name('table_2') + + # Assert + assert result_none == 'table_name' + assert result_table_1 == 'table_1' + assert result_table_2 == 'table_2' + + params = [ + ('update_column', ['column_name']), + ('update_columns', ['column_names']), + ('update_columns_metadata', ['column_metadata']), + ('add_column', ['column_name']), + ('set_primary_key', ['column_name']), + ('remove_primary_key', []), + ('add_column_relationship', ['relationship_type', 'column_names']), + ('add_alternate_keys', ['column_names']), + ('get_column_names', []), + ] + + @pytest.mark.parametrize('method, args', params) + def test_update_methods(self, method, args): + """Test that all update methods call the superclass method with the resolved arguments.""" + # Setup + metadata = Metadata() + metadata._handle_table_name = Mock(return_value='table_name') + superclass = Metadata.__bases__[0] + + with patch.object(superclass, method) as mock_super_method: + # Run + getattr(metadata, method)(*args, 'table_name') + + # Assert + metadata._handle_table_name.assert_called_once_with('table_name') + mock_super_method.assert_called_once_with('table_name', *args) diff --git a/tests/unit/metadata/test_multi_table.py b/tests/unit/metadata/test_multi_table.py index cd20b7434..aefe8eade 100644 --- a/tests/unit/metadata/test_multi_table.py +++ b/tests/unit/metadata/test_multi_table.py @@ -1396,11 +1396,11 @@ def test_validate_data_datetime_warning(self): '20220826', '20220826', ]) - metadata.add_column('upravna_enota', 'warning_date_str', sdtype='datetime') + metadata.add_column('warning_date_str', 'upravna_enota', sdtype='datetime') metadata.add_column( - 'upravna_enota', 'valid_date', sdtype='datetime', datetime_format='%Y%m%d%H%M%S%f' + 'valid_date', 'upravna_enota', sdtype='datetime', datetime_format='%Y%m%d%H%M%S%f' ) - metadata.add_column('upravna_enota', 'datetime', sdtype='datetime') + metadata.add_column('datetime', 'upravna_enota', sdtype='datetime') # Run and Assert warning_df = pd.DataFrame({ diff --git a/tests/unit/multi_table/test_base.py b/tests/unit/multi_table/test_base.py index 8e0c60629..446c92531 100644 --- a/tests/unit/multi_table/test_base.py +++ b/tests/unit/multi_table/test_base.py @@ -169,10 +169,10 @@ def test__init__column_relationship_warning(self, mock_is_faker_function): # Setup mock_is_faker_function.return_value = True metadata = get_multi_table_metadata() - metadata.add_column('nesreca', 'lat', sdtype='latitude') - metadata.add_column('nesreca', 'lon', sdtype='longitude') + metadata.add_column('lat', 'nesreca', sdtype='latitude') + metadata.add_column('lon', 'nesreca', sdtype='longitude') - metadata.add_column_relationship('nesreca', 'gps', ['lat', 'lon']) + metadata.add_column_relationship('gps', ['lat', 'lon'], 'nesreca') expected_warning = ( "The metadata contains a column relationship of type 'gps' " @@ -540,7 +540,7 @@ def test_validate_constraints_not_met(self): metadata = get_multi_table_metadata() data = get_multi_table_data() data['nesreca']['val'] = list(range(4)) - metadata.add_column('nesreca', 'val', sdtype='numerical') + metadata.add_column('val', 'nesreca', sdtype='numerical') instance = BaseMultiTableSynthesizer(metadata) inequality_constraint = { 'constraint_class': 'Inequality', @@ -1344,8 +1344,8 @@ def test_add_constraints(self): # Setup metadata = get_multi_table_metadata() instance = BaseMultiTableSynthesizer(metadata) - metadata.add_column('nesreca', 'positive_int', sdtype='numerical') - metadata.add_column('oseba', 'negative_int', sdtype='numerical') + metadata.add_column('positive_int', 'nesreca', sdtype='numerical') + metadata.add_column('negative_int', 'oseba', sdtype='numerical') positive_constraint = { 'constraint_class': 'Positive', 'table_name': 'nesreca', @@ -1401,8 +1401,8 @@ def test_get_constraints(self): # Setup metadata = get_multi_table_metadata() instance = BaseMultiTableSynthesizer(metadata) - metadata.add_column('nesreca', 'positive_int', sdtype='numerical') - metadata.add_column('oseba', 'negative_int', sdtype='numerical') + metadata.add_column('positive_int', 'nesreca', sdtype='numerical') + metadata.add_column('negative_int', 'oseba', sdtype='numerical') positive_constraint = { 'constraint_class': 'Positive', 'table_name': 'nesreca', @@ -1527,7 +1527,7 @@ def test_get_info(self, mock_version): data = {'tab': pd.DataFrame({'col': [1, 2, 3]})} metadata = Metadata() metadata.add_table('tab') - metadata.add_column('tab', 'col', sdtype='numerical') + metadata.add_column('col', 'tab', sdtype='numerical') mock_version.public = '1.0.0' mock_version.enterprise = None @@ -1575,7 +1575,7 @@ def test_get_info_with_enterprise(self, mock_version): data = {'tab': pd.DataFrame({'col': [1, 2, 3]})} metadata = Metadata() metadata.add_table('tab') - metadata.add_column('tab', 'col', sdtype='numerical') + metadata.add_column('col', 'tab', sdtype='numerical') mock_version.public = '1.0.0' mock_version.enterprise = '1.1.0' diff --git a/tests/unit/multi_table/test_hma.py b/tests/unit/multi_table/test_hma.py index 72abc0e25..727b40f5a 100644 --- a/tests/unit/multi_table/test_hma.py +++ b/tests/unit/multi_table/test_hma.py @@ -166,9 +166,9 @@ def test__augment_table(self): # Setup metadata = get_multi_table_metadata() instance = HMASynthesizer(metadata) - metadata.add_column('nesreca', 'value', sdtype='numerical') - metadata.add_column('oseba', 'oseba_value', sdtype='numerical') - metadata.add_column('upravna_enota', 'name', sdtype='categorical') + metadata.add_column('value', 'nesreca', sdtype='numerical') + metadata.add_column('oseba_value', 'oseba', sdtype='numerical') + metadata.add_column('name', 'upravna_enota', sdtype='categorical') data = get_multi_table_data() data['nesreca']['value'] = [0, 1, 2, 3] @@ -712,9 +712,9 @@ def test_get_learned_distributions_raises_an_error(self): """Test that ``get_learned_distributions`` raises an error.""" # Setup metadata = get_multi_table_metadata() - metadata.add_column('nesreca', 'value', sdtype='numerical') - metadata.add_column('oseba', 'value', sdtype='numerical') - metadata.add_column('upravna_enota', 'a_value', sdtype='numerical') + metadata.add_column('value', 'nesreca', sdtype='numerical') + metadata.add_column('value', 'oseba', sdtype='numerical') + metadata.add_column('a_value', 'upravna_enota', sdtype='numerical') instance = HMASynthesizer(metadata) # Run and Assert diff --git a/tests/unit/multi_table/test_utils.py b/tests/unit/multi_table/test_utils.py index 57b52fa08..a08ddfb38 100644 --- a/tests/unit/multi_table/test_utils.py +++ b/tests/unit/multi_table/test_utils.py @@ -1947,7 +1947,7 @@ def test__subsample_data( def test__subsample_data_with_null_foreing_keys(): """Test the ``_subsample_data`` method when there are null foreign keys.""" # Setup - metadata = MultiTableMetadata.load_from_dict({ + metadata = Metadata.load_from_dict({ 'tables': { 'parent': { 'columns': { diff --git a/tests/unit/sequential/test_par.py b/tests/unit/sequential/test_par.py index d7dae5832..e1062fcfd 100644 --- a/tests/unit/sequential/test_par.py +++ b/tests/unit/sequential/test_par.py @@ -21,15 +21,15 @@ class TestPARSynthesizer: def get_metadata(self, add_sequence_key=True, add_sequence_index=False): metadata = Metadata() metadata.add_table('table') - metadata.add_column('table', 'time', sdtype='datetime') - metadata.add_column('table', 'gender', sdtype='categorical') - metadata.add_column('table', 'name', sdtype='id') - metadata.add_column('table', 'measurement', sdtype='numerical') + metadata.add_column('time', 'table', sdtype='datetime') + metadata.add_column('gender', 'table', sdtype='categorical') + metadata.add_column('name', 'table', sdtype='id') + metadata.add_column('measurement', 'table', sdtype='numerical') if add_sequence_key: - metadata.set_sequence_key('table', 'name') + metadata.set_sequence_key('name', 'table') if add_sequence_index: - metadata.set_sequence_index('table', 'time') + metadata.set_sequence_index('time', 'table') return metadata @@ -261,11 +261,12 @@ def test_validate_context_columns_unique_per_sequence_key(self): }) metadata = Metadata() metadata.add_table('table') - metadata.add_column('table', 'sk_col1', sdtype='id') - metadata.add_column('table', 'sk_col2', sdtype='id') - metadata.add_column('table', 'ct_col1', sdtype='numerical') - metadata.add_column('table', 'ct_col2', sdtype='numerical') - metadata.set_sequence_key('table', 'sk_col1') + metadata.add_column('sk_col1', 'table', sdtype='id') + metadata.add_column('sk_col2', 'table', sdtype='id') + metadata.add_column('ct_col1', 'table', sdtype='numerical') + metadata.add_column('ct_col2', 'table', sdtype='numerical') + metadata.set_sequence_key('sk_col1', 'table') + instance = PARSynthesizer(metadata=metadata, context_columns=['ct_col1', 'ct_col2']) # Run and Assert @@ -529,7 +530,8 @@ def test_auto_assign_transformers_without_enforce_min_max(self, mock_get_transfo 'measurement': [55, 60, 65], }) metadata = self.get_metadata() - metadata.set_sequence_index('table', 'time') + metadata.set_sequence_index('time', 'table') + mock_get_transfomers.return_value = {'time': FloatFormatter} # Run @@ -611,7 +613,7 @@ def test__fit_sequence_columns_with_categorical_float( data = self.get_data() data['measurement'] = data['measurement'].astype(float) metadata = self.get_metadata() - metadata.update_column('table', 'measurement', sdtype='categorical') + metadata.update_column('measurement', 'table', sdtype='categorical') par = PARSynthesizer(metadata=metadata, context_columns=['gender']) sequences = [ {'context': np.array(['M'], dtype=object), 'data': [['2020-01-03'], [65.0]]}, @@ -649,7 +651,8 @@ def test__fit_sequence_columns_with_sequence_index(self, assemble_sequences_mock 'measurement': [55, 60, 65, 65, 70], }) metadata = self.get_metadata() - metadata.set_sequence_index('table', 'time') + metadata.set_sequence_index('time', 'table') + par = PARSynthesizer(metadata=metadata, context_columns=['gender']) sequences = [ {'context': np.array(['F'], dtype=object), 'data': [[1, 1], [55, 60], [1, 1]]}, @@ -840,7 +843,8 @@ def test__sample_from_par_with_sequence_index(self, tqdm_mock): """ # Setup metadata = self.get_metadata() - metadata.set_sequence_index('table', 'time') + metadata.set_sequence_index('time', 'table') + par = PARSynthesizer(metadata=metadata, context_columns=['gender']) model_mock = Mock() par._model = model_mock diff --git a/tests/unit/single_table/test_base.py b/tests/unit/single_table/test_base.py index 158f71325..5541030f7 100644 --- a/tests/unit/single_table/test_base.py +++ b/tests/unit/single_table/test_base.py @@ -709,10 +709,11 @@ def test_update_transformers_invalid_keys(self): column_name_to_transformer = {'col2': RegexGenerator(), 'col3': FloatFormatter()} metadata = Metadata() metadata.add_table('table') - metadata.add_column('table', 'col2', sdtype='id') - metadata.add_column('table', 'col3', sdtype='id') - metadata.set_sequence_key('table', 'col2') - metadata.add_alternate_keys('table', ['col3']) + metadata.add_column('col2', 'table', sdtype='id') + metadata.add_column('col3', 'table', sdtype='id') + metadata.set_sequence_key('col2', 'table') + + metadata.add_alternate_keys(['col3'], 'table') instance = BaseSingleTableSynthesizer(metadata) # Run and Assert @@ -731,8 +732,8 @@ def test_update_transformers_already_fitted(self): column_name_to_transformer = {'col1': BinaryEncoder(), 'col2': fitted_transformer} metadata = Metadata() metadata.add_table('table') - metadata.add_column('table', 'col1', sdtype='boolean') - metadata.add_column('table', 'col2', sdtype='numerical') + metadata.add_column('col1', 'table', sdtype='boolean') + metadata.add_column('col2', 'table', sdtype='numerical') instance = BaseSingleTableSynthesizer(metadata) # Run and Assert @@ -746,8 +747,8 @@ def test_update_transformers_warns_gaussian_copula(self): column_name_to_transformer = {'col1': OneHotEncoder(), 'col2': FloatFormatter()} metadata = Metadata() metadata.add_table('table') - metadata.add_column('table', 'col1', sdtype='categorical') - metadata.add_column('table', 'col2', sdtype='numerical') + metadata.add_column('col1', 'table', sdtype='categorical') + metadata.add_column('col2', 'table', sdtype='numerical') instance = GaussianCopulaSynthesizer(metadata) instance._data_processor.fit(pd.DataFrame({'col1': [1, 2], 'col2': [1, 2]})) @@ -774,8 +775,8 @@ def test_update_transformers_warns_models(self): column_name_to_transformer = {'col1': OneHotEncoder(), 'col2': FloatFormatter()} metadata = Metadata() metadata.add_table('table') - metadata.add_column('table', 'col1', sdtype='categorical') - metadata.add_column('table', 'col2', sdtype='numerical') + metadata.add_column('col1', 'table', sdtype='categorical') + metadata.add_column('col2', 'table', sdtype='numerical') # NOTE: when PARSynthesizer is implemented, add it here as well for model in [CTGANSynthesizer, CopulaGANSynthesizer, TVAESynthesizer]: @@ -800,8 +801,8 @@ def test_update_transformers_warns_fitted(self): column_name_to_transformer = {'col1': GaussianNormalizer(), 'col2': GaussianNormalizer()} metadata = Metadata() metadata.add_table('table') - metadata.add_column('table', 'col1', sdtype='numerical') - metadata.add_column('table', 'col2', sdtype='numerical') + metadata.add_column('col1', 'table', sdtype='numerical') + metadata.add_column('col2', 'table', sdtype='numerical') instance = BaseSingleTableSynthesizer(metadata) instance._data_processor.fit(pd.DataFrame({'col1': [1, 2], 'col2': [1, 2]})) instance._fitted = True @@ -819,8 +820,8 @@ def test_update_transformers(self): column_name_to_transformer = {'col1': GaussianNormalizer(), 'col2': GaussianNormalizer()} metadata = Metadata() metadata.add_table('table') - metadata.add_column('table', 'col1', sdtype='numerical') - metadata.add_column('table', 'col2', sdtype='numerical') + metadata.add_column('col1', 'table', sdtype='numerical') + metadata.add_column('col2', 'table', sdtype='numerical') instance = BaseSingleTableSynthesizer(metadata) instance._data_processor.fit(pd.DataFrame({'col1': [1, 2], 'col2': [1, 2]})) @@ -2053,7 +2054,7 @@ def test_add_constraints(self): # Setup metadata = Metadata() metadata.add_table('table') - metadata.add_column('table', 'col', sdtype='numerical') + metadata.add_column('col', 'table', sdtype='numerical') instance = BaseSingleTableSynthesizer(metadata) positive_constraint = { 'constraint_class': 'Positive', @@ -2084,7 +2085,7 @@ def test_get_constraints(self): # Setup metadata = Metadata() metadata.add_table('table') - metadata.add_column('table', 'col', sdtype='numerical') + metadata.add_column('col', 'table', sdtype='numerical') instance = BaseSingleTableSynthesizer(metadata) positive_constraint = { 'constraint_class': 'Positive', @@ -2119,7 +2120,7 @@ def test_get_info_no_enterprise(self, mock_sdv_version): mock_sdv_version.enterprise = None metadata = Metadata() metadata.add_table('table') - metadata.add_column('table', 'col', sdtype='numerical') + metadata.add_column('col', 'table', sdtype='numerical') with patch('sdv.single_table.base.datetime.datetime') as mock_date: mock_date.today.return_value = datetime(2023, 1, 23) @@ -2167,7 +2168,7 @@ def test_get_info_with_enterprise(self, mock_sdv_version): mock_sdv_version.enterprise = '1.2.0' metadata = Metadata() metadata.add_table('table') - metadata.add_column('table', 'col', sdtype='numerical') + metadata.add_column('col', 'table', sdtype='numerical') with patch('sdv.single_table.base.datetime.datetime') as mock_date: mock_date.today.return_value = datetime(2023, 1, 23) diff --git a/tests/unit/single_table/test_copulagan.py b/tests/unit/single_table/test_copulagan.py index e7c531a26..354ca841c 100644 --- a/tests/unit/single_table/test_copulagan.py +++ b/tests/unit/single_table/test_copulagan.py @@ -90,7 +90,7 @@ def test___init__custom(self): # Setup metadata = Metadata() metadata.add_table('table') - metadata.add_column('table', 'field', sdtype='numerical') + metadata.add_column('field', 'table', sdtype='numerical') enforce_min_max_values = False enforce_rounding = False embedding_dim = 64 @@ -226,9 +226,9 @@ def test__create_gaussian_normalizer_config(self, mock_rdt): numerical_distributions = {'age': 'gamma'} metadata = Metadata() metadata.add_table('table') - metadata.add_column('table', 'name', sdtype='categorical') - metadata.add_column('table', 'age', sdtype='numerical') - metadata.add_column('table', 'account', sdtype='numerical') + metadata.add_column('name', 'table', sdtype='categorical') + metadata.add_column('age', 'table', sdtype='numerical') + metadata.add_column('account', 'table', sdtype='numerical') instance = CopulaGANSynthesizer(metadata, numerical_distributions=numerical_distributions) processed_data = pd.DataFrame({ @@ -275,7 +275,7 @@ def test__fit_logging(self, mock_rdt, mock_ctgansynthesizer__fit, mock_logger): # Setup metadata = Metadata() metadata.add_table('table') - metadata.add_column('table', 'col', sdtype='numerical') + metadata.add_column('col', 'table', sdtype='numerical') numerical_distributions = {'col': 'gamma'} instance = CopulaGANSynthesizer(metadata, numerical_distributions=numerical_distributions) processed_data = pd.DataFrame() diff --git a/tests/unit/single_table/test_copulas.py b/tests/unit/single_table/test_copulas.py index be523196d..57fc1dd08 100644 --- a/tests/unit/single_table/test_copulas.py +++ b/tests/unit/single_table/test_copulas.py @@ -92,7 +92,7 @@ def test___init__custom(self): # Setup metadata = Metadata() metadata.add_table('table') - metadata.add_column('table', 'field', sdtype='numerical') + metadata.add_column('field', 'table', sdtype='numerical') enforce_min_max_values = False enforce_rounding = False numerical_distributions = {'field': 'gamma'} @@ -169,7 +169,7 @@ def test__fit_logging(self, mock_logger): # Setup metadata = Metadata() metadata.add_table('table') - metadata.add_column('table', 'col', sdtype='numerical') + metadata.add_column('col', 'table', sdtype='numerical') numerical_distributions = {'col': 'gamma'} instance = GaussianCopulaSynthesizer( metadata, numerical_distributions=numerical_distributions @@ -197,8 +197,8 @@ def test__fit(self, mock_multivariate, mock_warnings): # Setup metadata = Metadata() metadata.add_table('table') - metadata.add_column('table', 'name', sdtype='numerical') - metadata.add_column('table', 'user.id', sdtype='numerical') + metadata.add_column('name', 'table', sdtype='numerical') + metadata.add_column('user.id', 'table', sdtype='numerical') numerical_distributions = {'name': 'uniform', 'user.id': 'gamma'} processed_data = pd.DataFrame({ diff --git a/tests/unit/single_table/test_ctgan.py b/tests/unit/single_table/test_ctgan.py index 5cb3e6000..a3b946f4b 100644 --- a/tests/unit/single_table/test_ctgan.py +++ b/tests/unit/single_table/test_ctgan.py @@ -189,9 +189,9 @@ def test__estimate_num_columns(self): # Setup metadata = Metadata() metadata.add_table('table') - metadata.add_column('table', 'id', sdtype='numerical') - metadata.add_column('table', 'name', sdtype='categorical') - metadata.add_column('table', 'surname', sdtype='categorical') + metadata.add_column('id', 'table', sdtype='numerical') + metadata.add_column('name', 'table', sdtype='categorical') + metadata.add_column('surname', 'table', sdtype='categorical') data = pd.DataFrame({ 'id': np.random.rand(1_001), 'name': [f'cat_{i}' for i in range(1_001)], @@ -215,8 +215,8 @@ def test_preprocessing_many_categories(self, capfd): # Setup metadata = Metadata() metadata.add_table('table') - metadata.add_column('table', 'name_longer_than_Original_Column_Name', sdtype='numerical') - metadata.add_column('table', 'categorical', sdtype='categorical') + metadata.add_column('name_longer_than_Original_Column_Name', 'table', sdtype='numerical') + metadata.add_column('categorical', 'table', sdtype='categorical') data = pd.DataFrame({ 'name_longer_than_Original_Column_Name': np.random.rand(1_001), 'categorical': [f'cat_{i}' for i in range(1_001)], @@ -246,8 +246,8 @@ def test_preprocessing_few_categories(self, capfd): # Setup metadata = Metadata() metadata.add_table('table') - metadata.add_column('table', 'name_longer_than_Original_Column_Name', sdtype='numerical') - metadata.add_column('table', 'categorical', sdtype='categorical') + metadata.add_column('name_longer_than_Original_Column_Name', 'table', sdtype='numerical') + metadata.add_column('categorical', 'table', sdtype='categorical') data = pd.DataFrame({ 'name_longer_than_Original_Column_Name': np.random.rand(10), 'categorical': [f'cat_{i}' for i in range(10)],