From 089f376c9710a7c030c3115e2a80fa05882246d3 Mon Sep 17 00:00:00 2001 From: John La Date: Wed, 17 Jul 2024 11:47:50 -0500 Subject: [PATCH 01/16] Add Metadata Class (#2135) --- sdv/metadata/__init__.py | 2 + sdv/metadata/metadata.py | 73 +++ tests/integration/metadata/test_metadata.py | 218 ++++++++ tests/unit/metadata/test_metadata.py | 545 ++++++++++++++++++++ 4 files changed, 838 insertions(+) create mode 100644 sdv/metadata/metadata.py create mode 100644 tests/integration/metadata/test_metadata.py create mode 100644 tests/unit/metadata/test_metadata.py diff --git a/sdv/metadata/__init__.py b/sdv/metadata/__init__.py index 71d689727..5d0ca7e1e 100644 --- a/sdv/metadata/__init__.py +++ b/sdv/metadata/__init__.py @@ -4,9 +4,11 @@ from sdv.metadata.errors import InvalidMetadataError, MetadataNotFittedError from sdv.metadata.multi_table import MultiTableMetadata from sdv.metadata.single_table import SingleTableMetadata +from sdv.metadata.metadata import Metadata __all__ = ( 'InvalidMetadataError', + 'Metadata', 'MetadataNotFittedError', 'MultiTableMetadata', 'SingleTableMetadata', diff --git a/sdv/metadata/metadata.py b/sdv/metadata/metadata.py new file mode 100644 index 000000000..1b628c9b9 --- /dev/null +++ b/sdv/metadata/metadata.py @@ -0,0 +1,73 @@ +"""Metadata.""" + +from pathlib import Path + +from sdv.metadata.multi_table import MultiTableMetadata +from sdv.metadata.single_table import SingleTableMetadata +from sdv.metadata.utils import read_json + + +class Metadata(MultiTableMetadata): + """Metadata class that handles all metadata.""" + + METADATA_SPEC_VERSION = 'V1' + + @classmethod + def load_from_json(cls, filepath): + """Create a ``Metadata`` instance from a ``json`` file. + + Args: + filepath (str): + String that represents the ``path`` to the ``json`` file. + + Raises: + - An ``Error`` if the path does not exist. + - An ``Error`` if the ``json`` file does not contain the ``METADATA_SPEC_VERSION``. + + Returns: + A ``Metadata`` instance. + """ + filename = Path(filepath).stem + metadata = read_json(filepath) + return cls.load_from_dict(metadata, filename) + + @classmethod + def load_from_dict(cls, metadata_dict, single_table_name=None): + """Create a ``Metadata`` instance from a python ``dict``. + + Args: + metadata_dict (dict): + Python dictionary representing a ``MultiTableMetadata`` + or ``SingleTableMetadata`` object. + single_table_name (string): + If the python dictionary represents a ``SingleTableMetadata`` then + this arg is used for the name of the table. + + Returns: + Instance of ``Metadata``. + """ + instance = cls() + instance._set_metadata_dict(metadata_dict, single_table_name) + return instance + + def _set_metadata_dict(self, metadata, single_table_name=None): + """Set a ``metadata`` dictionary to the current instance. + + Checks to see if the metadata is in the ``SingleTableMetadata`` or + ``MultiTableMetadata`` format and converts it to a standard + ``MultiTableMetadata`` format if necessary. + + Args: + metadata (dict): + Python dictionary representing a ``MultiTableMetadata`` or + ``SingleTableMetadata`` object. + """ + is_multi_table = 'tables' in metadata + + if is_multi_table: + super()._set_metadata_dict(metadata) + else: + if single_table_name is None: + single_table_name = 'default_table_name' + + self.tables[single_table_name] = SingleTableMetadata.load_from_dict(metadata) diff --git a/tests/integration/metadata/test_metadata.py b/tests/integration/metadata/test_metadata.py new file mode 100644 index 000000000..6adde98ed --- /dev/null +++ b/tests/integration/metadata/test_metadata.py @@ -0,0 +1,218 @@ +from sdv.datasets.demo import download_demo +from sdv.metadata.metadata import Metadata + + +def test_metadata(): + """Test ``MultiTableMetadata``.""" + # Create an instance + instance = Metadata() + + # To dict + result = instance.to_dict() + + # Assert + assert result == {'tables': {}, 'relationships': [], 'METADATA_SPEC_VERSION': 'V1'} + assert instance.tables == {} + assert instance.relationships == [] + + +def test_detect_from_dataframes_multi_table(): + """Test the ``detect_from_dataframes`` method works with multi-table.""" + # Setup + real_data, _ = download_demo(modality='multi_table', dataset_name='fake_hotels') + + metadata = Metadata() + + # Run + metadata.detect_from_dataframes(real_data) + + # Assert + metadata.update_column( + table_name='hotels', + column_name='classification', + sdtype='categorical', + ) + + expected_metadata = { + 'tables': { + 'hotels': { + 'columns': { + 'hotel_id': {'sdtype': 'id'}, + 'city': {'sdtype': 'city', 'pii': True}, + 'state': {'sdtype': 'administrative_unit', 'pii': True}, + 'rating': {'sdtype': 'numerical'}, + 'classification': {'sdtype': 'categorical'}, + }, + 'primary_key': 'hotel_id', + }, + 'guests': { + 'columns': { + 'guest_email': {'sdtype': 'email', 'pii': True}, + 'hotel_id': {'sdtype': 'id'}, + 'has_rewards': {'sdtype': 'categorical'}, + 'room_type': {'sdtype': 'categorical'}, + 'amenities_fee': {'sdtype': 'numerical'}, + 'checkin_date': {'sdtype': 'datetime', 'datetime_format': '%d %b %Y'}, + 'checkout_date': {'sdtype': 'datetime', 'datetime_format': '%d %b %Y'}, + 'room_rate': {'sdtype': 'numerical'}, + 'billing_address': {'sdtype': 'unknown', 'pii': True}, + 'credit_card_number': {'sdtype': 'credit_card_number', 'pii': True}, + }, + 'primary_key': 'guest_email', + }, + }, + 'relationships': [ + { + 'parent_table_name': 'hotels', + 'child_table_name': 'guests', + 'parent_primary_key': 'hotel_id', + 'child_foreign_key': 'hotel_id', + } + ], + 'METADATA_SPEC_VERSION': 'V1', + } + assert metadata.to_dict() == expected_metadata + + +def test_detect_from_data_frames_single_table(): + """Test the ``detect_from_dataframes`` method works with a single table.""" + # Setup + data, _ = download_demo(modality='multi_table', dataset_name='fake_hotels') + + metadata = Metadata() + metadata.detect_from_dataframes({'table_1': data['hotels']}) + + # Run + metadata.validate() + + # Assert + expected_metadata = { + 'METADATA_SPEC_VERSION': 'V1', + 'tables': { + 'table_1': { + 'columns': { + 'hotel_id': {'sdtype': 'id'}, + 'city': {'sdtype': 'city', 'pii': True}, + 'state': {'sdtype': 'administrative_unit', 'pii': True}, + 'rating': {'sdtype': 'numerical'}, + 'classification': {'sdtype': 'unknown', 'pii': True}, + }, + 'primary_key': 'hotel_id', + } + }, + 'relationships': [], + } + assert metadata.to_dict() == expected_metadata + + +def test_detect_from_csvs(tmp_path): + """Test the ``detect_from_csvs`` method.""" + # Setup + real_data, _ = download_demo(modality='multi_table', dataset_name='fake_hotels') + + metadata = Metadata() + + for table_name, dataframe in real_data.items(): + csv_path = tmp_path / f'{table_name}.csv' + dataframe.to_csv(csv_path, index=False) + + # Run + metadata.detect_from_csvs(folder_name=tmp_path) + + # Assert + metadata.update_column( + table_name='hotels', + column_name='classification', + sdtype='categorical', + ) + + expected_metadata = { + 'tables': { + 'hotels': { + 'columns': { + 'hotel_id': {'sdtype': 'id'}, + 'city': {'sdtype': 'city', 'pii': True}, + 'state': {'sdtype': 'administrative_unit', 'pii': True}, + 'rating': {'sdtype': 'numerical'}, + 'classification': {'sdtype': 'categorical'}, + }, + 'primary_key': 'hotel_id', + }, + 'guests': { + 'columns': { + 'guest_email': {'sdtype': 'email', 'pii': True}, + 'hotel_id': {'sdtype': 'id'}, + 'has_rewards': {'sdtype': 'categorical'}, + 'room_type': {'sdtype': 'categorical'}, + 'amenities_fee': {'sdtype': 'numerical'}, + 'checkin_date': {'sdtype': 'datetime', 'datetime_format': '%d %b %Y'}, + 'checkout_date': {'sdtype': 'datetime', 'datetime_format': '%d %b %Y'}, + 'room_rate': {'sdtype': 'numerical'}, + 'billing_address': {'sdtype': 'unknown', 'pii': True}, + 'credit_card_number': {'sdtype': 'credit_card_number', 'pii': True}, + }, + 'primary_key': 'guest_email', + }, + }, + 'relationships': [ + { + 'parent_table_name': 'hotels', + 'child_table_name': 'guests', + 'parent_primary_key': 'hotel_id', + 'child_foreign_key': 'hotel_id', + } + ], + 'METADATA_SPEC_VERSION': 'V1', + } + + assert metadata.to_dict() == expected_metadata + + +def test_detect_table_from_csv(tmp_path): + """Test the ``detect_table_from_csv`` method.""" + # Setup + real_data, _ = download_demo(modality='multi_table', dataset_name='fake_hotels') + + metadata = Metadata() + + for table_name, dataframe in real_data.items(): + csv_path = tmp_path / f'{table_name}.csv' + dataframe.to_csv(csv_path, index=False) + + # Run + metadata.detect_table_from_csv('hotels', tmp_path / 'hotels.csv') + + # Assert + metadata.update_column( + table_name='hotels', + column_name='city', + sdtype='categorical', + ) + metadata.update_column( + table_name='hotels', + column_name='state', + sdtype='categorical', + ) + metadata.update_column( + table_name='hotels', + column_name='classification', + sdtype='categorical', + ) + expected_metadata = { + 'tables': { + 'hotels': { + 'columns': { + 'hotel_id': {'sdtype': 'id'}, + 'city': {'sdtype': 'categorical'}, + 'state': {'sdtype': 'categorical'}, + 'rating': {'sdtype': 'numerical'}, + 'classification': {'sdtype': 'categorical'}, + }, + 'primary_key': 'hotel_id', + } + }, + 'relationships': [], + 'METADATA_SPEC_VERSION': 'V1', + } + + assert metadata.to_dict() == expected_metadata diff --git a/tests/unit/metadata/test_metadata.py b/tests/unit/metadata/test_metadata.py new file mode 100644 index 000000000..284bad8d6 --- /dev/null +++ b/tests/unit/metadata/test_metadata.py @@ -0,0 +1,545 @@ +from unittest.mock import patch + +import pytest + +from sdv.metadata.metadata import Metadata +from tests.utils import get_multi_table_data, get_multi_table_metadata + + +class TestMetadataClass: + """Test ``Metadata`` class.""" + + def get_multi_table_metadata(self): + """Set the tables and relationships for metadata.""" + metadata = {} + metadata['tables'] = { + 'users': { + 'columns': {'id': {'sdtype': 'id'}, 'country': {'sdtype': 'categorical'}}, + 'primary_key': 'id', + }, + 'payments': { + 'columns': { + 'payment_id': {'sdtype': 'id'}, + 'user_id': {'sdtype': 'id'}, + 'date': {'sdtype': 'datetime'}, + }, + 'primary_key': 'payment_id', + }, + 'sessions': { + 'columns': { + 'session_id': {'sdtype': 'id'}, + 'user_id': {'sdtype': 'id'}, + 'device': {'sdtype': 'categorical'}, + }, + 'primary_key': 'session_id', + }, + 'transactions': { + 'columns': { + 'transaction_id': {'sdtype': 'id'}, + 'session_id': {'sdtype': 'id'}, + 'timestamp': {'sdtype': 'datetime'}, + }, + 'primary_key': 'transaction_id', + }, + } + + metadata['relationships'] = [ + { + 'parent_table_name': 'users', + 'parent_primary_key': 'id', + 'child_table_name': 'sessions', + 'child_foreign_key': 'user_id', + }, + { + 'parent_table_name': 'sessions', + 'parent_primary_key': 'session_id', + 'child_table_name': 'transactions', + 'child_foreign_key': 'session_id', + }, + { + 'parent_table_name': 'users', + 'parent_primary_key': 'id', + 'child_table_name': 'payments', + 'child_foreign_key': 'user_id', + }, + ] + + return Metadata.load_from_dict(metadata) + + @patch('sdv.metadata.utils.Path') + def test_load_from_json_path_does_not_exist(self, mock_path): + """Test the ``load_from_json`` method. + + Test that the method raises a ``ValueError`` when the specified path does not + exist. + + Mock: + - Mock the ``Path`` library in order to return ``False``, that the file does not exist. + + Input: + - String representing a filepath. + + Side Effects: + - A ``ValueError`` is raised pointing that the ``file`` does not exist. + """ + # Setup + mock_path.return_value.exists.return_value = False + mock_path.return_value.name = 'filepath.json' + + # Run / Assert + error_msg = ( + "A file named 'filepath.json' does not exist. Please specify a different filename." + ) + with pytest.raises(ValueError, match=error_msg): + Metadata.load_from_json('filepath.json') + + @patch('sdv.metadata.utils.Path') + @patch('sdv.metadata.utils.json') + def test_load_from_json_single_table(self, mock_json, mock_path): + """Test the ``load_from_json`` method. + + Test that ``load_from_json`` function creates an instance with the contents returned by the + ``json`` load function when passing in a single table metadata json. + + Mock: + - Mock the ``Path`` library in order to return ``True``. + - Mock the ``json`` library in order to use a custom return. + + Input: + - String representing a filepath. + + Output: + - ``SingleTableMetadata`` instance with the custom configuration from the ``json`` + file (``json.load`` return value) + """ + # Setup + instance = Metadata() + mock_path.return_value.exists.return_value = True + mock_path.return_value.name = 'filepath.json' + mock_json.load.return_value = { + 'columns': {'animals': {'type': 'categorical'}}, + 'primary_key': 'animals', + 'METADATA_SPEC_VERSION': 'SINGLE_TABLE_V1', + } + + # Run + instance = Metadata.load_from_json('filepath.json') + + # Assert + assert list(instance.tables.keys()) == ['filepath'] + assert instance.tables['filepath'].columns == {'animals': {'type': 'categorical'}} + assert instance.tables['filepath'].primary_key == 'animals' + assert instance.tables['filepath'].sequence_key is None + assert instance.tables['filepath'].alternate_keys == [] + assert instance.tables['filepath'].sequence_index is None + assert instance.tables['filepath']._version == 'SINGLE_TABLE_V1' + + @patch('sdv.metadata.utils.Path') + @patch('sdv.metadata.utils.json') + def test_load_from_json_multi_table(self, mock_json, mock_path): + """Test the ``load_from_json`` method. + + Test that ``load_from_json`` function creates an instance with the contents returned by the + ``json`` load function when passing in a multi-table metadata json. + + Mock: + - Mock the ``Path`` library in order to return ``True``. + - Mock the ``json`` library in order to use a custom return. + + Input: + - String representing a filepath. + + Output: + - ``SingleTableMetadata`` instance with the custom configuration from the ``json`` + file (``json.load`` return value) + """ + # Setup + instance = Metadata() + mock_path.return_value.exists.return_value = True + mock_path.return_value.name = 'filepath.json' + mock_json.load.return_value = { + 'tables': { + 'table1': { + 'columns': {'animals': {'type': 'categorical'}}, + 'primary_key': 'animals', + 'METADATA_SPEC_VERSION': 'SINGLE_TABLE_V1', + } + }, + 'relationships': {}, + } + + # Run + instance = Metadata.load_from_json('filepath.json') + + # Asserts + assert list(instance.tables.keys()) == ['table1'] + assert instance.tables['table1'].columns == {'animals': {'type': 'categorical'}} + assert instance.tables['table1'].primary_key == 'animals' + assert instance.tables['table1'].sequence_key is None + assert instance.tables['table1'].alternate_keys == [] + assert instance.tables['table1'].sequence_index is None + assert instance.tables['table1']._version == 'SINGLE_TABLE_V1' + + @patch('sdv.metadata.multi_table.SingleTableMetadata') + def test_load_from_dict_multi_table(self, mock_singletablemetadata): + """Test that ``load_from_dict`` returns a instance of multi-table ``Metadata``. + + Test that when calling the ``load_from_dict`` method a new instance with the passed + python ``dict`` details should be created. + + Setup: + - A dict representing a multi-table ``Metadata``. + + Mock: + - Mock ``SingleTableMetadata`` from ``sdv.metadata.multi_table`` + + Output: + - ``instance`` that contains ``instance.tables`` and ``instance.relationships``. + + Side Effects: + - ``SingleTableMetadata.load_from_dict`` has been called. + """ + # Setup + multitable_metadata = { + 'tables': { + 'accounts': { + 'id': {'sdtype': 'numerical'}, + 'branch_id': {'sdtype': 'numerical'}, + 'amount': {'sdtype': 'numerical'}, + 'start_date': {'sdtype': 'datetime'}, + 'owner': {'sdtype': 'id'}, + }, + 'branches': { + 'id': {'sdtype': 'numerical'}, + 'name': {'sdtype': 'id'}, + }, + }, + 'relationships': [ + { + 'parent_table_name': 'accounts', + 'parent_primary_key': 'id', + 'child_table_name': 'branches', + 'child_foreign_key': 'branch_id', + } + ], + } + + single_table_accounts = object() + single_table_branches = object() + mock_singletablemetadata.load_from_dict.side_effect = [ + single_table_accounts, + single_table_branches, + ] + + # Run + instance = Metadata.load_from_dict(multitable_metadata) + + # Assert + assert instance.tables == { + 'accounts': single_table_accounts, + 'branches': single_table_branches, + } + + assert instance.relationships == [ + { + 'parent_table_name': 'accounts', + 'parent_primary_key': 'id', + 'child_table_name': 'branches', + 'child_foreign_key': 'branch_id', + } + ] + + @patch('sdv.metadata.multi_table.SingleTableMetadata') + def test_load_from_dict_integer_multi_table(self, mock_singletablemetadata): + """Test that ``load_from_dict`` returns a instance of multi-table ``Metadata``. + + Test that when calling the ``load_from_dict`` method a new instance with the passed + python ``dict`` details should be created. Make sure that integers passed in are + turned into strings to ensure metadata is properly typed. + + Setup: + - A dict representing a multi-table ``Metadata``. + + Mock: + - Mock ``SingleTableMetadata`` from ``sdv.metadata.multi_table`` + + Output: + - ``instance`` that contains ``instance.tables`` and ``instance.relationships``. + + Side Effects: + - ``SingleTableMetadata.load_from_dict`` has been called. + """ + # Setup + multitable_metadata = { + 'tables': { + 'accounts': { + 1: {'sdtype': 'numerical'}, + 2: {'sdtype': 'numerical'}, + 'amount': {'sdtype': 'numerical'}, + 'start_date': {'sdtype': 'datetime'}, + 'owner': {'sdtype': 'id'}, + }, + 'branches': { + 1: {'sdtype': 'numerical'}, + 'name': {'sdtype': 'id'}, + }, + }, + 'relationships': [ + { + 'parent_table_name': 'accounts', + 'parent_primary_key': 1, + 'child_table_name': 'branches', + 'child_foreign_key': 1, + } + ], + } + + single_table_accounts = { + '1': {'sdtype': 'numerical'}, + '2': {'sdtype': 'numerical'}, + 'amount': {'sdtype': 'numerical'}, + 'start_date': {'sdtype': 'datetime'}, + 'owner': {'sdtype': 'id'}, + } + single_table_branches = { + '1': {'sdtype': 'numerical'}, + 'name': {'sdtype': 'id'}, + } + mock_singletablemetadata.load_from_dict.side_effect = [ + single_table_accounts, + single_table_branches, + ] + + # Run + instance = Metadata.load_from_dict(multitable_metadata) + + # Assert + assert instance.tables == { + 'accounts': single_table_accounts, + 'branches': single_table_branches, + } + + assert instance.relationships == [ + { + 'parent_table_name': 'accounts', + 'parent_primary_key': '1', + 'child_table_name': 'branches', + 'child_foreign_key': '1', + } + ] + + def test_load_from_dict_single_table(self): + """Test that ``load_from_dict`` returns a instance of single-table ``Metadata``. + + Test that when calling the ``load_from_dict`` method a new instance with the passed + python ``dict`` details should be created. + """ + # Setup + my_metadata = { + 'columns': {'my_column': 'value'}, + 'primary_key': 'pk', + 'alternate_keys': [], + 'sequence_key': None, + 'sequence_index': None, + 'METADATA_SPEC_VERSION': 'SINGLE_TABLE_V1', + } + + # Run + instance = Metadata.load_from_dict(my_metadata) + + # Assert + assert list(instance.tables.keys()) == ['default_table_name'] + assert instance.tables['default_table_name'].columns == {'my_column': 'value'} + assert instance.tables['default_table_name'].primary_key == 'pk' + assert instance.tables['default_table_name'].sequence_key is None + assert instance.tables['default_table_name'].alternate_keys == [] + assert instance.tables['default_table_name'].sequence_index is None + assert instance.tables['default_table_name']._version == 'SINGLE_TABLE_V1' + + def test_load_from_dict_integer_single_table(self): + """Test that ``load_from_dict`` returns a instance of single-table ``Metadata``. + + Test that when calling the ``load_from_dict`` method a new instance with the passed + python ``dict`` details should be created. Make sure that integers passed in are + turned into strings to ensure metadata is properly typed. + """ + + # Setup + my_metadata = { + 'columns': {1: 'value'}, + 'primary_key': 'pk', + 'alternate_keys': [], + 'sequence_key': None, + 'sequence_index': None, + 'METADATA_SPEC_VERSION': 'SINGLE_TABLE_V1', + } + + # Run + instance = Metadata.load_from_dict(my_metadata) + + # Assert + assert list(instance.tables.keys()) == ['default_table_name'] + assert instance.tables['default_table_name'].columns == {'1': 'value'} + assert instance.tables['default_table_name'].primary_key == 'pk' + assert instance.tables['default_table_name'].sequence_key is None + assert instance.tables['default_table_name'].alternate_keys == [] + assert instance.tables['default_table_name'].sequence_index is None + + @patch('sdv.metadata.multi_table.SingleTableMetadata') + def test__set_metadata_multi_table(self, mock_singletablemetadata): + """Test the ``_set_metadata`` method for ``Metadata``. + + Setup: + - instance of ``Metadata``. + - A dict representing a ``MultiTableMetadata``. + + Mock: + - Mock ``SingleTableMetadata`` from ``sdv.metadata.multi_table`` + + Side Effects: + - ``instance`` now contains ``instance.tables`` and ``instance.relationships``. + - ``SingleTableMetadata.load_from_dict`` has been called. + """ + # Setup + multitable_metadata = { + 'tables': { + 'accounts': { + 'id': {'sdtype': 'numerical'}, + 'branch_id': {'sdtype': 'numerical'}, + 'amount': {'sdtype': 'numerical'}, + 'start_date': {'sdtype': 'datetime'}, + 'owner': {'sdtype': 'id'}, + }, + 'branches': { + 'id': {'sdtype': 'numerical'}, + 'name': {'sdtype': 'id'}, + }, + }, + 'relationships': [ + { + 'parent_table_name': 'accounts', + 'parent_primary_key': 'id', + 'child_table_name': 'branches', + 'chil_foreign_key': 'branch_id', + } + ], + } + + single_table_accounts = object() + single_table_branches = object() + mock_singletablemetadata.load_from_dict.side_effect = [ + single_table_accounts, + single_table_branches, + ] + + instance = Metadata() + + # Run + instance._set_metadata_dict(multitable_metadata) + + # Assert + assert instance.tables == { + 'accounts': single_table_accounts, + 'branches': single_table_branches, + } + + assert instance.relationships == [ + { + 'parent_table_name': 'accounts', + 'parent_primary_key': 'id', + 'child_table_name': 'branches', + 'chil_foreign_key': 'branch_id', + } + ] + + def test__set_metadata_single_table(self): + """Test the ``_set_metadata`` method for ``Metadata``. + + Setup: + - instance of ``Metadata``. + - A dict representing a ``SingleTableMetadata``. + + Mock: + - Mock ``SingleTableMetadata`` from ``sdv.metadata.multi_table`` + + Side Effects: + - ``SingleTableMetadata.load_from_dict`` has been called. + """ + # Setup + multitable_metadata = { + 'columns': {'my_column': 'value'}, + 'primary_key': 'pk', + 'alternate_keys': [], + 'sequence_key': None, + 'sequence_index': None, + 'METADATA_SPEC_VERSION': 'SINGLE_TABLE_V1', + } + + instance = Metadata() + + # Run + instance._set_metadata_dict(multitable_metadata) + + # Assert + assert instance.tables['default_table_name'].columns == {'my_column': 'value'} + assert instance.tables['default_table_name'].primary_key == 'pk' + assert instance.tables['default_table_name'].alternate_keys == [] + assert instance.tables['default_table_name'].sequence_key is None + assert instance.tables['default_table_name'].sequence_index is None + assert instance.tables['default_table_name'].METADATA_SPEC_VERSION == 'SINGLE_TABLE_V1' + + def test_validate(self): + """Test the method ``validate``. + + Test that when a valid ``Metadata`` has been provided no errors are being raised. + + Setup: + - Instance of ``Metadata`` with all valid tables and relationships. + """ + # Setup + instance = self.get_multi_table_metadata() + + # Run + instance.validate() + + def test_validate_no_relationships(self): + """Test the method ``validate`` without relationships. + + Test that when a valid ``Metadata`` has been provided no errors are being raised. + + Setup: + - Instance of ``Metadata`` with all valid tables and no relationships. + """ + # Setup + metadata = self.get_multi_table_metadata() + metadata_no_relationships = metadata.to_dict() + del metadata_no_relationships['relationships'] + test_metadata = Metadata.load_from_dict(metadata_no_relationships) + + # Run + test_metadata.validate() + assert test_metadata.METADATA_SPEC_VERSION == 'V1' + + def test_validate_data(self): + """Test that no error is being raised when the data is valid.""" + # Setup + metadata_dict = get_multi_table_metadata().to_dict() + metadata = Metadata.load_from_dict(metadata_dict) + data = get_multi_table_data() + + # Run and Assert + metadata.validate_data(data) + assert metadata.METADATA_SPEC_VERSION == 'V1' + + def test_validate_data_no_relationships(self): + """Test that no error is being raised when the data is valid but has no relationships.""" + # Setup + metadata_dict = get_multi_table_metadata().to_dict() + del metadata_dict['relationships'] + del metadata_dict['METADATA_SPEC_VERSION'] + metadata = Metadata.load_from_dict(metadata_dict) + data = get_multi_table_data() + + # Run and Assert + metadata.validate_data(data) + assert metadata.METADATA_SPEC_VERSION == 'V1' From e074e5a927bfa0fef9df7492dc26d7caef15ae00 Mon Sep 17 00:00:00 2001 From: John La Date: Thu, 1 Aug 2024 12:22:13 -0500 Subject: [PATCH 02/16] Make `Metadata` work with single table synthesizers (#2140) Co-authored-by: Andrew Montanez --- sdv/lite/single_table.py | 2 +- sdv/metadata/metadata.py | 10 +++- sdv/multi_table/base.py | 1 + sdv/single_table/base.py | 20 ++++++- sdv/single_table/ctgan.py | 8 +-- tests/integration/multi_table/test_hma.py | 5 ++ tests/integration/single_table/test_base.py | 23 ++++--- tests/unit/lite/test_single_table.py | 9 ++- tests/unit/multi_table/test_base.py | 11 +++- tests/unit/sequential/test_par.py | 54 ++++++++++++++++- tests/unit/single_table/test_base.py | 66 +++++++++++++++++++-- tests/unit/single_table/test_copulagan.py | 37 ++++++++++++ tests/unit/single_table/test_copulas.py | 28 +++++++++ tests/unit/single_table/test_ctgan.py | 32 ++++++++++ 14 files changed, 278 insertions(+), 28 deletions(-) diff --git a/sdv/lite/single_table.py b/sdv/lite/single_table.py index 08232ac60..7ad82bd1e 100644 --- a/sdv/lite/single_table.py +++ b/sdv/lite/single_table.py @@ -65,7 +65,7 @@ def add_constraints(self, constraints): self._synthesizer.add_constraints(constraints) def get_metadata(self): - """Return the ``SingleTableMetadata`` for this synthesizer.""" + """Return the ``Metadata`` for this synthesizer.""" warnings.warn(DEPRECATION_MSG, FutureWarning) return self._synthesizer.get_metadata() diff --git a/sdv/metadata/metadata.py b/sdv/metadata/metadata.py index 1b628c9b9..d8a9b463c 100644 --- a/sdv/metadata/metadata.py +++ b/sdv/metadata/metadata.py @@ -2,6 +2,7 @@ from pathlib import Path +from sdv.metadata.errors import InvalidMetadataError from sdv.metadata.multi_table import MultiTableMetadata from sdv.metadata.single_table import SingleTableMetadata from sdv.metadata.utils import read_json @@ -69,5 +70,12 @@ def _set_metadata_dict(self, metadata, single_table_name=None): else: if single_table_name is None: single_table_name = 'default_table_name' - self.tables[single_table_name] = SingleTableMetadata.load_from_dict(metadata) + + def _convert_to_single_table(self): + if len(self.tables) > 1: + raise InvalidMetadataError( + 'Metadata contains more than one table, use a MultiTableSynthesizer instead.' + ) + + return next(iter(self.tables.values()), SingleTableMetadata()) diff --git a/sdv/multi_table/base.py b/sdv/multi_table/base.py index 83c80ae25..9a72be003 100644 --- a/sdv/multi_table/base.py +++ b/sdv/multi_table/base.py @@ -25,6 +25,7 @@ SynthesizerInputError, ) from sdv.logging import disable_single_table_logger, get_sdv_logger +from sdv.metadata.metadata import Metadata from sdv.single_table.base import INT_REGEX_ZERO_ERROR_MESSAGE from sdv.single_table.copulas import GaussianCopulaSynthesizer diff --git a/sdv/single_table/base.py b/sdv/single_table/base.py index 7645c6986..d618a4c48 100644 --- a/sdv/single_table/base.py +++ b/sdv/single_table/base.py @@ -35,6 +35,8 @@ SynthesizerInputError, ) from sdv.logging import get_sdv_logger +from sdv.metadata.metadata import Metadata +from sdv.metadata.single_table import SingleTableMetadata from sdv.single_table.utils import check_num_rows, handle_sampling_error, validate_file_path LOGGER = logging.getLogger(__name__) @@ -47,6 +49,11 @@ 'or update it to correspond to valid ints.' ) +DEPRECATION_MSG = ( + "The 'SingleTableMetadata' is deprecated. Please use the new " + "'Metadata' class for synthesizers." +) + class BaseSynthesizer: """Base class for all ``Synthesizers``. @@ -55,8 +62,9 @@ class BaseSynthesizer: ``Synthesizers`` need to implement, as well as common functionality. Args: - metadata (sdv.metadata.SingleTableMetadata): + metadata (sdv.metadata.Metadata): Single table metadata representing the data that this synthesizer will be used for. + * sdv.metadata.SingleTableMetadata can be used but will be deprecated. enforce_min_max_values (bool): Specify whether or not to clip the data returned by ``reverse_transform`` of the numerical transformer, ``FloatFormatter``, to the min and max values seen @@ -104,6 +112,12 @@ def __init__( ): self._validate_inputs(enforce_min_max_values, enforce_rounding) self.metadata = metadata + if isinstance(metadata, Metadata): + self.metadata = metadata._convert_to_single_table() + elif isinstance(metadata, SingleTableMetadata): + warnings.warn(DEPRECATION_MSG, FutureWarning) + + self._validate_inputs(enforce_min_max_values, enforce_rounding) self.metadata.validate() self._check_metadata_updated() self.enforce_min_max_values = enforce_min_max_values @@ -269,8 +283,8 @@ def get_parameters(self): return instantiated_parameters def get_metadata(self): - """Return the ``SingleTableMetadata`` for this synthesizer.""" - return self.metadata + """Return the ``Metadata`` for this synthesizer.""" + return Metadata.load_from_dict(self.metadata.to_dict()) def load_custom_constraint_classes(self, filepath, class_names): """Load a custom constraint class for the current synthesizer. diff --git a/sdv/single_table/ctgan.py b/sdv/single_table/ctgan.py index b918784fb..30174f43a 100644 --- a/sdv/single_table/ctgan.py +++ b/sdv/single_table/ctgan.py @@ -285,9 +285,7 @@ def _fit(self, processed_data): _validate_no_category_dtype(processed_data) transformers = self._data_processor._hyper_transformer.field_transformers - discrete_columns = detect_discrete_columns( - self.get_metadata(), processed_data, transformers - ) + discrete_columns = detect_discrete_columns(self.metadata, processed_data, transformers) self._model = CTGAN(**self._model_kwargs) with warnings.catch_warnings(): warnings.filterwarnings('ignore', message='.*Attempting to run cuBLAS.*') @@ -402,9 +400,7 @@ def _fit(self, processed_data): _validate_no_category_dtype(processed_data) transformers = self._data_processor._hyper_transformer.field_transformers - discrete_columns = detect_discrete_columns( - self.get_metadata(), processed_data, transformers - ) + discrete_columns = detect_discrete_columns(self.metadata, processed_data, transformers) self._model = TVAE(**self._model_kwargs) self._model.fit(processed_data, discrete_columns=discrete_columns) diff --git a/tests/integration/multi_table/test_hma.py b/tests/integration/multi_table/test_hma.py index 133a53681..cc22db1c2 100644 --- a/tests/integration/multi_table/test_hma.py +++ b/tests/integration/multi_table/test_hma.py @@ -1246,6 +1246,11 @@ def test_metadata_updated_no_warning(self, tmp_path): # Run 1 with warnings.catch_warnings(record=True) as captured_warnings: + warnings.filterwarnings( + 'ignore', + message=".*The 'SingleTableMetadata' is deprecated.*", + category=DeprecationWarning, + ) warnings.simplefilter('always') instance = HMASynthesizer(metadata) instance.fit(data) diff --git a/tests/integration/single_table/test_base.py b/tests/integration/single_table/test_base.py index 27ab09d06..d08abe502 100644 --- a/tests/integration/single_table/test_base.py +++ b/tests/integration/single_table/test_base.py @@ -13,6 +13,7 @@ from sdv.datasets.demo import download_demo from sdv.errors import SamplingError, SynthesizerInputError, VersionError from sdv.metadata import SingleTableMetadata +from sdv.metadata.metadata import Metadata from sdv.sampling import Condition from sdv.single_table import ( CopulaGANSynthesizer, @@ -587,7 +588,7 @@ def test_metadata_updated_no_warning(mock__fit, tmp_path): initialization, but is saved to a file before fitting. """ # Setup - metadata_from_dict = SingleTableMetadata().load_from_dict({ + metadata_from_dict = Metadata().load_from_dict({ 'columns': { 'col 1': {'sdtype': 'numerical'}, 'col 2': {'sdtype': 'numerical'}, @@ -610,8 +611,8 @@ def test_metadata_updated_no_warning(mock__fit, tmp_path): assert len(captured_warnings) == 0 # Run 2 - metadata_detect = SingleTableMetadata() - metadata_detect.detect_from_dataframe(data) + metadata_detect = Metadata() + metadata_detect.detect_from_dataframes({'mock_table': data}) file_name = tmp_path / 'singletable.json' metadata_detect.save_to_json(file_name) with warnings.catch_warnings(record=True) as captured_warnings: @@ -624,7 +625,7 @@ def test_metadata_updated_no_warning(mock__fit, tmp_path): # Run 3 instance = BaseSingleTableSynthesizer(metadata_detect) - metadata_detect.update_column('col 1', sdtype='categorical') + metadata_detect.update_column('mock_table', 'col 1', sdtype='categorical') file_name = tmp_path / 'singletable_2.json' metadata_detect.save_to_json(file_name) with warnings.catch_warnings(record=True) as captured_warnings: @@ -650,18 +651,26 @@ def test_metadata_updated_warning_detect(mock__fit): }) metadata = SingleTableMetadata() metadata.detect_from_dataframe(data) - expected_message = re.escape( + expected_user_message = ( "We strongly recommend saving the metadata using 'save_to_json' for replicability" ' in future SDV versions.' ) + expected_deprecation_message = ( + "The 'SingleTableMetadata' is deprecated. " + "Please use the new 'Metadata' class for synthesizers." + ) # Run - with pytest.warns(UserWarning, match=expected_message) as record: + with warnings.catch_warnings(record=True) as record: instance = BaseSingleTableSynthesizer(metadata) instance.fit(data) # Assert - assert len(record) == 1 + assert len(record) == 2 + assert record[0].category is FutureWarning + assert str(record[0].message) == expected_deprecation_message + assert record[1].category is UserWarning + assert str(record[1].message) == expected_user_message parametrization = [ diff --git a/tests/unit/lite/test_single_table.py b/tests/unit/lite/test_single_table.py index aff2453ee..c51d99b83 100644 --- a/tests/unit/lite/test_single_table.py +++ b/tests/unit/lite/test_single_table.py @@ -8,6 +8,7 @@ from sdv.lite import SingleTablePreset from sdv.metadata import SingleTableMetadata +from sdv.metadata.metadata import Metadata from sdv.single_table import GaussianCopulaSynthesizer from tests.utils import DataFrameMatcher @@ -79,14 +80,18 @@ def test_get_parameters(self, mock_data_processor): def test_get_metadata(self, mock_data_processor): """Test that it returns the ``metadata`` object.""" # Setup - metadata = Mock() + metadata = Mock(spec=SingleTableMetadata) + metadata._updated = False + metadata.columns = {} + metadata.to_dict.return_value = {'METADATA_SPEC_VERSION': 'SINGLE_TABLE_V1'} instance = SingleTablePreset(metadata, 'FAST_ML') # Run result = instance.get_metadata() # Assert - assert result == metadata + assert isinstance(result, Metadata) + assert result._convert_to_single_table().to_dict() == metadata.to_dict() def test_fit(self): """Test that the synthesizer's fit method is called with the expected args.""" diff --git a/tests/unit/multi_table/test_base.py b/tests/unit/multi_table/test_base.py index 48a32cd30..a1975da2f 100644 --- a/tests/unit/multi_table/test_base.py +++ b/tests/unit/multi_table/test_base.py @@ -19,6 +19,7 @@ SynthesizerInputError, VersionError, ) +from sdv.metadata.metadata import Metadata from sdv.metadata.multi_table import MultiTableMetadata from sdv.metadata.single_table import SingleTableMetadata from sdv.multi_table.base import BaseMultiTableSynthesizer @@ -57,14 +58,18 @@ def test__initialize_models(self): } instance._synthesizer.assert_has_calls([ call( - metadata=instance.metadata.tables['nesreca'], + metadata=ANY, default_distribution='gamma', locales=['en_US'], ), - call(metadata=instance.metadata.tables['oseba'], locales=locales), - call(metadata=instance.metadata.tables['upravna_enota'], locales=locales), + call(metadata=ANY, locales=locales), + call(metadata=ANY, locales=locales), ]) + for call_args in instance._synthesizer.call_args_list: + metadata_arg = call_args[1].get('metadata', None) + assert isinstance(metadata_arg, Metadata) + def test__get_pbar_args(self): """Test that ``_get_pbar_args`` returns a dictionary with disable opposite to verbose.""" # Setup diff --git a/tests/unit/sequential/test_par.py b/tests/unit/sequential/test_par.py index b235c33f1..21a40922c 100644 --- a/tests/unit/sequential/test_par.py +++ b/tests/unit/sequential/test_par.py @@ -9,6 +9,8 @@ from sdv.data_processing.data_processor import DataProcessor from sdv.data_processing.errors import InvalidConstraintsError from sdv.errors import InvalidDataError, NotFittedError, SamplingError, SynthesizerInputError +from sdv.metadata.errors import InvalidMetadataError +from sdv.metadata.metadata import Metadata from sdv.metadata.single_table import SingleTableMetadata from sdv.sampling import Condition from sdv.sequential.par import PARSynthesizer @@ -236,7 +238,8 @@ def test_get_metadata(self): result = instance.get_metadata() # Assert - assert result == metadata + assert result._convert_to_single_table().to_dict() == metadata.to_dict() + assert isinstance(result, Metadata) def test_validate_context_columns_unique_per_sequence_key(self): """Test error is raised if context column values vary for each tuple of sequence keys. @@ -1024,3 +1027,52 @@ def test___init___error_sequence_key_in_context(self): metadata=metadata, context_columns=['name'], ) + + def test___init__with_unified_metadata(self): + """Test initialization with unified metadata.""" + # Setup + metadata = Metadata.load_from_dict({ + 'tables': { + 'table_1': { + 'columns': { + 'time': {'sdtype': 'datetime'}, + 'gender': {'sdtype': 'categorical'}, + 'name': {'sdtype': 'id'}, + 'measurement': {'sdtype': 'numerical'}, + }, + 'sequence_key': 'name', + } + } + }) + + multi_metadata = Metadata.load_from_dict({ + 'tables': { + 'table_1': { + 'columns': { + 'time': {'sdtype': 'datetime'}, + 'gender': {'sdtype': 'categorical'}, + 'name': {'sdtype': 'id'}, + 'measurement': {'sdtype': 'numerical'}, + }, + 'sequence_key': 'name', + }, + 'table_2': { + 'columns': { + 'time': {'sdtype': 'datetime'}, + 'gender': {'sdtype': 'categorical'}, + 'name': {'sdtype': 'id'}, + 'measurement': {'sdtype': 'numerical'}, + }, + 'sequence_key': 'name', + }, + } + }) + + # Run and Assert + PARSynthesizer(metadata) + error_msg = re.escape( + 'Metadata contains more than one table, use a MultiTableSynthesizer instead.' + ) + + with pytest.raises(InvalidMetadataError, match=error_msg): + PARSynthesizer(multi_metadata) diff --git a/tests/unit/single_table/test_base.py b/tests/unit/single_table/test_base.py index eae85337f..5fd111960 100644 --- a/tests/unit/single_table/test_base.py +++ b/tests/unit/single_table/test_base.py @@ -25,6 +25,8 @@ SynthesizerInputError, VersionError, ) +from sdv.metadata.errors import InvalidMetadataError +from sdv.metadata.metadata import Metadata from sdv.metadata.single_table import SingleTableMetadata from sdv.sampling.tabular import Condition from sdv.single_table import ( @@ -77,7 +79,7 @@ def test__check_metadata_updated(self): @patch('sdv.single_table.base.generate_synthesizer_id') @patch('sdv.single_table.base.DataProcessor') @patch('sdv.single_table.base.BaseSingleTableSynthesizer._check_metadata_updated') - def test___init__( + def test___init___l( self, mock_check_metadata_updated, mock_data_processor, @@ -119,6 +121,59 @@ def test___init__( 'SYNTHESIZER ID': 'BaseSingleTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5', }) + def test__init__with_old_metadata_future_warning(self): + """Test that future warning is thrown when using `SingleTableMetadata`""" + # Setup + metadata = SingleTableMetadata.load_from_dict({ + 'columns': { + 'a': {'sdtype': 'categorical'}, + } + }) + warn_msg = re.escape( + "The 'SingleTableMetadata' is deprecated. Please use the new " + "'Metadata' class for synthesizers." + ) + # Run and Assert + with pytest.warns(FutureWarning, match=warn_msg): + BaseSingleTableSynthesizer(metadata) + + def test___init__with_unified_metadata(self): + """Test initialization with unified metadata.""" + # Setup + metadata = Metadata.load_from_dict({ + 'tables': { + 'table_1': { + 'columns': { + 'id': {'sdtype': 'id'}, + }, + } + } + }) + + multi_metadata = Metadata.load_from_dict({ + 'tables': { + 'table_1': { + 'columns': { + 'id': {'sdtype': 'id'}, + }, + }, + 'table_2': { + 'columns': { + 'id': {'sdtype': 'id'}, + }, + }, + } + }) + + # Run and Assert + BaseSingleTableSynthesizer(metadata) + error_msg = re.escape( + 'Metadata contains more than one table, use a MultiTableSynthesizer instead.' + ) + + with pytest.raises(InvalidMetadataError, match=error_msg): + BaseSingleTableSynthesizer(multi_metadata) + @patch('sdv.single_table.base.DataProcessor') def test___init__custom(self, mock_data_processor): """Test that instantiating with custom parameters are properly stored in the instance.""" @@ -199,19 +254,22 @@ def test_get_parameters(self, mock_data_processor): } @patch('sdv.single_table.base.DataProcessor') - def test_get_metadata(self, mock_data_processor): + @patch('sdv.single_table.base.Metadata.load_from_dict') + def test_get_metadata(self, mock_load_from_dict, _): """Test that it returns the ``metadata`` object.""" # Setup - metadata = Mock() + metadata = Mock(spec=Metadata) instance = BaseSingleTableSynthesizer( metadata, enforce_min_max_values=False, enforce_rounding=False ) + mock_converted_metadata = Mock() + mock_load_from_dict.return_value = mock_converted_metadata # Run result = instance.get_metadata() # Assert - assert result == metadata + assert result == mock_converted_metadata def test_auto_assign_transformers(self): """Test that the ``DataProcessor.prepare_for_fitting`` is being called.""" diff --git a/tests/unit/single_table/test_copulagan.py b/tests/unit/single_table/test_copulagan.py index fa32359c6..e41fbdc83 100644 --- a/tests/unit/single_table/test_copulagan.py +++ b/tests/unit/single_table/test_copulagan.py @@ -8,6 +8,7 @@ from rdt.transformers import GaussianNormalizer from sdv.errors import SynthesizerInputError +from sdv.metadata.metadata import Metadata from sdv.metadata.single_table import SingleTableMetadata from sdv.single_table.copulagan import CopulaGANSynthesizer @@ -49,6 +50,42 @@ def test___init__(self): assert instance._numerical_distributions == {} assert instance._default_distribution == BetaUnivariate + def test___init__with_unified_metadata(self): + """Test creating an instance of ``CopulaGANSynthesizer`` with Metadata.""" + # Setup + metadata = Metadata() + enforce_min_max_values = True + enforce_rounding = True + + # Run + instance = CopulaGANSynthesizer( + metadata, + enforce_min_max_values=enforce_min_max_values, + enforce_rounding=enforce_rounding, + ) + + # Assert + assert instance.enforce_min_max_values is True + assert instance.enforce_rounding is True + assert instance.embedding_dim == 128 + assert instance.generator_dim == (256, 256) + assert instance.discriminator_dim == (256, 256) + assert instance.generator_lr == 2e-4 + assert instance.generator_decay == 1e-6 + assert instance.discriminator_lr == 2e-4 + assert instance.discriminator_decay == 1e-6 + assert instance.batch_size == 500 + assert instance.discriminator_steps == 1 + assert instance.log_frequency is True + assert instance.verbose is False + assert instance.epochs == 300 + assert instance.pac == 10 + assert instance.cuda is True + assert instance.numerical_distributions == {} + assert instance.default_distribution == 'beta' + assert instance._numerical_distributions == {} + assert instance._default_distribution == BetaUnivariate + def test___init__custom(self): """Test creating an instance of ``CopulaGANSynthesizer`` with custom parameters.""" # Setup diff --git a/tests/unit/single_table/test_copulas.py b/tests/unit/single_table/test_copulas.py index 0d83fa57d..130fb068d 100644 --- a/tests/unit/single_table/test_copulas.py +++ b/tests/unit/single_table/test_copulas.py @@ -8,6 +8,7 @@ from copulas.univariate import BetaUnivariate, GammaUnivariate, TruncatedGaussian, UniformUnivariate from sdv.errors import SynthesizerInputError +from sdv.metadata.metadata import Metadata from sdv.metadata.single_table import SingleTableMetadata from sdv.single_table.copulas import GaussianCopulaSynthesizer @@ -60,6 +61,33 @@ def test___init__(self): assert instance._numerical_distributions == {} assert instance._num_rows is None + def test___init__with_unified_metadata(self): + """Test creating an instance of ``GaussianCopulaSynthesizer`` with Metadata.""" + # Setup + metadata = Metadata() + enforce_min_max_values = True + enforce_rounding = True + numerical_distributions = None + default_distribution = None + + # Run + instance = GaussianCopulaSynthesizer( + metadata, + enforce_min_max_values=enforce_min_max_values, + enforce_rounding=enforce_rounding, + numerical_distributions=numerical_distributions, + default_distribution=default_distribution, + ) + + # Assert + assert instance.enforce_min_max_values is True + assert instance.enforce_rounding is True + assert instance.numerical_distributions == {} + assert instance.default_distribution == 'beta' + assert instance._default_distribution == BetaUnivariate + assert instance._numerical_distributions == {} + assert instance._num_rows is None + def test___init__custom(self): """Test creating an instance of ``GaussianCopulaSynthesizer`` with custom parameters.""" # Setup diff --git a/tests/unit/single_table/test_ctgan.py b/tests/unit/single_table/test_ctgan.py index 18245729c..0f81558e6 100644 --- a/tests/unit/single_table/test_ctgan.py +++ b/tests/unit/single_table/test_ctgan.py @@ -62,6 +62,38 @@ def test___init__(self): assert instance.pac == 10 assert instance.cuda is True + def test___init__with_unified_metadata(self): + """Test creating an instance of ``CTGANSynthesizer`` with Metadata.""" + # Setup + metadata = SingleTableMetadata() + enforce_min_max_values = True + enforce_rounding = True + + # Run + instance = CTGANSynthesizer( + metadata, + enforce_min_max_values=enforce_min_max_values, + enforce_rounding=enforce_rounding, + ) + + # Assert + assert instance.enforce_min_max_values is True + assert instance.enforce_rounding is True + assert instance.embedding_dim == 128 + assert instance.generator_dim == (256, 256) + assert instance.discriminator_dim == (256, 256) + assert instance.generator_lr == 2e-4 + assert instance.generator_decay == 1e-6 + assert instance.discriminator_lr == 2e-4 + assert instance.discriminator_decay == 1e-6 + assert instance.batch_size == 500 + assert instance.discriminator_steps == 1 + assert instance.log_frequency is True + assert instance.verbose is False + assert instance.epochs == 300 + assert instance.pac == 10 + assert instance.cuda is True + def test___init__custom(self): """Test creating an instance of ``CTGANSynthesizer`` with custom parameters.""" # Setup From 99e40a2648d81eb8e80b4e7aa76a862512215a7e Mon Sep 17 00:00:00 2001 From: John La Date: Tue, 6 Aug 2024 12:51:36 -0500 Subject: [PATCH 03/16] Deprecate MultiTableMetadata in favor for Metadata (#2172) --- sdv/multi_table/base.py | 11 +- tests/integration/metadata/test_metadata.py | 154 ++++++++++++++++++++ tests/integration/multi_table/test_hma.py | 34 ++++- tests/unit/multi_table/test_base.py | 23 ++- 4 files changed, 214 insertions(+), 8 deletions(-) diff --git a/sdv/multi_table/base.py b/sdv/multi_table/base.py index 9a72be003..c2f09d6b9 100644 --- a/sdv/multi_table/base.py +++ b/sdv/multi_table/base.py @@ -26,10 +26,15 @@ ) from sdv.logging import disable_single_table_logger, get_sdv_logger from sdv.metadata.metadata import Metadata +from sdv.metadata.multi_table import MultiTableMetadata from sdv.single_table.base import INT_REGEX_ZERO_ERROR_MESSAGE from sdv.single_table.copulas import GaussianCopulaSynthesizer SYNTHESIZER_LOGGER = get_sdv_logger('MultiTableSynthesizer') +DEPRECATION_MSG = ( + "The 'MultiTableMetadata' is deprecated. Please use the new " + "'Metadata' class for synthesizers." +) class BaseMultiTableSynthesizer: @@ -99,6 +104,8 @@ def _check_metadata_updated(self): def __init__(self, metadata, locales=['en_US'], synthesizer_kwargs=None): self.metadata = metadata + if type(metadata) is MultiTableMetadata: + warnings.warn(DEPRECATION_MSG, FutureWarning) with warnings.catch_warnings(): warnings.filterwarnings('ignore', message=r'.*column relationship.*') self.metadata.validate() @@ -206,8 +213,8 @@ def set_table_parameters(self, table_name, table_parameters): self._table_parameters[table_name].update(deepcopy(table_parameters)) def get_metadata(self): - """Return the ``MultiTableMetadata`` for this synthesizer.""" - return self.metadata + """Return the ``Metadata`` for this synthesizer.""" + return Metadata.load_from_dict(self.metadata.to_dict()) def _validate_all_tables(self, data): """Validate every table of the data has a valid table/metadata pair.""" diff --git a/tests/integration/metadata/test_metadata.py b/tests/integration/metadata/test_metadata.py index 6adde98ed..7efe8421d 100644 --- a/tests/integration/metadata/test_metadata.py +++ b/tests/integration/metadata/test_metadata.py @@ -1,5 +1,16 @@ +import os +import re + +import pytest + from sdv.datasets.demo import download_demo from sdv.metadata.metadata import Metadata +from sdv.metadata.multi_table import MultiTableMetadata +from sdv.metadata.single_table import SingleTableMetadata +from sdv.multi_table.hma import HMASynthesizer +from sdv.single_table.copulas import GaussianCopulaSynthesizer + +DEFAULT_TABLE_NAME = 'default_table_name' def test_metadata(): @@ -216,3 +227,146 @@ def test_detect_table_from_csv(tmp_path): } assert metadata.to_dict() == expected_metadata + + +def test_single_table_compatibility(tmp_path): + """Test if SingleTableMetadata still has compatibility with single table synthesizers.""" + # Setup + data, _ = download_demo('single_table', 'fake_hotel_guests') + warn_msg = ( + "The 'SingleTableMetadata' is deprecated. Please use the new " + "'Metadata' class for synthesizers." + ) + + single_table_metadata_dict = { + 'primary_key': 'guest_email', + 'METADATA_SPEC_VERSION': 'SINGLE_TABLE_V1', + 'columns': { + 'guest_email': {'sdtype': 'email', 'pii': True}, + 'has_rewards': {'sdtype': 'boolean'}, + 'room_type': {'sdtype': 'categorical'}, + 'amenities_fee': {'sdtype': 'numerical', 'computer_representation': 'Float'}, + 'checkin_date': {'sdtype': 'datetime', 'datetime_format': '%d %b %Y'}, + 'checkout_date': {'sdtype': 'datetime', 'datetime_format': '%d %b %Y'}, + 'room_rate': {'sdtype': 'numerical', 'computer_representation': 'Float'}, + 'billing_address': {'sdtype': 'address', 'pii': True}, + 'credit_card_number': {'sdtype': 'credit_card_number', 'pii': True}, + }, + } + metadata = SingleTableMetadata.load_from_dict(single_table_metadata_dict) + assert isinstance(metadata, SingleTableMetadata) + + # Run + with pytest.warns(FutureWarning, match=warn_msg): + synthesizer = GaussianCopulaSynthesizer(metadata) + synthesizer.fit(data) + model_path = os.path.join(tmp_path, 'synthesizer.pkl') + synthesizer.save(model_path) + + # Assert + assert os.path.exists(model_path) + assert os.path.isfile(model_path) + loaded_synthesizer = GaussianCopulaSynthesizer.load(model_path) + assert isinstance(synthesizer, GaussianCopulaSynthesizer) + assert loaded_synthesizer.get_info() == synthesizer.get_info() + assert loaded_synthesizer.metadata.to_dict() == metadata.to_dict() + loaded_sample = loaded_synthesizer.sample(10) + synthesizer.validate(loaded_sample) + + # Run against Metadata + synthesizer_2 = GaussianCopulaSynthesizer(Metadata.load_from_dict(metadata.to_dict())) + synthesizer_2.fit(data) + metadata_sample = synthesizer.sample(10) + assert loaded_synthesizer.metadata.to_dict() == synthesizer_2.metadata.to_dict() + assert metadata_sample.columns.to_list() == loaded_sample.columns.to_list() + + +def test_multi_table_compatibility(tmp_path): + """Test if MultiTableMetadata still has compatibility with multi table synthesizers.""" + # Setup + data, _ = download_demo('multi_table', 'fake_hotels') + warn_msg = re.escape( + "The 'MultiTableMetadata' is deprecated. Please use the new " + "'Metadata' class for synthesizers." + ) + + multi_dict = { + 'tables': { + 'guests': { + 'primary_key': 'guest_email', + 'columns': { + 'guest_email': {'sdtype': 'email', 'pii': True}, + 'hotel_id': {'sdtype': 'id', 'regex_format': '[A-Za-z]{5}'}, + 'has_rewards': {'sdtype': 'boolean'}, + 'room_type': {'sdtype': 'categorical'}, + 'amenities_fee': {'sdtype': 'numerical', 'computer_representation': 'Float'}, + 'checkin_date': {'sdtype': 'datetime', 'datetime_format': '%d %b %Y'}, + 'checkout_date': {'sdtype': 'datetime', 'datetime_format': '%d %b %Y'}, + 'room_rate': {'sdtype': 'numerical', 'computer_representation': 'Float'}, + 'billing_address': {'sdtype': 'address', 'pii': True}, + 'credit_card_number': {'sdtype': 'credit_card_number', 'pii': True}, + }, + }, + 'hotels': { + 'primary_key': 'hotel_id', + 'columns': { + 'hotel_id': {'sdtype': 'id', 'regex_format': 'HID_[0-9]{3}'}, + 'city': {'sdtype': 'categorical'}, + 'state': {'sdtype': 'categorical'}, + 'rating': {'sdtype': 'numerical', 'computer_representation': 'Float'}, + 'classification': {'sdtype': 'categorical'}, + }, + }, + }, + 'relationships': [ + { + 'parent_table_name': 'hotels', + 'parent_primary_key': 'hotel_id', + 'child_table_name': 'guests', + 'child_foreign_key': 'hotel_id', + } + ], + 'METADATA_SPEC_VERSION': 'MULTI_TABLE_V1', + } + metadata = MultiTableMetadata.load_from_dict(multi_dict) + assert type(metadata) is MultiTableMetadata + + # Run + with pytest.warns(FutureWarning, match=warn_msg): + synthesizer = HMASynthesizer(metadata) + + synthesizer.fit(data) + model_path = os.path.join(tmp_path, 'synthesizer.pkl') + synthesizer.save(model_path) + + # Assert + assert os.path.exists(model_path) + assert os.path.isfile(model_path) + + # Load HMASynthesizer + loaded_synthesizer = HMASynthesizer.load(model_path) + + # Asserts + assert isinstance(synthesizer, HMASynthesizer) + assert loaded_synthesizer.get_info() == synthesizer.get_info() + + # Load Metadata + expected_metadata = metadata.to_dict() + + # Asserts + + assert loaded_synthesizer.metadata.to_dict() == expected_metadata + + # Sample from loaded synthesizer + loaded_sample = loaded_synthesizer.sample(10) + synthesizer.validate(loaded_sample) + + # Run against Metadata + synthesizer_2 = HMASynthesizer(Metadata.load_from_dict(metadata.to_dict())) + synthesizer_2.fit(data) + metadata_sample = synthesizer.sample(10) + expected_metadata = loaded_synthesizer.metadata.to_dict() + expected_metadata['METADATA_SPEC_VERSION'] = 'V1' + assert expected_metadata == synthesizer_2.metadata.to_dict() + for table in metadata_sample: + assert metadata_sample[table].columns.to_list() == loaded_sample[table].columns.to_list() diff --git a/tests/integration/multi_table/test_hma.py b/tests/integration/multi_table/test_hma.py index cc22db1c2..d161c6fa2 100644 --- a/tests/integration/multi_table/test_hma.py +++ b/tests/integration/multi_table/test_hma.py @@ -19,6 +19,7 @@ from sdv.datasets.local import load_csvs from sdv.errors import SamplingError, SynthesizerInputError, VersionError from sdv.evaluation.multi_table import evaluate_quality, get_column_pair_plot, get_column_plot +from sdv.metadata.metadata import Metadata from sdv.metadata.multi_table import MultiTableMetadata from sdv.multi_table import HMASynthesizer from tests.integration.single_table.custom_constraints import MyConstraint @@ -51,6 +52,32 @@ def test_hma(self): for normal_table, increased_table in zip(normal_sample.values(), increased_sample.values()): assert increased_table.size > normal_table.size + def test_hma_metadata(self): + """End to end integration tests with ``HMASynthesizer``. + + The test consist on loading the demo data, convert the old metadata to the new format + and then fit a ``HMASynthesizer``. After fitting two samples are being generated, one with + a 0.5 scale and one with 1.5 scale. + """ + # Setup + data, multi_metadata = download_demo('multi_table', 'got_families') + metadata = Metadata.load_from_dict(multi_metadata.to_dict()) + hmasynthesizer = HMASynthesizer(metadata) + + # Run + hmasynthesizer.fit(data) + normal_sample = hmasynthesizer.sample(0.5) + increased_sample = hmasynthesizer.sample(1.5) + + # Assert + assert set(normal_sample) == {'characters', 'character_families', 'families'} + assert set(increased_sample) == {'characters', 'character_families', 'families'} + for table_name, table in normal_sample.items(): + assert all(table.columns == data[table_name].columns) + + for normal_table, increased_table in zip(normal_sample.values(), increased_sample.values()): + assert increased_table.size > normal_table.size + def test_hma_reset_sampling(self): """End to end integration test that uses ``reset_sampling``. @@ -1242,7 +1269,8 @@ def test_metadata_updated_no_warning(self, tmp_path): initialization, but is saved to a file before fitting. """ # Setup - data, metadata = download_demo('multi_table', 'got_families') + data, multi_metadata = download_demo('multi_table', 'got_families') + metadata = Metadata.load_from_dict(multi_metadata.to_dict()) # Run 1 with warnings.catch_warnings(record=True) as captured_warnings: @@ -1259,7 +1287,7 @@ def test_metadata_updated_no_warning(self, tmp_path): assert len(captured_warnings) == 0 # Run 2 - metadata_detect = MultiTableMetadata() + metadata_detect = Metadata() metadata_detect.detect_from_dataframes(data) metadata_detect.relationships = metadata.relationships @@ -1299,7 +1327,7 @@ def test_metadata_updated_warning_detect(self): """ # Setup data, metadata = download_demo('multi_table', 'got_families') - metadata_detect = MultiTableMetadata() + metadata_detect = Metadata() metadata_detect.detect_from_dataframes(data) metadata_detect.relationships = metadata.relationships diff --git a/tests/unit/multi_table/test_base.py b/tests/unit/multi_table/test_base.py index a1975da2f..3ca501aaf 100644 --- a/tests/unit/multi_table/test_base.py +++ b/tests/unit/multi_table/test_base.py @@ -124,7 +124,7 @@ def test___init__( mock_generate_synthesizer_id.return_value = synthesizer_id mock_datetime.datetime.now.return_value = '2024-04-19 16:20:10.037183' metadata = get_multi_table_metadata() - metadata.validate = Mock() + metadata.validate = Mock(spec=Metadata) # Run with catch_sdv_logs(caplog, logging.INFO, 'MultiTableSynthesizer'): @@ -147,6 +147,21 @@ def test___init__( 'SYNTHESIZER ID': 'BaseMultiTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5', }) + def test___init___deprecated(self): + """Test that init with old MultiTableMetadata gives a future warnging.""" + # Setup + metadata = get_multi_table_metadata() + metadata.validate = Mock() + + deprecation_msg = re.escape( + "The 'MultiTableMetadata' is deprecated. Please use the new " + "'Metadata' class for synthesizers." + ) + + # Run + with pytest.warns(FutureWarning, match=deprecation_msg): + BaseMultiTableSynthesizer(metadata) + @patch('sdv.metadata.single_table.is_faker_function') def test__init__column_relationship_warning(self, mock_is_faker_function): """Test that a warning is raised only once when the metadata has column relationships.""" @@ -391,7 +406,9 @@ def test_get_metadata(self): result = instance.get_metadata() # Assert - assert metadata == result + expected_metadata = Metadata.load_from_dict(metadata.to_dict()) + assert type(result) is Metadata + assert expected_metadata.to_dict() == result.to_dict() def test_validate(self): """Test that no error is being raised when the data is valid.""" @@ -879,7 +896,7 @@ def test_preprocess_int_columns(self): def test_preprocess_warning(self, mock_warnings): """Test that ``preprocess`` warns the user if the model has already been fitted.""" # Setup - metadata = get_multi_table_metadata() + metadata = Metadata.load_from_dict(get_multi_table_metadata().to_dict()) instance = BaseMultiTableSynthesizer(metadata) instance.validate = Mock() data = { From d2d19b252bf59f623d65fdbae9a38e203ef009e5 Mon Sep 17 00:00:00 2001 From: Plamen Valentinov Kolev <41479552+pvk-developer@users.noreply.github.com> Date: Thu, 8 Aug 2024 08:42:37 +0200 Subject: [PATCH 04/16] Enable evaluation to work with new metadata (#2174) --- sdv/evaluation/multi_table.py | 10 +- sdv/evaluation/single_table.py | 21 ++- .../evaluation/test_multi_table.py | 54 ++++++ .../evaluation/test_single_table.py | 26 +++ tests/unit/evaluation/test_single_table.py | 175 ++++++++++++++++++ 5 files changed, 277 insertions(+), 9 deletions(-) diff --git a/sdv/evaluation/multi_table.py b/sdv/evaluation/multi_table.py index 34ed0b94e..302b971a6 100644 --- a/sdv/evaluation/multi_table.py +++ b/sdv/evaluation/multi_table.py @@ -15,7 +15,7 @@ def evaluate_quality(real_data, synthetic_data, metadata, verbose=True): Dictionary containing the real table data. synthetic_column (dict): Dictionary containing the synthetic table data. - metadata (MultiTableMetadata): + metadata (Metadata): The metadata object describing the real/synthetic data. verbose (bool): Whether or not to print report summary and progress. @@ -38,7 +38,7 @@ def run_diagnostic(real_data, synthetic_data, metadata, verbose=True): Dictionary containing the real table data. synthetic_column (dict): Dictionary containing the synthetic table data. - metadata (MultiTableMetadata): + metadata (Metadata): The metadata object describing the real/synthetic data. verbose (bool): Whether or not to print report summary and progress. @@ -61,7 +61,7 @@ def get_column_plot(real_data, synthetic_data, metadata, table_name, column_name Dictionary containing the real table data. synthetic_column (dict): Dictionary containing the synthetic table data. - metadata (MultiTableMetadata): + metadata (Metadata): Metadata describing the data. table_name (str): The name of the table. @@ -98,7 +98,7 @@ def get_column_pair_plot( Dictionary containing the real table data. synthetic_column (dict): Dictionary containing the synthetic table data. - metadata (MultiTableMetadata): + metadata (Metadata): Metadata describing the data. table_name (str): The name of the table. @@ -147,7 +147,7 @@ def get_cardinality_plot( The name of the parent table. child_foreign_key (string): The name of the foreign key column in the child table. - metadata (MultiTableMetadata): + metadata (Metadata): Metadata describing the data. plot_type (str): The plot type to use to plot the cardinality. Must be either 'bar' or 'distplot'. diff --git a/sdv/evaluation/single_table.py b/sdv/evaluation/single_table.py index d02b38a16..89bdad616 100644 --- a/sdv/evaluation/single_table.py +++ b/sdv/evaluation/single_table.py @@ -6,6 +6,7 @@ from sdmetrics.reports.single_table.quality_report import QualityReport from sdv.errors import VisualizationUnavailableError +from sdv.metadata.metadata import Metadata def evaluate_quality(real_data, synthetic_data, metadata, verbose=True): @@ -16,7 +17,7 @@ def evaluate_quality(real_data, synthetic_data, metadata, verbose=True): The table containing the real data. synthetic_data (pd.DataFrame): The table containing the synthetic data. - metadata (SingleTableMetadata): + metadata (Metadata): The metadata object describing the real/synthetic data. verbose (bool): Whether or not to print report summary and progress. @@ -27,6 +28,9 @@ def evaluate_quality(real_data, synthetic_data, metadata, verbose=True): Single table quality report object. """ quality_report = QualityReport() + if isinstance(metadata, Metadata): + metadata = metadata._convert_to_single_table() + quality_report.generate(real_data, synthetic_data, metadata.to_dict(), verbose) return quality_report @@ -39,7 +43,7 @@ def run_diagnostic(real_data, synthetic_data, metadata, verbose=True): The table containing the real data. synthetic_data (pd.DataFrame): The table containing the synthetic data. - metadata (SingleTableMetadata): + metadata (Metadata): The metadata object describing the real/synthetic data. verbose (bool): Whether or not to print report summary and progress. @@ -50,6 +54,9 @@ def run_diagnostic(real_data, synthetic_data, metadata, verbose=True): Single table diagnostic report object. """ diagnostic_report = DiagnosticReport() + if isinstance(metadata, Metadata): + metadata = metadata._convert_to_single_table() + diagnostic_report.generate(real_data, synthetic_data, metadata.to_dict(), verbose) return diagnostic_report @@ -62,7 +69,7 @@ def get_column_plot(real_data, synthetic_data, metadata, column_name, plot_type= The real table data. synthetic_data (pandas.DataFrame): The synthetic table data. - metadata (SingleTableMetadata): + metadata (Metadata): The table metadata. column_name (str): The name of the column. @@ -76,6 +83,9 @@ def get_column_plot(real_data, synthetic_data, metadata, column_name, plot_type= plotly.graph_objects._figure.Figure: 1D marginal distribution plot (i.e. a histogram) of the columns. """ + if isinstance(metadata, Metadata): + metadata = metadata._convert_to_single_table() + sdtype = metadata.columns.get(column_name)['sdtype'] if plot_type is None: if sdtype in ['datetime', 'numerical']: @@ -114,7 +124,7 @@ def get_column_pair_plot( The real table data. synthetic_column (pandas.Dataframe): The synthetic table data. - metadata (SingleTableMetadata): + metadata (Metadata): The table metadata. column_names (list[string]): The names of the two columns to plot. @@ -131,6 +141,9 @@ def get_column_pair_plot( plotly.graph_objects._figure.Figure: 2D bivariate distribution plot (i.e. a scatterplot) of the columns. """ + if isinstance(metadata, Metadata): + metadata = metadata._convert_to_single_table() + real_data = real_data.copy() synthetic_data = synthetic_data.copy() if plot_type is None: diff --git a/tests/integration/evaluation/test_multi_table.py b/tests/integration/evaluation/test_multi_table.py index 3367bd0fa..5f816fb87 100644 --- a/tests/integration/evaluation/test_multi_table.py +++ b/tests/integration/evaluation/test_multi_table.py @@ -1,6 +1,7 @@ import pandas as pd from sdv.evaluation.multi_table import evaluate_quality, run_diagnostic +from sdv.metadata.metadata import Metadata from sdv.metadata.multi_table import MultiTableMetadata @@ -55,3 +56,56 @@ def test_evaluation(): 'Score': [1.0, 1.0, 1.0], }), ) + + +def test_evaluation_metadata(): + """Test ``evaluate_quality`` and ``run_diagnostic`` with Metadata.""" + # Setup + table = pd.DataFrame({'id': [0, 1, 2, 3], 'col': [1, 2, 3, 4]}) + slightly_different_table = pd.DataFrame({'id': [0, 1, 2, 3], 'col': [1, 2, 3, 3.5]}) + data = { + 'table1': table, + 'table2': table, + } + samples = { + 'table1': table, + 'table2': slightly_different_table, + } + metadata = Metadata().load_from_dict({ + 'tables': { + 'table1': { + 'columns': { + 'id': {'sdtype': 'id'}, + 'col': {'sdtype': 'numerical'}, + }, + }, + 'table2': { + 'columns': { + 'id': {'sdtype': 'id'}, + 'col': {'sdtype': 'numerical'}, + }, + }, + }, + 'relationships': [ + { + 'parent_table_name': 'table1', + 'parent_primary_key': 'id', + 'child_table_name': 'table2', + 'child_foreign_key': 'id', + } + ], + }) + + # Run and Assert + score = evaluate_quality(data, samples, metadata).get_score() + assert score == 0.9566297110928815 + + report = run_diagnostic(data, samples, metadata) + assert report.get_score() == 1 + pd.testing.assert_frame_equal( + report.get_properties(), + pd.DataFrame({ + 'Property': ['Data Validity', 'Data Structure', 'Relationship Validity'], + 'Score': [1.0, 1.0, 1.0], + }), + ) diff --git a/tests/integration/evaluation/test_single_table.py b/tests/integration/evaluation/test_single_table.py index 5b9ee9123..95e1b671e 100644 --- a/tests/integration/evaluation/test_single_table.py +++ b/tests/integration/evaluation/test_single_table.py @@ -2,6 +2,7 @@ from sdv.datasets.demo import download_demo from sdv.evaluation.single_table import evaluate_quality, get_column_pair_plot, run_diagnostic +from sdv.metadata.metadata import Metadata from sdv.metadata.single_table import SingleTableMetadata from sdv.single_table.copulas import GaussianCopulaSynthesizer @@ -31,6 +32,31 @@ def test_evaluation(): ) +def test_evaluation_metadata(): + """Test ``evaluate_quality`` and ``run_diagnostic`` with Metadata.""" + # Setup + data = pd.DataFrame({'col': [1, 2, 3]}) + metadata_dict = {'columns': {'col': {'sdtype': 'numerical'}}} + metadata = Metadata.load_from_dict(metadata_dict) + synthesizer = GaussianCopulaSynthesizer(metadata, default_distribution='truncnorm') + + # Run and Assert + synthesizer.fit(data) + samples = synthesizer.sample(10) + score = evaluate_quality(data, samples, metadata).get_score() + assert score == 0.8666666666666667 + + report = run_diagnostic(data, samples, metadata) + assert report.get_score() == 1 + pd.testing.assert_frame_equal( + report.get_properties(), + pd.DataFrame({ + 'Property': ['Data Validity', 'Data Structure'], + 'Score': [1.0, 1.0], + }), + ) + + def test_column_pair_plot_sample_size_parameter(): """Test the sample_size parameter for the column pair plot.""" # Setup diff --git a/tests/unit/evaluation/test_single_table.py b/tests/unit/evaluation/test_single_table.py index 1baa9a18e..3a53e11e2 100644 --- a/tests/unit/evaluation/test_single_table.py +++ b/tests/unit/evaluation/test_single_table.py @@ -13,6 +13,7 @@ get_column_plot, run_diagnostic, ) +from sdv.metadata.metadata import Metadata from sdv.metadata.single_table import SingleTableMetadata @@ -32,6 +33,23 @@ def test_evaluate_quality(): QualityReport.generate.assert_called_once_with(data1, data2, metadata.to_dict(), True) +def test_evaluate_quality_metadata(): + """Test ``generate`` is called for the ``QualityReport`` object with Metadata.""" + # Setup + data1 = pd.DataFrame({'col': [1, 2, 3]}) + data2 = pd.DataFrame({'col': [2, 1, 3]}) + metadata_dict = {'columns': {'col': {'sdtype': 'numerical'}}} + metadata = Metadata.load_from_dict(metadata_dict) + QualityReport.generate = Mock() + + # Run + evaluate_quality(data1, data2, metadata) + + # Assert + expected_metadata = metadata.tables['default_table_name'].to_dict() + QualityReport.generate.assert_called_once_with(data1, data2, expected_metadata, True) + + def test_run_diagnostic(): """Test ``generate`` is called for the ``DiagnosticReport`` object.""" # Setup @@ -48,6 +66,23 @@ def test_run_diagnostic(): DiagnosticReport.generate.assert_called_once_with(data1, data2, metadata.to_dict(), True) +def test_run_diagnostic_metadata(): + """Test ``generate`` is called for the ``DiagnosticReport`` object with Metadata.""" + # Setup + data1 = pd.DataFrame({'col': [1, 2, 3]}) + data2 = pd.DataFrame({'col': [2, 1, 3]}) + metadata_dict = {'columns': {'col': {'sdtype': 'numerical'}}} + metadata = Metadata.load_from_dict(metadata_dict) + DiagnosticReport.generate = Mock(return_value=123) + + # Run + run_diagnostic(data1, data2, metadata) + + # Assert + expected_metadata = metadata.tables['default_table_name'].to_dict() + DiagnosticReport.generate.assert_called_once_with(data1, data2, expected_metadata, True) + + @patch('sdmetrics.visualization.get_column_plot') def test_get_column_plot_continuous_data(mock_get_plot): """Test the ``get_column_plot`` with continuous data. @@ -69,6 +104,27 @@ def test_get_column_plot_continuous_data(mock_get_plot): assert plot == mock_get_plot.return_value +@patch('sdmetrics.visualization.get_column_plot') +def test_get_column_plot_continuous_data_metadata(mock_get_plot): + """Test the ``get_column_plot`` with continuous data. + + Test that when we call ``get_column_plot`` with continuous data (datetime or numerical) + this will choose to use the ``distplot`` as ``plot_type``. Uses Metadata. + """ + # Setup + data1 = pd.DataFrame({'col': [1, 2, 3]}) + data2 = pd.DataFrame({'col': [2, 1, 3]}) + metadata_dict = {'columns': {'col': {'sdtype': 'numerical'}}} + metadata = Metadata.load_from_dict(metadata_dict) + + # Run + plot = get_column_plot(data1, data2, metadata, 'col') + + # Assert + mock_get_plot.assert_called_once_with(data1, data2, 'col', plot_type='distplot') + assert plot == mock_get_plot.return_value + + @patch('sdmetrics.visualization.get_column_plot') def test_get_column_plot_discrete_data(mock_get_plot): """Test the ``get_column_plot`` with discrete data. @@ -90,6 +146,27 @@ def test_get_column_plot_discrete_data(mock_get_plot): assert plot == mock_get_plot.return_value +@patch('sdmetrics.visualization.get_column_plot') +def test_get_column_plot_discrete_data_metadata(mock_get_plot): + """Test the ``get_column_plot`` with discrete data. + + Test that when we call ``get_column_plot`` with discrete data (categorical or boolean) + this will choose to use the ``bar`` as ``plot_type``. Uses Metadata. + """ + # Setup + data1 = pd.DataFrame({'col': ['a', 'b', 'c']}) + data2 = pd.DataFrame({'col': ['a', 'b', 'c']}) + metadata_dict = {'columns': {'col': {'sdtype': 'categorical'}}} + metadata = Metadata.load_from_dict(metadata_dict) + + # Run + plot = get_column_plot(data1, data2, metadata, 'col') + + # Assert + mock_get_plot.assert_called_once_with(data1, data2, 'col', plot_type='bar') + assert plot == mock_get_plot.return_value + + @patch('sdmetrics.visualization.get_column_plot') def test_get_column_plot_discrete_data_with_distplot(mock_get_plot): """Test the ``get_column_plot`` with discrete data. @@ -112,6 +189,28 @@ def test_get_column_plot_discrete_data_with_distplot(mock_get_plot): assert plot == mock_get_plot.return_value +@patch('sdmetrics.visualization.get_column_plot') +def test_get_column_plot_discrete_data_with_distplot_metadata(mock_get_plot): + """Test the ``get_column_plot`` with discrete data. + + Test that when we call ``get_column_plot`` with discrete data (categorical or boolean) + and pass in the ``distplot`` it will call the ``sdmetrics.visualization.get_column_plot`` + with it and not switch to ``bar``. Uses Metadata. + """ + # Setup + data1 = pd.DataFrame({'col': ['a', 'b', 'c']}) + data2 = pd.DataFrame({'col': ['a', 'b', 'c']}) + metadata_dict = {'columns': {'col': {'sdtype': 'categorical'}}} + metadata = Metadata.load_from_dict(metadata_dict) + + # Run + plot = get_column_plot(data1, data2, metadata, 'col', plot_type='distplot') + + # Assert + mock_get_plot.assert_called_once_with(data1, data2, 'col', plot_type='distplot') + assert plot == mock_get_plot.return_value + + @patch('sdmetrics.visualization.get_column_plot') def test_get_column_plot_invalid_sdtype(mock_get_plot): """Test the ``get_column_plot`` with sdtype that can't be plotted. @@ -134,6 +233,28 @@ def test_get_column_plot_invalid_sdtype(mock_get_plot): get_column_plot(data1, data2, metadata, 'col') +@patch('sdmetrics.visualization.get_column_plot') +def test_get_column_plot_invalid_sdtype_metadata(mock_get_plot): + """Test the ``get_column_plot`` with sdtype that can't be plotted. + + Test that when we call ``get_column_plot`` with an sdtype that can't be plotted, this raises + an error. Uses Metadata. + """ + # Setup + data1 = pd.DataFrame({'col': ['a', 'b', 'c']}) + data2 = pd.DataFrame({'col': ['a', 'b', 'c']}) + metadata_dict = {'columns': {'col': {'sdtype': 'id'}}} + metadata = Metadata.load_from_dict(metadata_dict) + + # Run and Assert + error_msg = re.escape( + "The column 'col' has sdtype 'id', which does not have a " + "supported visualization. To visualize this data anyways, please add a 'plot_type'." + ) + with pytest.raises(VisualizationUnavailableError, match=error_msg): + get_column_plot(data1, data2, metadata, 'col') + + @patch('sdmetrics.visualization.get_column_plot') def test_get_column_plot_invalid_sdtype_with_plot_type(mock_get_plot): """Test the ``get_column_plot`` with sdtype that can't be plotted. @@ -155,6 +276,27 @@ def test_get_column_plot_invalid_sdtype_with_plot_type(mock_get_plot): assert plot == mock_get_plot.return_value +@patch('sdmetrics.visualization.get_column_plot') +def test_get_column_plot_invalid_sdtype_with_plot_type_metadata(mock_get_plot): + """Test the ``get_column_plot`` with sdtype that can't be plotted. + + Test that when we call ``get_column_plot`` with an sdtype that can't be plotted, but passing + ``plot_type`` it will attempt to plot it using the ``sdmetrics.visualization.get_column_plot``. + """ + # Setup + data1 = pd.DataFrame({'col': ['a', 'b', 'c']}) + data2 = pd.DataFrame({'col': ['a', 'b', 'c']}) + metadata_dict = {'columns': {'col': {'sdtype': 'id'}}} + metadata = Metadata.load_from_dict(metadata_dict) + + # Run + plot = get_column_plot(data1, data2, metadata, 'col', plot_type='bar') + + # Assert + mock_get_plot.assert_called_once_with(data1, data2, 'col', plot_type='bar') + assert plot == mock_get_plot.return_value + + @patch('sdmetrics.visualization.get_column_plot') def test_get_column_plot_with_datetime_sdtype(mock_get_plot): """Test the ``get_column_plot`` with datetime sdtype. @@ -406,6 +548,39 @@ def test_get_column_pair_plot_with_sample_size(mock_get_plot): assert synthetic_subsample.isin(synthetic_data).all().all() +@patch('sdmetrics.visualization.get_column_pair_plot') +def test_get_column_pair_plot_with_sample_size_metadata(mock_get_plot): + """Test ``get_column_pair_plot`` with ``sample_size`` parameter with Metadata.""" + # Setup + columns = ['amount', 'price'] + real_data = pd.DataFrame({ + 'amount': [1, 2, 3], + 'price': [10, 20, 30], + }) + synthetic_data = pd.DataFrame({ + 'amount': [1.0, 2.0, 3.0], + 'price': [11.0, 22.0, 33.0], + }) + metadata_dict = { + 'columns': { + 'amount': {'sdtype': 'numerical'}, + 'price': {'sdtype': 'numerical'}, + } + } + metadata = Metadata.load_from_dict(metadata_dict) + + # Run + get_column_pair_plot(real_data, synthetic_data, metadata, columns, sample_size=2) + + # Assert + real_subsample = mock_get_plot.call_args[0][0] + synthetic_subsample = mock_get_plot.call_args[0][1] + assert len(real_subsample) == 2 + assert len(synthetic_subsample) == 2 + assert real_subsample.isin(real_data).all().all() + assert synthetic_subsample.isin(synthetic_data).all().all() + + @patch('sdmetrics.visualization.get_column_pair_plot') def test_get_column_pair_plot_with_sample_size_too_big(mock_get_plot): """Test ``get_column_pair_plot`` when ``sample_size`` is bigger than the length of the data.""" From e728c733806d6dde21a67c1ba341a3c62dd063c8 Mon Sep 17 00:00:00 2001 From: John La Date: Tue, 13 Aug 2024 09:07:38 -0500 Subject: [PATCH 05/16] Enable demos are using the new metadata (#2175) --- sdv/datasets/demo.py | 25 +++++----- sdv/evaluation/single_table.py | 3 ++ sdv/metadata/metadata.py | 7 +-- sdv/utils/utils.py | 8 +++- .../data_processing/test_data_processor.py | 22 ++++----- tests/integration/sequential/test_par.py | 4 +- tests/integration/single_table/test_base.py | 2 +- .../integration/single_table/test_copulas.py | 6 +-- tests/integration/single_table/test_ctgan.py | 2 +- tests/integration/utils/test_utils.py | 2 + tests/unit/datasets/test_demo.py | 46 +++++++++++++------ tests/unit/metadata/test_metadata.py | 2 +- 12 files changed, 77 insertions(+), 52 deletions(-) diff --git a/sdv/datasets/demo.py b/sdv/datasets/demo.py index 1ead34d39..5142679d0 100644 --- a/sdv/datasets/demo.py +++ b/sdv/datasets/demo.py @@ -4,6 +4,7 @@ import json import logging import os +import warnings from collections import defaultdict from pathlib import Path from zipfile import ZipFile @@ -15,8 +16,7 @@ from botocore.client import Config from botocore.exceptions import ClientError -from sdv.metadata.multi_table import MultiTableMetadata -from sdv.metadata.single_table import SingleTableMetadata +from sdv.metadata.metadata import Metadata LOGGER = logging.getLogger(__name__) BUCKET = 'sdv-demo-datasets' @@ -104,15 +104,20 @@ def _get_data(modality, output_folder_name, in_memory_directory): return data -def _get_metadata(modality, output_folder_name, in_memory_directory): - metadata = MultiTableMetadata() if modality == 'multi_table' else SingleTableMetadata() +def _get_metadata(output_folder_name, in_memory_directory, dataset_name): + metadata = Metadata() if output_folder_name: metadata_path = os.path.join(output_folder_name, METADATA_FILENAME) - metadata = metadata.load_from_json(metadata_path) + metadata = metadata.load_from_json(metadata_path, dataset_name) else: - metadict = json.loads(in_memory_directory['metadata_v1.json']) - metadata = metadata.load_from_dict(metadict) + metadata_path = 'metadata_v2.json' + if metadata_path not in in_memory_directory: + warnings.warn(f'Metadata for {dataset_name} is missing updated version v2.') + metadata_path = 'metadata_v1.json' + + metadict = json.loads(in_memory_directory[metadata_path]) + metadata = metadata.load_from_dict(metadict, dataset_name) return metadata @@ -134,22 +139,20 @@ def download_demo(modality, dataset_name, output_folder_name=None): tuple (data, metadata): If ``data`` is single table or sequential, it is a DataFrame. If ``data`` is multi table, it is a dictionary mapping table name to DataFrame. - If ``metadata`` is single table or sequential, it is a ``SingleTableMetadata`` object. - If ``metadata`` is multi table, it is a ``MultiTableMetadata`` object. + ``metadata`` is of class ``Metadata`` which can represent single table or multi table. Raises: Error: * If the ``dataset_name`` exists in the bucket but under a different modality. * If the ``dataset_name`` doesn't exist in the bucket. * If there is already a folder named ``output_folder_name``. - * If ``modality`` is not ``'single_table'``, ``'multi_table'`` or ``'sequential'``. """ _validate_modalities(modality) _validate_output_folder(output_folder_name) bytes_io = _download(modality, dataset_name) in_memory_directory = _extract_data(bytes_io, output_folder_name) data = _get_data(modality, output_folder_name, in_memory_directory) - metadata = _get_metadata(modality, output_folder_name, in_memory_directory) + metadata = _get_metadata(output_folder_name, in_memory_directory, dataset_name) return data, metadata diff --git a/sdv/evaluation/single_table.py b/sdv/evaluation/single_table.py index 89bdad616..eec9424c7 100644 --- a/sdv/evaluation/single_table.py +++ b/sdv/evaluation/single_table.py @@ -27,6 +27,9 @@ def evaluate_quality(real_data, synthetic_data, metadata, verbose=True): QualityReport: Single table quality report object. """ + if isinstance(metadata, Metadata): + metadata = metadata._convert_to_single_table() + quality_report = QualityReport() if isinstance(metadata, Metadata): metadata = metadata._convert_to_single_table() diff --git a/sdv/metadata/metadata.py b/sdv/metadata/metadata.py index d8a9b463c..dcb1e0b4c 100644 --- a/sdv/metadata/metadata.py +++ b/sdv/metadata/metadata.py @@ -1,7 +1,5 @@ """Metadata.""" -from pathlib import Path - from sdv.metadata.errors import InvalidMetadataError from sdv.metadata.multi_table import MultiTableMetadata from sdv.metadata.single_table import SingleTableMetadata @@ -14,7 +12,7 @@ class Metadata(MultiTableMetadata): METADATA_SPEC_VERSION = 'V1' @classmethod - def load_from_json(cls, filepath): + def load_from_json(cls, filepath, single_table_name=None): """Create a ``Metadata`` instance from a ``json`` file. Args: @@ -28,9 +26,8 @@ def load_from_json(cls, filepath): Returns: A ``Metadata`` instance. """ - filename = Path(filepath).stem metadata = read_json(filepath) - return cls.load_from_dict(metadata, filename) + return cls.load_from_dict(metadata, single_table_name) @classmethod def load_from_dict(cls, metadata_dict, single_table_name=None): diff --git a/sdv/utils/utils.py b/sdv/utils/utils.py index 72c43eb73..3e1b3c00b 100644 --- a/sdv/utils/utils.py +++ b/sdv/utils/utils.py @@ -8,6 +8,7 @@ from sdv._utils import _validate_foreign_keys_not_null from sdv.errors import InvalidDataError, SynthesizerInputError +from sdv.metadata.metadata import Metadata from sdv.multi_table.utils import _drop_rows @@ -75,8 +76,8 @@ def get_random_sequence_subset( Args: data (pandas.DataFrame): The sequential data. - metadata (SingleTableMetadata): - A SingleTableMetadata object describing the data. + metadata (Metadata): + A Metadata object describing the data. num_sequences (int): The number of sequences to subsample. max_sequence_length (int): @@ -91,6 +92,9 @@ def get_random_sequence_subset( - random: Randomly choose n rows to keep within the sequence. It is important to keep the randomly chosen rows in the same order as they appear in the original data. """ + if isinstance(metadata, Metadata): + metadata = metadata._convert_to_single_table() + if long_sequence_subsampling_method not in ['first_rows', 'last_rows', 'random']: raise ValueError( 'long_sequence_subsampling_method must be one of "first_rows", "last_rows" or "random"' diff --git a/tests/integration/data_processing/test_data_processor.py b/tests/integration/data_processing/test_data_processor.py index e448e95f0..a3217f712 100644 --- a/tests/integration/data_processing/test_data_processor.py +++ b/tests/integration/data_processing/test_data_processor.py @@ -52,10 +52,10 @@ def test_with_anonymized_columns(self): data, metadata = download_demo('single_table', 'adult') # Add anonymized field - metadata.update_column('occupation', sdtype='job', pii=True) + metadata.update_column('adult', 'occupation', sdtype='job', pii=True) # Instance ``DataProcessor`` - dp = DataProcessor(metadata) + dp = DataProcessor(metadata._convert_to_single_table()) # Fit dp.fit(data) @@ -100,18 +100,18 @@ def test_with_anonymized_columns_and_primary_key(self): data, metadata = download_demo('single_table', 'adult') # Add anonymized field - metadata.update_column('occupation', sdtype='job', pii=True) + metadata.update_column('adult', 'occupation', sdtype='job', pii=True) # Add primary key field - metadata.add_column('id', sdtype='id', regex_format='ID_\\d{4}[0-9]') - metadata.set_primary_key('id') + metadata.add_column('adult', 'id', sdtype='id', regex_format='ID_\\d{4}[0-9]') + metadata.set_primary_key('adult', 'id') # Add id size = len(data) data['id'] = np.arange(0, size).astype('O') # Instance ``DataProcessor`` - dp = DataProcessor(metadata) + dp = DataProcessor(metadata._convert_to_single_table()) # Fit dp.fit(data) @@ -247,7 +247,7 @@ def test_prepare_for_fitting(self): data, metadata = download_demo( modality='single_table', dataset_name='student_placements_pii' ) - dp = DataProcessor(metadata) + dp = DataProcessor(metadata._convert_to_single_table()) # Run dp.prepare_for_fitting(data) @@ -288,7 +288,7 @@ def test_reverse_transform_with_formatters(self): """End to end test using formatters.""" # Setup data, metadata = download_demo(modality='single_table', dataset_name='student_placements') - dp = DataProcessor(metadata) + dp = DataProcessor(metadata._convert_to_single_table()) # Run dp.fit(data) @@ -327,7 +327,7 @@ def test_refit_hypertransformer(self): """Test data processor re-fits _hyper_transformer.""" # Setup data, metadata = download_demo(modality='single_table', dataset_name='student_placements') - dp = DataProcessor(metadata) + dp = DataProcessor(metadata._convert_to_single_table()) # Run dp.fit(data) @@ -346,9 +346,9 @@ def test_localized_anonymized_columns(self): """Test data processor uses the default locale for anonymized columns.""" # Setup data, metadata = download_demo('single_table', 'adult') - metadata.update_column('occupation', sdtype='job', pii=True) + metadata.update_column('adult', 'occupation', sdtype='job', pii=True) - dp = DataProcessor(metadata, locales=['en_CA', 'fr_CA']) + dp = DataProcessor(metadata._convert_to_single_table(), locales=['en_CA', 'fr_CA']) # Run dp.fit(data) diff --git a/tests/integration/sequential/test_par.py b/tests/integration/sequential/test_par.py index a85c11309..567586525 100644 --- a/tests/integration/sequential/test_par.py +++ b/tests/integration/sequential/test_par.py @@ -183,7 +183,7 @@ def test_synthesize_sequences(tmp_path): assert model_path.exists() assert model_path.is_file() assert loaded_synthesizer.get_info() == synthesizer.get_info() - assert loaded_synthesizer.metadata.to_dict() == metadata.to_dict() + assert loaded_synthesizer.metadata.to_dict() == metadata._convert_to_single_table().to_dict() synthesizer.validate(synthetic_data) synthesizer.validate(custom_synthetic_data) synthesizer.validate(custom_synthetic_data_conditional) @@ -445,7 +445,7 @@ def test_par_categorical_column_represented_by_floats(): # Setup data, metadata = download_demo('sequential', 'nasdaq100_2019') data['category'] = [100.0 if i % 2 == 0 else 50.0 for i in data.index] - metadata.add_column('category', sdtype='categorical') + metadata.add_column('nasdaq100_2019', 'category', sdtype='categorical') # Run synth = PARSynthesizer(metadata) diff --git a/tests/integration/single_table/test_base.py b/tests/integration/single_table/test_base.py index d08abe502..4832a2746 100644 --- a/tests/integration/single_table/test_base.py +++ b/tests/integration/single_table/test_base.py @@ -201,7 +201,7 @@ def test_sample_keys_are_scrambled(): """Test that the keys are scrambled in the sampled data.""" # Setup data, metadata = download_demo(modality='single_table', dataset_name='fake_hotel_guests') - metadata.update_column('guest_email', sdtype='id', regex_format='[A-Z]{3}') + metadata.update_column('fake_hotel_guests', 'guest_email', sdtype='id', regex_format='[A-Z]{3}') synthesizer = GaussianCopulaSynthesizer(metadata) synthesizer.fit(data) diff --git a/tests/integration/single_table/test_copulas.py b/tests/integration/single_table/test_copulas.py index dd58f90bb..0a8477409 100644 --- a/tests/integration/single_table/test_copulas.py +++ b/tests/integration/single_table/test_copulas.py @@ -106,7 +106,7 @@ def test_synthesize_table_gaussian_copula(tmp_path): loaded_synthesizer = GaussianCopulaSynthesizer.load(model_path) assert isinstance(synthesizer, GaussianCopulaSynthesizer) assert loaded_synthesizer.get_info() == synthesizer.get_info() - assert loaded_synthesizer.metadata.to_dict() == metadata.to_dict() + assert loaded_synthesizer.metadata.to_dict() == metadata._convert_to_single_table().to_dict() loaded_synthesizer.sample(20) # Assert - custom synthesizer @@ -192,7 +192,7 @@ def test_adding_constraints(tmp_path): assert isinstance(loaded_synthesizer, GaussianCopulaSynthesizer) assert loaded_synthesizer.get_info() == synthesizer.get_info() - assert loaded_synthesizer.metadata.to_dict() == metadata.to_dict() + assert loaded_synthesizer.metadata.to_dict() == metadata._convert_to_single_table().to_dict() sampled_data = loaded_synthesizer.sample(100) validation = sampled_data[sampled_data['has_rewards']] assert validation['amenities_fee'].sum() == 0.0 @@ -251,7 +251,7 @@ def test_custom_processing_anonymization(): anonymized_sample = anonymization_synthesizer.sample(num_rows=100) # Assert - Pre-process data - assert pre_processed_data.index.name == metadata.primary_key + assert pre_processed_data.index.name == metadata.tables['fake_hotel_guests'].primary_key assert all(pre_processed_data.dtypes == 'float64') for column in sensitive_columns: assert default_sample[column].isin(real_data[column]).sum() == 0 diff --git a/tests/integration/single_table/test_ctgan.py b/tests/integration/single_table/test_ctgan.py index 4e0218985..26bc906b2 100644 --- a/tests/integration/single_table/test_ctgan.py +++ b/tests/integration/single_table/test_ctgan.py @@ -113,7 +113,7 @@ def test_synthesize_table_ctgan(tmp_path): loaded_synthesizer = CTGANSynthesizer.load(model_path) assert isinstance(synthesizer, CTGANSynthesizer) assert loaded_synthesizer.get_info() == synthesizer.get_info() - assert loaded_synthesizer.metadata.to_dict() == metadata.to_dict() + assert loaded_synthesizer.metadata.to_dict() == metadata._convert_to_single_table().to_dict() loaded_synthesizer.sample(20) # Assert - custom synthesizer diff --git a/tests/integration/utils/test_utils.py b/tests/integration/utils/test_utils.py index b03f4be5c..8c3498311 100644 --- a/tests/integration/utils/test_utils.py +++ b/tests/integration/utils/test_utils.py @@ -149,6 +149,7 @@ def test_get_random_sequence_subset(): data, metadata = download_demo(modality='sequential', dataset_name='nasdaq100_2019') # Run + metadata = metadata._convert_to_single_table() subset = get_random_sequence_subset(data, metadata, num_sequences=3, max_sequence_length=5) # Assert @@ -169,6 +170,7 @@ def test_get_random_sequence_subset_random_clipping(): """ # Setup data, metadata = download_demo(modality='sequential', dataset_name='nasdaq100_2019') + metadata = metadata._convert_to_single_table() # Run subset = get_random_sequence_subset( diff --git a/tests/unit/datasets/test_demo.py b/tests/unit/datasets/test_demo.py index 04ab9cf01..dbb2e7cf8 100644 --- a/tests/unit/datasets/test_demo.py +++ b/tests/unit/datasets/test_demo.py @@ -49,11 +49,16 @@ def test_download_demo_single_table(tmpdir): pd.testing.assert_frame_equal(table.head(2), expected_table) expected_metadata_dict = { - 'columns': { - '0': {'sdtype': 'numerical', 'computer_representation': 'Int64'}, - '1': {'sdtype': 'numerical', 'computer_representation': 'Int64'}, + 'tables': { + 'ring': { + 'columns': { + '0': {'sdtype': 'numerical', 'computer_representation': 'Int64'}, + '1': {'sdtype': 'numerical', 'computer_representation': 'Int64'}, + }, + } }, - 'METADATA_SPEC_VERSION': 'SINGLE_TABLE_V1', + 'METADATA_SPEC_VERSION': 'V1', + 'relationships': [], } assert metadata.to_dict() == expected_metadata_dict @@ -99,12 +104,18 @@ def test_download_demo_single_table_no_output_folder(): pd.testing.assert_frame_equal(table.head(2), expected_table) expected_metadata_dict = { - 'columns': { - '0': {'sdtype': 'numerical', 'computer_representation': 'Int64'}, - '1': {'sdtype': 'numerical', 'computer_representation': 'Int64'}, + 'tables': { + 'ring': { + 'columns': { + '0': {'sdtype': 'numerical', 'computer_representation': 'Int64'}, + '1': {'sdtype': 'numerical', 'computer_representation': 'Int64'}, + }, + } }, - 'METADATA_SPEC_VERSION': 'SINGLE_TABLE_V1', + 'METADATA_SPEC_VERSION': 'V1', + 'relationships': [], } + assert metadata.to_dict() == expected_metadata_dict @@ -125,12 +136,17 @@ def test_download_demo_timeseries(tmpdir): pd.testing.assert_frame_equal(table.head(2), expected_table) expected_metadata_dict = { - 'METADATA_SPEC_VERSION': 'SINGLE_TABLE_V1', - 'columns': { - 'e_id': {'sdtype': 'numerical', 'computer_representation': 'Int64'}, - 'dim_0': {'sdtype': 'numerical', 'computer_representation': 'Float'}, - 'dim_1': {'sdtype': 'numerical', 'computer_representation': 'Float'}, - 'ml_class': {'sdtype': 'categorical'}, + 'METADATA_SPEC_VERSION': 'V1', + 'relationships': [], + 'tables': { + 'Libras': { + 'columns': { + 'e_id': {'sdtype': 'numerical', 'computer_representation': 'Int64'}, + 'dim_0': {'sdtype': 'numerical', 'computer_representation': 'Float'}, + 'dim_1': {'sdtype': 'numerical', 'computer_representation': 'Float'}, + 'ml_class': {'sdtype': 'categorical'}, + } + } }, } assert metadata.to_dict() == expected_metadata_dict @@ -203,7 +219,7 @@ def test_download_demo_multi_table(tmpdir): 'child_foreign_key': 'character_id', }, ], - 'METADATA_SPEC_VERSION': 'MULTI_TABLE_V1', + 'METADATA_SPEC_VERSION': 'V1', } assert metadata.to_dict() == expected_metadata_dict diff --git a/tests/unit/metadata/test_metadata.py b/tests/unit/metadata/test_metadata.py index 284bad8d6..0d282c347 100644 --- a/tests/unit/metadata/test_metadata.py +++ b/tests/unit/metadata/test_metadata.py @@ -123,7 +123,7 @@ def test_load_from_json_single_table(self, mock_json, mock_path): } # Run - instance = Metadata.load_from_json('filepath.json') + instance = Metadata.load_from_json('filepath.json', 'filepath') # Assert assert list(instance.tables.keys()) == ['filepath'] From b5d921b99f93551b667278a7bb166ad310371298 Mon Sep 17 00:00:00 2001 From: John La Date: Wed, 14 Aug 2024 10:36:44 -0500 Subject: [PATCH 06/16] Update tests to use Metadata (#2178) --- sdv/io/local/local.py | 4 +- sdv/lite/single_table.py | 15 ++- sdv/metadata/metadata.py | 25 +++- sdv/sampling/independent_sampler.py | 2 +- sdv/single_table/base.py | 5 +- sdv/single_table/copulagan.py | 3 +- sdv/single_table/copulas.py | 3 +- sdv/single_table/ctgan.py | 6 +- .../data_processing/test_data_processor.py | 27 ++-- .../evaluation/test_multi_table.py | 3 +- .../evaluation/test_single_table.py | 6 +- tests/integration/io/local/test_local.py | 8 +- tests/integration/lite/test_single_table.py | 14 +- .../metadata/test_visualization.py | 8 +- tests/integration/multi_table/test_hma.py | 59 +++++---- tests/integration/sequential/test_par.py | 54 ++++---- tests/integration/single_table/test_base.py | 123 ++++++++++-------- .../single_table/test_constraints.py | 110 +++++++++------- .../integration/single_table/test_copulas.py | 24 ++-- tests/integration/single_table/test_ctgan.py | 23 ++-- tests/integration/utils/test_poc.py | 4 +- tests/integration/utils/test_utils.py | 4 +- tests/unit/evaluation/test_multi_table.py | 18 +-- tests/unit/evaluation/test_single_table.py | 105 +++++++++------ tests/unit/io/local/test_local.py | 6 +- tests/unit/lite/test_single_table.py | 2 +- tests/unit/metadata/test_metadata.py | 14 +- tests/unit/multi_table/test_base.py | 19 +-- tests/unit/multi_table/test_hma.py | 116 ++++++++++++++++- tests/unit/multi_table/test_utils.py | 26 ++-- tests/unit/sequential/test_par.py | 45 ++++--- tests/unit/single_table/test_base.py | 84 ++++++------ tests/unit/single_table/test_copulagan.py | 46 +++---- tests/unit/single_table/test_copulas.py | 38 +++--- tests/unit/single_table/test_ctgan.py | 67 +++++----- tests/unit/utils/test_poc.py | 12 +- tests/utils.py | 4 +- 37 files changed, 670 insertions(+), 462 deletions(-) diff --git a/sdv/io/local/local.py b/sdv/io/local/local.py index ccba2f51d..ca827eddf 100644 --- a/sdv/io/local/local.py +++ b/sdv/io/local/local.py @@ -7,7 +7,7 @@ import pandas as pd -from sdv.metadata import MultiTableMetadata +from sdv.metadata.metadata import Metadata class BaseLocalHandler: @@ -29,7 +29,7 @@ def create_metadata(self, data): An ``sdv.metadata.MultiTableMetadata`` object with the detected metadata properties from the data. """ - metadata = MultiTableMetadata() + metadata = Metadata() metadata.detect_from_dataframes(data) return metadata diff --git a/sdv/lite/single_table.py b/sdv/lite/single_table.py index 7ad82bd1e..619ae234d 100644 --- a/sdv/lite/single_table.py +++ b/sdv/lite/single_table.py @@ -7,6 +7,7 @@ import cloudpickle +from sdv.metadata.metadata import Metadata from sdv.single_table import GaussianCopulaSynthesizer LOGGER = logging.getLogger(__name__) @@ -20,13 +21,19 @@ "functionality, please use the 'GaussianCopulaSynthesizer'." ) +META_DEPRECATION_MSG = ( + "The 'SingleTableMetadata' is deprecated. Please use the new " + "'Metadata' class for synthesizers." +) + class SingleTablePreset: """Class for all single table synthesizer presets. Args: - metadata (sdv.metadata.SingleTableMetadata): - ``SingleTableMetadata`` instance. + metadata (sdv.metadata.Metadata): + ``Metadata`` instance. + * sdv.metadata.SingleTableMetadata can be used but will be deprecated. name (str): The preset to use. locales (list or str): @@ -49,6 +56,10 @@ def __init__(self, metadata, name, locales=['en_US']): raise ValueError(f"'name' must be one of {PRESETS}.") self.name = name + if isinstance(metadata, Metadata): + metadata = metadata._convert_to_single_table() + else: + warnings.warn(META_DEPRECATION_MSG, FutureWarning) if name == FAST_ML_PRESET: self._setup_fast_preset(metadata, self.locales) diff --git a/sdv/metadata/metadata.py b/sdv/metadata/metadata.py index dcb1e0b4c..3e855a6eb 100644 --- a/sdv/metadata/metadata.py +++ b/sdv/metadata/metadata.py @@ -1,5 +1,7 @@ """Metadata.""" +import warnings + from sdv.metadata.errors import InvalidMetadataError from sdv.metadata.multi_table import MultiTableMetadata from sdv.metadata.single_table import SingleTableMetadata @@ -10,6 +12,7 @@ class Metadata(MultiTableMetadata): """Metadata class that handles all metadata.""" METADATA_SPEC_VERSION = 'V1' + DEFAULT_SINGLE_TABLE_NAME = 'default_table_name' @classmethod def load_from_json(cls, filepath, single_table_name=None): @@ -66,9 +69,29 @@ def _set_metadata_dict(self, metadata, single_table_name=None): super()._set_metadata_dict(metadata) else: if single_table_name is None: - single_table_name = 'default_table_name' + single_table_name = self.DEFAULT_SINGLE_TABLE_NAME self.tables[single_table_name] = SingleTableMetadata.load_from_dict(metadata) + def _get_single_table_name(self): + """Get the table name if there is only one table. + + Checks to see if the metadata contains only a single table, if so + return the name. Otherwise warn the user and return None. + + Args: + metadata (dict): + Python dictionary representing a ``MultiTableMetadata`` or + ``SingleTableMetadata`` object. + """ + if len(self.tables) != 1: + warnings.warn( + 'This metadata does not contain only a single table. Could not determine ' + 'single table name and will return None.' + ) + return None + + return next(iter(self.tables), None) + def _convert_to_single_table(self): if len(self.tables) > 1: raise InvalidMetadataError( diff --git a/sdv/sampling/independent_sampler.py b/sdv/sampling/independent_sampler.py index a86b0ddc0..9fc4507c7 100644 --- a/sdv/sampling/independent_sampler.py +++ b/sdv/sampling/independent_sampler.py @@ -98,7 +98,7 @@ def _finalize(self, sampled_data): final_data = {} for table_name, table_rows in sampled_data.items(): synthesizer = self._table_synthesizers.get(table_name) - metadata = synthesizer.get_metadata() + metadata = synthesizer.get_metadata()._convert_to_single_table() dtypes = synthesizer._data_processor._dtypes dtypes_to_sdtype = synthesizer._data_processor._DTYPE_TO_SDTYPE diff --git a/sdv/single_table/base.py b/sdv/single_table/base.py index d618a4c48..e18a3931d 100644 --- a/sdv/single_table/base.py +++ b/sdv/single_table/base.py @@ -112,7 +112,9 @@ def __init__( ): self._validate_inputs(enforce_min_max_values, enforce_rounding) self.metadata = metadata + self._table_name = Metadata.DEFAULT_SINGLE_TABLE_NAME if isinstance(metadata, Metadata): + self._table_name = metadata._get_single_table_name() self.metadata = metadata._convert_to_single_table() elif isinstance(metadata, SingleTableMetadata): warnings.warn(DEPRECATION_MSG, FutureWarning) @@ -284,7 +286,8 @@ def get_parameters(self): def get_metadata(self): """Return the ``Metadata`` for this synthesizer.""" - return Metadata.load_from_dict(self.metadata.to_dict()) + table_name = getattr(self, '_table_name', None) + return Metadata.load_from_dict(self.metadata.to_dict(), table_name) def load_custom_constraint_classes(self, filepath, class_names): """Load a custom constraint class for the current synthesizer. diff --git a/sdv/single_table/copulagan.py b/sdv/single_table/copulagan.py index 19ca50b2e..1713cef45 100644 --- a/sdv/single_table/copulagan.py +++ b/sdv/single_table/copulagan.py @@ -59,8 +59,9 @@ class CopulaGANSynthesizer(CTGANSynthesizer): Args: - metadata (sdv.metadata.SingleTableMetadata): + metadata (sdv.metadata.Metadata): Single table metadata representing the data that this synthesizer will be used for. + * sdv.metadata.SingleTableMetadata can be used but will be deprecated. enforce_min_max_values (bool): Specify whether or not to clip the data returned by ``reverse_transform`` of the numerical transformer, ``FloatFormatter``, to the min and max values seen diff --git a/sdv/single_table/copulas.py b/sdv/single_table/copulas.py index 26167a946..c47b67424 100644 --- a/sdv/single_table/copulas.py +++ b/sdv/single_table/copulas.py @@ -29,8 +29,9 @@ class GaussianCopulaSynthesizer(BaseSingleTableSynthesizer): """Model wrapping ``copulas.multivariate.GaussianMultivariate`` copula. Args: - metadata (sdv.metadata.SingleTableMetadata): + metadata (sdv.metadata.Metadata): Single table metadata representing the data that this synthesizer will be used for. + * sdv.metadata.SingleTableMetadata can be used but will be deprecated. enforce_min_max_values (bool): Specify whether or not to clip the data returned by ``reverse_transform`` of the numerical transformer, ``FloatFormatter``, to the min and max values seen diff --git a/sdv/single_table/ctgan.py b/sdv/single_table/ctgan.py index 30174f43a..1d03fe956 100644 --- a/sdv/single_table/ctgan.py +++ b/sdv/single_table/ctgan.py @@ -100,8 +100,9 @@ class CTGANSynthesizer(LossValuesMixin, BaseSingleTableSynthesizer): """Model wrapping ``CTGAN`` model. Args: - metadata (sdv.metadata.SingleTableMetadata): + metadata (sdv.metadata.Metadata): Single table metadata representing the data that this synthesizer will be used for. + * sdv.metadata.SingleTableMetadata can be used but will be deprecated. enforce_min_max_values (bool): Specify whether or not to clip the data returned by ``reverse_transform`` of the numerical transformer, ``FloatFormatter``, to the min and max values seen @@ -316,8 +317,9 @@ class TVAESynthesizer(LossValuesMixin, BaseSingleTableSynthesizer): """Model wrapping ``TVAE`` model. Args: - metadata (sdv.metadata.SingleTableMetadata): + metadata (sdv.metadata.Metadata): Single table metadata representing the data that this synthesizer will be used for. + * sdv.metadata.SingleTableMetadata can be used but will be deprecated. enforce_min_max_values (bool): Specify whether or not to clip the data returned by ``reverse_transform`` of the numerical transformer, ``FloatFormatter``, to the min and max values seen diff --git a/tests/integration/data_processing/test_data_processor.py b/tests/integration/data_processing/test_data_processor.py index a3217f712..dba25709b 100644 --- a/tests/integration/data_processing/test_data_processor.py +++ b/tests/integration/data_processing/test_data_processor.py @@ -23,6 +23,7 @@ from sdv.datasets.demo import download_demo from sdv.errors import SynthesizerInputError from sdv.metadata import SingleTableMetadata +from sdv.metadata.metadata import Metadata class TestDataProcessor: @@ -155,12 +156,12 @@ def test_with_primary_key_numerical(self): """ # Load metadata and data data, _ = download_demo('single_table', 'adult') - adult_metadata = SingleTableMetadata() - adult_metadata.detect_from_dataframe(data=data) + adult_metadata = Metadata() + adult_metadata.detect_from_dataframes({'adult': data}) # Add primary key field - adult_metadata.add_column('id', sdtype='id') - adult_metadata.set_primary_key('id') + adult_metadata.add_column('adult', 'id', sdtype='id') + adult_metadata.set_primary_key('adult', 'id') # Add id size = len(data) @@ -169,7 +170,7 @@ def test_with_primary_key_numerical(self): data['id'] = ids # Instance ``DataProcessor`` - dp = DataProcessor(adult_metadata) + dp = DataProcessor(adult_metadata._convert_to_single_table()) # Fit dp.fit(data) @@ -195,17 +196,17 @@ def test_with_alternate_keys(self): # Load metadata and data data, _ = download_demo('single_table', 'adult') data['fnlwgt'] = data['fnlwgt'].astype(str) - adult_metadata = SingleTableMetadata() - adult_metadata.detect_from_dataframe(data=data) + adult_metadata = Metadata() + adult_metadata.detect_from_dataframes({'adult': data}) # Add primary key field - adult_metadata.add_column('id', sdtype='id') - adult_metadata.set_primary_key('id') + adult_metadata.add_column('adult', 'id', sdtype='id') + adult_metadata.set_primary_key('adult', 'id') - adult_metadata.add_column('secondary_id', sdtype='id') - adult_metadata.update_column('fnlwgt', sdtype='id', regex_format='ID_\\d{4}[0-9]') + adult_metadata.add_column('adult', 'secondary_id', sdtype='id') + adult_metadata.update_column('adult', 'fnlwgt', sdtype='id', regex_format='ID_\\d{4}[0-9]') - adult_metadata.add_alternate_keys(['secondary_id', 'fnlwgt']) + adult_metadata.add_alternate_keys('adult', ['secondary_id', 'fnlwgt']) # Add id size = len(data) @@ -215,7 +216,7 @@ def test_with_alternate_keys(self): data['secondary_id'] = ids # Instance ``DataProcessor`` - dp = DataProcessor(adult_metadata) + dp = DataProcessor(adult_metadata._convert_to_single_table()) # Fit dp.fit(data) diff --git a/tests/integration/evaluation/test_multi_table.py b/tests/integration/evaluation/test_multi_table.py index 5f816fb87..7ca797435 100644 --- a/tests/integration/evaluation/test_multi_table.py +++ b/tests/integration/evaluation/test_multi_table.py @@ -2,7 +2,6 @@ from sdv.evaluation.multi_table import evaluate_quality, run_diagnostic from sdv.metadata.metadata import Metadata -from sdv.metadata.multi_table import MultiTableMetadata def test_evaluation(): @@ -18,7 +17,7 @@ def test_evaluation(): 'table1': table, 'table2': slightly_different_table, } - metadata = MultiTableMetadata().load_from_dict({ + metadata = Metadata().load_from_dict({ 'tables': { 'table1': { 'columns': { diff --git a/tests/integration/evaluation/test_single_table.py b/tests/integration/evaluation/test_single_table.py index 95e1b671e..8b3cf23fa 100644 --- a/tests/integration/evaluation/test_single_table.py +++ b/tests/integration/evaluation/test_single_table.py @@ -3,7 +3,6 @@ from sdv.datasets.demo import download_demo from sdv.evaluation.single_table import evaluate_quality, get_column_pair_plot, run_diagnostic from sdv.metadata.metadata import Metadata -from sdv.metadata.single_table import SingleTableMetadata from sdv.single_table.copulas import GaussianCopulaSynthesizer @@ -11,8 +10,9 @@ def test_evaluation(): """Test ``evaluate_quality`` and ``run_diagnostic``.""" # Setup data = pd.DataFrame({'col': [1, 2, 3]}) - metadata = SingleTableMetadata() - metadata.add_column('col', sdtype='numerical') + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('table', 'col', sdtype='numerical') synthesizer = GaussianCopulaSynthesizer(metadata, default_distribution='truncnorm') # Run and Assert diff --git a/tests/integration/io/local/test_local.py b/tests/integration/io/local/test_local.py index 2fa7e00c3..c52372f30 100644 --- a/tests/integration/io/local/test_local.py +++ b/tests/integration/io/local/test_local.py @@ -1,7 +1,7 @@ import pandas as pd from sdv.io.local import CSVHandler, ExcelHandler -from sdv.metadata import MultiTableMetadata +from sdv.metadata import Metadata class TestCSVHandler: @@ -27,7 +27,7 @@ def test_integration_write_and_read(self, tmpdir): assert len(data) == 2 assert 'table1' in data assert 'table2' in data - assert isinstance(metadata, MultiTableMetadata) is True + assert isinstance(metadata, Metadata) is True # Check if the dataframes match the original synthetic data pd.testing.assert_frame_equal(data['table1'], synthetic_data['table1']) @@ -57,7 +57,7 @@ def test_integration_write_and_read(self, tmpdir): assert len(data) == 2 assert 'table1' in data assert 'table2' in data - assert isinstance(metadata, MultiTableMetadata) is True + assert isinstance(metadata, Metadata) is True # Check if the dataframes match the original synthetic data pd.testing.assert_frame_equal(data['table1'], synthetic_data['table1']) @@ -91,7 +91,7 @@ def test_integration_write_and_read_append_mode(self, tmpdir): assert len(data) == 2 assert 'table1' in data assert 'table2' in data - assert isinstance(metadata, MultiTableMetadata) is True + assert isinstance(metadata, Metadata) is True # Check if the dataframes match the original synthetic data expected_table_one = pd.concat( diff --git a/tests/integration/lite/test_single_table.py b/tests/integration/lite/test_single_table.py index 90c66cb5d..ad6a42aea 100644 --- a/tests/integration/lite/test_single_table.py +++ b/tests/integration/lite/test_single_table.py @@ -3,7 +3,7 @@ import pytest from sdv.lite import SingleTablePreset -from sdv.metadata import SingleTableMetadata +from sdv.metadata.metadata import Metadata def test_sample(): @@ -12,8 +12,8 @@ def test_sample(): data = pd.DataFrame({'a': [1, 2, 3, np.nan]}) # Run - metadata = SingleTableMetadata() - metadata.detect_from_dataframe(data) + metadata = Metadata() + metadata.detect_from_dataframes({'adult': data}) preset = SingleTablePreset(metadata, name='FAST_ML') preset.fit(data) samples = preset.sample(num_rows=10, max_tries_per_batch=20, batch_size=5) @@ -29,8 +29,8 @@ def test_sample_with_constraints(): data = pd.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]}) # Run - metadata = SingleTableMetadata() - metadata.detect_from_dataframe(data) + metadata = Metadata() + metadata.detect_from_dataframes({'table': data}) preset = SingleTablePreset(metadata, name='FAST_ML') constraints = [ { @@ -57,8 +57,8 @@ def test_warnings_are_shown(): data = pd.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]}) # Run - metadata = SingleTableMetadata() - metadata.detect_from_dataframe(data) + metadata = Metadata() + metadata.detect_from_dataframes({'table': data}) with pytest.warns(FutureWarning, match=warn_message): preset = SingleTablePreset(metadata, name='FAST_ML') diff --git a/tests/integration/metadata/test_visualization.py b/tests/integration/metadata/test_visualization.py index 07cb870b6..eb8c6d657 100644 --- a/tests/integration/metadata/test_visualization.py +++ b/tests/integration/metadata/test_visualization.py @@ -1,6 +1,6 @@ import pandas as pd -from sdv.metadata import MultiTableMetadata, SingleTableMetadata +from sdv.metadata.metadata import Metadata from sdv.multi_table.hma import HMASynthesizer from sdv.single_table.copulas import GaussianCopulaSynthesizer @@ -9,8 +9,8 @@ def test_visualize_graph_for_single_table(): """Test it runs when a column name contains symbols.""" # Setup data = pd.DataFrame({'\\|=/bla@#$324%^,"&*()><...': ['a', 'b', 'c']}) - metadata = SingleTableMetadata() - metadata.detect_from_dataframe(data) + metadata = Metadata() + metadata.detect_from_dataframes({'table': data}) model = GaussianCopulaSynthesizer(metadata) # Run @@ -26,7 +26,7 @@ def test_visualize_graph_for_multi_table(): data1 = pd.DataFrame({'\\|=/bla@#$324%^,"&*()><...': ['a', 'b', 'c']}) data2 = pd.DataFrame({'\\|=/bla@#$324%^,"&*()><...': ['a', 'b', 'c']}) tables = {'1': data1, '2': data2} - metadata = MultiTableMetadata() + metadata = Metadata() metadata.detect_from_dataframes(tables) metadata.update_column('1', '\\|=/bla@#$324%^,"&*()><...', sdtype='id') metadata.update_column('2', '\\|=/bla@#$324%^,"&*()><...', sdtype='id') diff --git a/tests/integration/multi_table/test_hma.py b/tests/integration/multi_table/test_hma.py index d161c6fa2..4393481e3 100644 --- a/tests/integration/multi_table/test_hma.py +++ b/tests/integration/multi_table/test_hma.py @@ -20,7 +20,6 @@ from sdv.errors import SamplingError, SynthesizerInputError, VersionError from sdv.evaluation.multi_table import evaluate_quality, get_column_pair_plot, get_column_plot from sdv.metadata.metadata import Metadata -from sdv.metadata.multi_table import MultiTableMetadata from sdv.multi_table import HMASynthesizer from tests.integration.single_table.custom_constraints import MyConstraint from tests.utils import catch_sdv_logs @@ -125,7 +124,7 @@ def test_get_info(self): # Setup data = {'tab': pd.DataFrame({'col': [1, 2, 3]})} today = datetime.datetime.today().strftime('%Y-%m-%d') - metadata = MultiTableMetadata() + metadata = Metadata() metadata.add_table('tab') metadata.add_column('tab', 'col', sdtype='numerical') synthesizer = HMASynthesizer(metadata) @@ -220,7 +219,7 @@ def get_custom_constraint_data_and_metadata(self): 'numerical_col_2': [2, 4, 6], }) - metadata = MultiTableMetadata() + metadata = Metadata() metadata.detect_table_from_dataframe('parent', parent_data) metadata.update_column('parent', 'primary_key', sdtype='id') metadata.detect_table_from_dataframe('child', child_data) @@ -360,7 +359,7 @@ def test_hma_with_inequality_constraint(self): data = {'parent_table': parent_table, 'child_table': child_table} - metadata = MultiTableMetadata() + metadata = Metadata() metadata.detect_table_from_dataframe(table_name='parent_table', data=parent_table) metadata.update_column('parent_table', 'id', sdtype='id') metadata.detect_table_from_dataframe(table_name='child_table', data=child_table) @@ -449,7 +448,7 @@ def test_hma_primary_key_and_foreign_key_only(self): data = {'users': users, 'sessions': sessions, 'games': games} - metadata = MultiTableMetadata() + metadata = Metadata() for table_name, table in data.items(): metadata.detect_table_from_dataframe(table_name, table) @@ -585,7 +584,7 @@ def test_use_own_data_using_hma(self, tmp_path): assert datasets.keys() == {'guests', 'hotels'} # Metadata - metadata = MultiTableMetadata() + metadata = Metadata() metadata.detect_table_from_dataframe(table_name='guests', data=datasets['guests']) metadata.detect_table_from_dataframe(table_name='hotels', data=datasets['hotels']) @@ -676,7 +675,7 @@ def test_use_own_data_using_hma(self, tmp_path): # Save and load metadata metadata_path = tmp_path / 'metadata.json' metadata.save_to_json(metadata_path) - loaded_metadata = MultiTableMetadata.load_from_json(metadata_path) + loaded_metadata = Metadata.load_from_json(metadata_path) # Assert loaded metadata matches saved assert metadata.to_dict() == loaded_metadata.to_dict() @@ -768,7 +767,7 @@ def test_hma_three_linear_nodes(self): } ) data = {'grandparent': grandparent, 'parent': parent, 'child': child} - metadata = MultiTableMetadata.load_from_dict({ + metadata = Metadata.load_from_dict({ 'tables': { 'grandparent': { 'primary_key': 'grandparent_ID', @@ -847,7 +846,7 @@ def test_hma_one_parent_two_children(self): } ) data = {'parent': parent, 'child1': child1, 'child2': child2} - metadata = MultiTableMetadata.load_from_dict({ + metadata = Metadata.load_from_dict({ 'tables': { 'parent': { 'primary_key': 'parent_ID', @@ -920,7 +919,7 @@ def test_hma_two_parents_one_child(self): data={'parent_ID2': [0, 1, 2, 3, 4], 'data': ['Yes', 'Yes', 'Maybe', 'No', 'No']} ) data = {'parent1': parent1, 'child': child, 'parent2': parent2} - metadata = MultiTableMetadata.load_from_dict({ + metadata = Metadata.load_from_dict({ 'tables': { 'parent1': { 'primary_key': 'parent_ID1', @@ -1018,7 +1017,7 @@ def test_hma_two_lineages_one_grandchild(self): 'child2': child2, 'grandchild': grandchild, } - metadata = MultiTableMetadata.load_from_dict({ + metadata = Metadata.load_from_dict({ 'tables': { 'root1': { 'primary_key': 'id', @@ -1173,7 +1172,7 @@ def test__extract_parameters(self): '__sessions__user_id__loc': 0.5, '__sessions__user_id__scale': -0.25, }) - instance = HMASynthesizer(MultiTableMetadata()) + instance = HMASynthesizer(Metadata()) instance.extended_columns = { 'sessions': { '__sessions__user_id__num_rows': FloatFormatter(enforce_min_max_values=True), @@ -1206,7 +1205,7 @@ def test__recreate_child_synthesizer_with_default_parameters(self): f'{prefix}univariates__brand__loc': 0.5, f'{prefix}univariates__brand__scale': -0.25, }) - metadata = MultiTableMetadata.load_from_dict({ + metadata = Metadata.load_from_dict({ 'tables': { 'users': {'columns': {'user_id': {'sdtype': 'id'}}, 'primary_key': 'user_id'}, 'sessions': { @@ -1351,7 +1350,7 @@ def test_metadata_updated_warning_detect(self): def test_null_foreign_keys(self): """Test that the synthesizer does not crash when there are null foreign keys.""" # Setup - metadata = MultiTableMetadata() + metadata = Metadata() metadata.add_table('parent_table1') metadata.add_column('parent_table1', 'id', sdtype='id') metadata.set_primary_key('parent_table1', 'id') @@ -1456,7 +1455,7 @@ def test_sampling_with_unknown_sdtype_numerical_column(self): tables_dict = {'people': table1, 'company': table2} - metadata = MultiTableMetadata() + metadata = Metadata() metadata.detect_from_dataframes(tables_dict) # Run @@ -1508,7 +1507,7 @@ def test_hma_0_1_child(num_rows): ) data = {'parent': parent_table, 'child': pd.DataFrame(data=child_table_data)} - metadata = MultiTableMetadata.load_from_dict({ + metadata = Metadata.load_from_dict({ 'tables': { 'parent': { 'primary_key': 'id', @@ -1596,7 +1595,7 @@ def test_hma_0_1_grandparent(): }, ], } - metadata = MultiTableMetadata().load_from_dict(metadata_dict) + metadata = Metadata().load_from_dict(metadata_dict) metadata.validate() metadata.validate_data(data) synthesizer = HMASynthesizer(metadata=metadata, verbose=False) @@ -1635,7 +1634,7 @@ def test_metadata_updated_warning(method, kwargs): The warning should be raised during synthesizer initialization. """ - metadata = MultiTableMetadata().load_from_dict({ + metadata = Metadata().load_from_dict({ 'tables': { 'departure': { 'primary_key': 'id', @@ -1685,7 +1684,7 @@ def test_metadata_updated_warning(method, kwargs): def test_save_and_load_with_downgraded_version(tmp_path): """Test that synthesizers are raising errors if loaded on a downgraded version.""" # Setup - metadata = MultiTableMetadata().load_from_dict({ + metadata = Metadata().load_from_dict({ 'tables': { 'departure': { 'primary_key': 'id', @@ -1736,7 +1735,7 @@ def test_save_and_load_with_downgraded_version(tmp_path): def test_fit_raises_version_error(): """Test that a ``VersionError`` is being raised if the current version is newer.""" # Setup - metadata = MultiTableMetadata().load_from_dict({ + metadata = Metadata().load_from_dict({ 'tables': { 'departure': { 'primary_key': 'id', @@ -1828,7 +1827,7 @@ def test_fit_and_sample_numerical_col_names(): data['0'][1] = primary_key data['1'][1] = primary_key data['1'][2] = primary_key_2 - metadata = MultiTableMetadata() + metadata = Metadata() metadata_dict = {'tables': {}} for table_idx in range(num_tables): metadata_dict['tables'][str(table_idx)] = {'columns': {}} @@ -1844,7 +1843,7 @@ def test_fit_and_sample_numerical_col_names(): 'child_foreign_key': 2, } ] - metadata = MultiTableMetadata.load_from_dict(metadata_dict) + metadata = Metadata.load_from_dict(metadata_dict) metadata.set_primary_key('0', '1') # Run @@ -1875,7 +1874,7 @@ def test_detect_from_dataframe_numerical_col(): 'parent_data': parent_data, 'child_data': child_data, } - metadata = MultiTableMetadata() + metadata = Metadata() metadata.detect_table_from_dataframe('parent_data', parent_data) metadata.detect_table_from_dataframe('child_data', child_data) metadata.update_column('parent_data', '1', sdtype='id') @@ -1890,7 +1889,7 @@ def test_detect_from_dataframe_numerical_col(): child_table_name='child_data', ) - test_metadata = MultiTableMetadata() + test_metadata = Metadata() test_metadata.detect_from_dataframes(data) test_metadata.update_column('parent_data', '1', sdtype='id') test_metadata.update_column('child_data', '3', sdtype='id') @@ -1914,7 +1913,7 @@ def test_detect_from_dataframe_numerical_col(): assert sample['parent_data'].columns.tolist() == data['parent_data'].columns.tolist() assert sample['child_data'].columns.tolist() == data['child_data'].columns.tolist() - test_metadata = MultiTableMetadata() + test_metadata = Metadata() test_metadata.detect_from_dataframes(data) @@ -1930,7 +1929,7 @@ def test_table_name_logging(caplog): 'parent_data': parent_data, 'child_data': child_data, } - metadata = MultiTableMetadata() + metadata = Metadata() metadata.detect_from_dataframes(data) instance = HMASynthesizer(metadata) @@ -1952,7 +1951,7 @@ def test_disjointed_tables(): remove_some_dict = metadata.to_dict() half_list = remove_some_dict['relationships'][1::2] remove_some_dict['relationships'] = half_list - disjoined_metadata = MultiTableMetadata.load_from_dict(remove_some_dict) + disjoined_metadata = Metadata.load_from_dict(remove_some_dict) # Run disjoin_synthesizer = HMASynthesizer(disjoined_metadata) @@ -2009,7 +2008,7 @@ def test_hma_synthesizer_with_fixed_combinations(): } # Creating metadata for the dataset - metadata = MultiTableMetadata() + metadata = Metadata() metadata.detect_from_dataframes(data) metadata.update_column('users', 'user_id', sdtype='id') @@ -2059,7 +2058,7 @@ def test_fit_int_primary_key_regex_includes_zero(regex): 'parent_data': parent_data, 'child_data': child_data, } - metadata = MultiTableMetadata() + metadata = Metadata() metadata.detect_from_dataframes(data) metadata.update_column('parent_data', 'parent_id', sdtype='id', regex_format=regex) metadata.set_primary_key('parent_data', 'parent_id') @@ -2106,7 +2105,7 @@ def test__estimate_num_columns_to_be_modeled_various_sdtypes(): 'grandparent': grandparent, 'parent': parent, } - metadata = MultiTableMetadata.load_from_dict({ + metadata = Metadata.load_from_dict({ 'tables': { 'root1': { 'primary_key': 'R1', diff --git a/tests/integration/sequential/test_par.py b/tests/integration/sequential/test_par.py index 567586525..c9628a429 100644 --- a/tests/integration/sequential/test_par.py +++ b/tests/integration/sequential/test_par.py @@ -8,7 +8,7 @@ from sdv.datasets.demo import download_demo from sdv.errors import SynthesizerInputError -from sdv.metadata import SingleTableMetadata +from sdv.metadata.metadata import Metadata from sdv.sequential import PARSynthesizer @@ -21,11 +21,11 @@ def _get_par_data_and_metadata(): 'entity': [1, 1, 2, 2], 'context': ['a', 'a', 'b', 'b'], }) - metadata = SingleTableMetadata() - metadata.detect_from_dataframe(data) - metadata.update_column('entity', sdtype='id') - metadata.set_sequence_key('entity') - metadata.set_sequence_index('date') + metadata = Metadata() + metadata.detect_from_dataframes({'table': data}) + metadata.update_column('table', 'entity', sdtype='id') + metadata.set_sequence_key('table', 'entity') + metadata.set_sequence_index('table', 'date') return data, metadata @@ -34,11 +34,11 @@ def test_par(): # Setup data = load_demo() data['date'] = pd.to_datetime(data['date']) - metadata = SingleTableMetadata() - metadata.detect_from_dataframe(data) - metadata.update_column('store_id', sdtype='id') - metadata.set_sequence_key('store_id') - metadata.set_sequence_index('date') + metadata = Metadata() + metadata.detect_from_dataframes({'table': data}) + metadata.update_column('table', 'store_id', sdtype='id') + metadata.set_sequence_key('table', 'store_id') + metadata.set_sequence_index('table', 'date') model = PARSynthesizer( metadata=metadata, context_columns=['region'], @@ -68,11 +68,11 @@ def test_column_after_date_simple(): 'date': [date, date], 'col2': ['hello', 'world'], }) - metadata = SingleTableMetadata() - metadata.detect_from_dataframe(data) - metadata.update_column('col', sdtype='id') - metadata.set_sequence_key('col') - metadata.set_sequence_index('date') + metadata = Metadata() + metadata.detect_from_dataframes({'table': data}) + metadata.update_column('table', 'col', sdtype='id') + metadata.set_sequence_key('table', 'col') + metadata.set_sequence_index('table', 'date') # Run model = PARSynthesizer(metadata=metadata, epochs=1) @@ -114,7 +114,7 @@ def test_save_and_load(tmp_path): # Assert assert isinstance(loaded_instance, PARSynthesizer) - assert metadata == instance.metadata + assert metadata._convert_to_single_table().to_dict() == instance.metadata.to_dict() def test_synthesize_sequences(tmp_path): @@ -229,7 +229,7 @@ def test_par_subset_of_data_simplified(): 'date': ['2020-01-01', '2020-01-02', '2020-01-03'], }) data.index = [0, 1, 5] - metadata = SingleTableMetadata.load_from_dict({ + metadata = Metadata.load_from_dict({ 'sequence_index': 'date', 'sequence_key': 'id', 'columns': { @@ -261,7 +261,7 @@ def test_par_missing_sequence_index(): 'sequence_key': 'e_id', } - metadata = SingleTableMetadata().load_from_dict(metadata_dict) + metadata = Metadata().load_from_dict(metadata_dict) data = pd.DataFrame({'value': [10, 20, 30], 'e_id': [1, 2, 3]}) @@ -348,11 +348,11 @@ def test_par_unique_sequence_index_with_enforce_min_max(): test_df[['visits', 'pre_date']] = test_df[['visits', 'pre_date']].apply( pd.to_datetime, format='%Y-%m-%d', errors='coerce' ) - metadata = SingleTableMetadata() - metadata.detect_from_dataframe(test_df) - metadata.update_column(column_name='s_key', sdtype='id') - metadata.set_sequence_key('s_key') - metadata.set_sequence_index('visits') + metadata = Metadata() + metadata.detect_from_dataframes({'table': test_df}) + metadata.update_column(table_name='table', column_name='s_key', sdtype='id') + metadata.set_sequence_key('table', 's_key') + metadata.set_sequence_index('table', 'visits') synthesizer = PARSynthesizer( metadata, enforce_min_max_values=True, enforce_rounding=False, epochs=100, verbose=True ) @@ -378,7 +378,7 @@ def test_par_sequence_index_is_numerical(): 'sequence_key': 'engine_no', 'METADATA_SPEC_VERSION': 'SINGLE_TABLE_V1', } - metadata = SingleTableMetadata.load_from_dict(metadata_dict) + metadata = Metadata.load_from_dict(metadata_dict) data = pd.DataFrame({'engine_no': [0, 0, 1, 1], 'time_in_cycles': [1, 2, 3, 4]}) s1 = PARSynthesizer(metadata) @@ -396,7 +396,7 @@ def test_init_error_sequence_key_in_context(): }, 'sequence_key': 'A', } - metadata = SingleTableMetadata.load_from_dict(metadata_dict) + metadata = Metadata.load_from_dict(metadata_dict) sequence_key_context_column_error_msg = re.escape( "The sequence key ['A'] cannot be a context column. " 'To proceed, please remove the sequence key from the context_columns parameter.' @@ -418,7 +418,7 @@ def test_par_with_datetime_context(): } ) - metadata = SingleTableMetadata.load_from_dict({ + metadata = Metadata.load_from_dict({ 'columns': { 'user_id': {'sdtype': 'id', 'regex_format': 'ID_[0-9]{2}'}, 'birthdate': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'}, diff --git a/tests/integration/single_table/test_base.py b/tests/integration/single_table/test_base.py index 4832a2746..5f99abf93 100644 --- a/tests/integration/single_table/test_base.py +++ b/tests/integration/single_table/test_base.py @@ -23,7 +23,7 @@ ) from sdv.single_table.base import BaseSingleTableSynthesizer -METADATA = SingleTableMetadata.load_from_dict({ +METADATA = Metadata.load_from_dict({ 'METADATA_SPEC_VERSION': 'SINGLE_TABLE_V1', 'columns': { 'column1': {'sdtype': 'numerical'}, @@ -91,10 +91,11 @@ def test_sample_from_conditions_with_batch_size(): 'column3': list(range(100)), }) - metadata = SingleTableMetadata() - metadata.add_column('column1', sdtype='numerical') - metadata.add_column('column2', sdtype='numerical') - metadata.add_column('column3', sdtype='numerical') + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('table', 'column1', sdtype='numerical') + metadata.add_column('table', 'column2', sdtype='numerical') + metadata.add_column('table', 'column3', sdtype='numerical') model = GaussianCopulaSynthesizer(metadata) model.fit(data) @@ -117,10 +118,11 @@ def test_sample_from_conditions_negative_float(): 'column3': list(range(100)), }) - metadata = SingleTableMetadata() - metadata.add_column('column1', sdtype='numerical') - metadata.add_column('column2', sdtype='numerical') - metadata.add_column('column3', sdtype='numerical') + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('table', 'column1', sdtype='numerical') + metadata.add_column('table', 'column2', sdtype='numerical') + metadata.add_column('table', 'column3', sdtype='numerical') model = GaussianCopulaSynthesizer(metadata) model.fit(data) @@ -230,10 +232,11 @@ def test_multiple_fits(): 'state': ['CA', 'CA', 'IL', 'CA', 'CA'], 'measurement': [27.1, 28.7, 26.9, 21.2, 30.9], }) - metadata = SingleTableMetadata() - metadata.add_column('city', sdtype='categorical') - metadata.add_column('state', sdtype='categorical') - metadata.add_column('measurement', sdtype='numerical') + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('table', 'city', sdtype='categorical') + metadata.add_column('table', 'state', sdtype='categorical') + metadata.add_column('table', 'measurement', sdtype='numerical') constraint = { 'constraint_class': 'FixedCombinations', 'constraint_parameters': {'column_names': ['city', 'state']}, @@ -273,7 +276,7 @@ def test_sampling(synthesizer): @pytest.mark.parametrize('synthesizer', SYNTHESIZERS) def test_sampling_reset_sampling(synthesizer): """Test ``sample`` method for each synthesizer using ``reset_sampling``.""" - metadata = SingleTableMetadata.load_from_dict({ + metadata = Metadata.load_from_dict({ 'METADATA_SPEC_VERSION': 'SINGLE_TABLE_V1', 'columns': { 'column1': {'sdtype': 'numerical'}, @@ -310,11 +313,13 @@ def test_config_creation_doesnt_raise_error(): 'address_col': ['223 Williams Rd', '75 Waltham St', '77 Mass Ave'], 'numerical_col': [1, 2, 3], }) - test_metadata = SingleTableMetadata() + test_metadata = Metadata() # Run - test_metadata.detect_from_dataframe(test_data) - test_metadata.update_column(column_name='address_col', sdtype='address', pii=False) + test_metadata.detect_from_dataframes({'table': test_data}) + test_metadata.update_column( + table_name='table', column_name='address_col', sdtype='address', pii=False + ) synthesizer = GaussianCopulaSynthesizer(test_metadata) synthesizer.fit(test_data) @@ -330,11 +335,13 @@ def test_transformers_correctly_auto_assigned(): 'categorical_col': ['a', 'b', 'a'], }) - metadata = SingleTableMetadata() - metadata.detect_from_dataframe(data) - metadata.update_column(column_name='primary_key', sdtype='id', regex_format='user-[0-9]{3}') - metadata.set_primary_key('primary_key') - metadata.update_column(column_name='pii_col', sdtype='address', pii=True) + metadata = Metadata() + metadata.detect_from_dataframes({'table': data}) + metadata.update_column( + table_name='table', column_name='primary_key', sdtype='id', regex_format='user-[0-9]{3}' + ) + metadata.set_primary_key('table', 'primary_key') + metadata.update_column(table_name='table', column_name='pii_col', sdtype='address', pii=True) synthesizer = GaussianCopulaSynthesizer( metadata, enforce_min_max_values=False, enforce_rounding=False ) @@ -391,7 +398,7 @@ def test_modeling_with_complex_datetimes(): } # Run - metadata = SingleTableMetadata.load_from_dict(test_metadata) + metadata = Metadata.load_from_dict(test_metadata) metadata.validate() synth = GaussianCopulaSynthesizer(metadata) synth.validate(data) @@ -418,13 +425,13 @@ def test_auto_assign_transformers_and_update_with_pii(): } ) - metadata = SingleTableMetadata() - metadata.detect_from_dataframe(data) + metadata = Metadata() + metadata.detect_from_dataframes({'table': data}) # Run - metadata.update_column(column_name='id', sdtype='first_name') - metadata.update_column(column_name='name', sdtype='name') - metadata.set_primary_key('id') + metadata.update_column(table_name='table', column_name='id', sdtype='first_name') + metadata.update_column(table_name='table', column_name='name', sdtype='name') + metadata.set_primary_key('table', 'id') synthesizer = GaussianCopulaSynthesizer(metadata) synthesizer.auto_assign_transformers(data) @@ -451,11 +458,11 @@ def test_refitting_a_model(): } ) - metadata = SingleTableMetadata() - metadata.detect_from_dataframe(data) - metadata.update_column(column_name='name', sdtype='name') - metadata.update_column('id', sdtype='id') - metadata.set_primary_key('id') + metadata = Metadata() + metadata.detect_from_dataframes({'table': data}) + metadata.update_column(table_name='table', column_name='name', sdtype='name') + metadata.update_column('table', 'id', sdtype='id') + metadata.set_primary_key('table', 'id') synthesizer = GaussianCopulaSynthesizer(metadata) synthesizer.fit(data) @@ -478,8 +485,9 @@ def test_get_info(): # Setup data = pd.DataFrame({'col': [1, 2, 3]}) today = datetime.datetime.today().strftime('%Y-%m-%d') - metadata = SingleTableMetadata() - metadata.add_column('col', sdtype='numerical') + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('table', 'col', sdtype='numerical') synthesizer = GaussianCopulaSynthesizer(metadata) # Run @@ -512,7 +520,7 @@ def test_get_info(): def test_save_and_load(tmp_path): """Test that synthesizers can be saved and loaded properly.""" # Setup - metadata = SingleTableMetadata() + metadata = Metadata() instance = BaseSingleTableSynthesizer(metadata) synthesizer_path = tmp_path / 'synthesizer.pkl' instance.save(synthesizer_path) @@ -534,7 +542,7 @@ def test_save_and_load(tmp_path): def test_save_and_load_no_id(tmp_path): """Test that synthesizers can be saved and loaded properly.""" # Setup - metadata = SingleTableMetadata() + metadata = Metadata() instance = BaseSingleTableSynthesizer(metadata) synthesizer_path = tmp_path / 'synthesizer.pkl' delattr(instance, '_synthesizer_id') @@ -560,7 +568,7 @@ def test_save_and_load_no_id(tmp_path): def test_save_and_load_with_downgraded_version(tmp_path): """Test that synthesizers are raising errors if loaded on a downgraded version.""" # Setup - metadata = SingleTableMetadata() + metadata = Metadata() instance = BaseSingleTableSynthesizer(metadata) instance._fitted = True instance._fitted_sdv_version = '10.0.0' @@ -693,13 +701,17 @@ def test_metadata_updated_warning(method, kwargs): The warning should be raised during synthesizer initialization. """ # Setup - metadata = SingleTableMetadata().load_from_dict({ - 'columns': { - 'col 1': {'sdtype': 'id'}, - 'col 2': {'sdtype': 'id'}, - 'col 3': {'sdtype': 'categorical'}, - 'city': {'sdtype': 'city'}, - 'country': {'sdtype': 'country_code'}, + metadata = Metadata().load_from_dict({ + 'tables': { + 'table': { + 'columns': { + 'col 1': {'sdtype': 'id'}, + 'col 2': {'sdtype': 'id'}, + 'col 3': {'sdtype': 'categorical'}, + 'city': {'sdtype': 'city'}, + 'country': {'sdtype': 'country_code'}, + } + } } }) expected_message = re.escape( @@ -708,12 +720,13 @@ def test_metadata_updated_warning(method, kwargs): ) # Run - metadata.__getattribute__(method)(**kwargs) + single_metadata = metadata._convert_to_single_table() + single_metadata.__getattribute__(method)(**kwargs) with pytest.warns(UserWarning, match=expected_message): - BaseSingleTableSynthesizer(metadata) + BaseSingleTableSynthesizer(single_metadata) # Assert - assert metadata._updated is False + assert single_metadata._updated is False def test_fit_raises_version_error(): @@ -724,8 +737,8 @@ def test_fit_raises_version_error(): 'col 2': [4, 5, 6], 'col 3': ['a', 'b', 'c'], }) - metadata = SingleTableMetadata() - metadata.detect_from_dataframe(data) + metadata = Metadata() + metadata.detect_from_dataframes({'table': data}) instance = BaseSingleTableSynthesizer(metadata) instance._fitted_sdv_version = '1.0.0' @@ -755,11 +768,11 @@ def test_fit_and_sample_numerical_col_names(synthesizer_class): num_cols = 10 values = {i: np.random.randint(0, 100, size=num_rows) for i in range(num_cols)} data = pd.DataFrame(values) - metadata = SingleTableMetadata() + metadata = Metadata() metadata_dict = {'columns': {}} for i in range(num_cols): metadata_dict['columns'][i] = {'sdtype': 'numerical'} - metadata = SingleTableMetadata.load_from_dict(metadata_dict) + metadata = Metadata.load_from_dict(metadata_dict) # Run synth = synthesizer_class(metadata) @@ -779,7 +792,7 @@ def test_fit_and_sample_numerical_col_names(synthesizer_class): def test_sample_not_fitted(synthesizer): """Test that a synthesizer raises an error when trying to sample without fitting.""" # Setup - metadata = SingleTableMetadata() + metadata = Metadata() synthesizer = synthesizer.__class__(metadata) expected_message = re.escape( 'This synthesizer has not been fitted. Please fit your synthesizer first before' @@ -800,10 +813,10 @@ def test_detect_from_dataframe_numerical_col(synthesizer_class): 2: [4, 5, 6], 3: ['a', 'b', 'c'], }) - metadata = SingleTableMetadata() + metadata = Metadata() # Run - metadata.detect_from_dataframe(data) + metadata.detect_from_dataframes({'table': data}) instance = synthesizer_class(metadata) instance.fit(data) sample = instance.sample(5) diff --git a/tests/integration/single_table/test_constraints.py b/tests/integration/single_table/test_constraints.py index 5a70ffea8..e42cdb227 100644 --- a/tests/integration/single_table/test_constraints.py +++ b/tests/integration/single_table/test_constraints.py @@ -11,7 +11,7 @@ from sdv.constraints import Constraint, create_custom_constraint_class from sdv.constraints.errors import AggregateConstraintsError from sdv.datasets.demo import download_demo -from sdv.metadata import SingleTableMetadata +from sdv.metadata.metadata import Metadata from sdv.sampling import Condition from sdv.single_table import GaussianCopulaSynthesizer from tests.integration.single_table.custom_constraints import MyConstraint @@ -72,10 +72,11 @@ def test_fit_with_unique_constraint_on_data_with_only_index_column(): ], }) - metadata = SingleTableMetadata() - metadata.add_column('key', sdtype='id') - metadata.add_column('index', sdtype='categorical') - metadata.set_primary_key('key') + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('table', 'key', sdtype='id') + metadata.add_column('table', 'index', sdtype='categorical') + metadata.set_primary_key('table', 'key') model = GaussianCopulaSynthesizer(metadata) constraint = { @@ -136,11 +137,12 @@ def test_fit_with_unique_constraint_on_data_which_has_index_column(): ], }) - metadata = SingleTableMetadata() - metadata.add_column('key', sdtype='id') - metadata.add_column('index', sdtype='categorical') - metadata.add_column('test_column', sdtype='categorical') - metadata.set_primary_key('key') + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('table', 'key', sdtype='id') + metadata.add_column('table', 'index', sdtype='categorical') + metadata.add_column('table', 'test_column', sdtype='categorical') + metadata.set_primary_key('table', 'key') model = GaussianCopulaSynthesizer(metadata) constraint = { @@ -194,10 +196,11 @@ def test_fit_with_unique_constraint_on_data_subset(): ], }) - metadata = SingleTableMetadata() - metadata.add_column('key', sdtype='id') - metadata.add_column('test_column', sdtype='categorical') - metadata.set_primary_key('key') + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('table', 'key', sdtype='id') + metadata.add_column('table', 'test_column', sdtype='categorical') + metadata.set_primary_key('table', 'key') test_df = test_df.iloc[[1, 3, 4]] constraint = { @@ -227,7 +230,7 @@ def test_conditional_sampling_with_constraints(): } ) - metadata = SingleTableMetadata.load_from_dict({ + metadata = Metadata.load_from_dict({ 'columns': { 'A': {'sdtype': 'numerical'}, 'B': {'sdtype': 'numerical'}, @@ -290,10 +293,11 @@ def test_conditional_sampling_constraint_uses_reject_sampling(gm_mock, isinstanc 'age': [27, 28, 26, 21, 30], }) - metadata = SingleTableMetadata() - metadata.add_column('city', sdtype='categorical') - metadata.add_column('state', sdtype='categorical') - metadata.add_column('age', sdtype='numerical') + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('table', 'city', sdtype='categorical') + metadata.add_column('table', 'state', sdtype='categorical') + metadata.add_column('table', 'age', sdtype='numerical') model = GaussianCopulaSynthesizer(metadata) @@ -335,9 +339,9 @@ def test_custom_constraints_from_file(tmpdir): 'categorical_col': ['a', 'b', 'a'], }) - metadata = SingleTableMetadata() - metadata.detect_from_dataframe(data) - metadata.update_column(column_name='pii_col', sdtype='address', pii=True) + metadata = Metadata() + metadata.detect_from_dataframes({'table': data}) + metadata.update_column(table_name='table', column_name='pii_col', sdtype='address', pii=True) synthesizer = GaussianCopulaSynthesizer( metadata, enforce_min_max_values=False, enforce_rounding=False ) @@ -379,9 +383,9 @@ def test_custom_constraints_from_object(tmpdir): 'categorical_col': ['a', 'b', 'a'], }) - metadata = SingleTableMetadata() - metadata.detect_from_dataframe(data) - metadata.update_column(column_name='pii_col', sdtype='address', pii=True) + metadata = Metadata() + metadata.detect_from_dataframes({'table': data}) + metadata.update_column(table_name='table', column_name='pii_col', sdtype='address', pii=True) synthesizer = GaussianCopulaSynthesizer( metadata, enforce_min_max_values=False, enforce_rounding=False ) @@ -444,7 +448,7 @@ def test_inequality_constraint_with_datetimes_and_nones(): } ) - metadata = SingleTableMetadata.load_from_dict({ + metadata = Metadata.load_from_dict({ 'columns': { 'A': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'}, 'B': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'}, @@ -505,7 +509,7 @@ def test_scalar_inequality_constraint_with_datetimes_and_nones(): } ) - metadata = SingleTableMetadata.load_from_dict({ + metadata = Metadata.load_from_dict({ 'columns': { 'A': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'}, 'B': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'}, @@ -556,7 +560,7 @@ def test_scalar_range_constraint_with_datetimes_and_nones(): } ) - metadata = SingleTableMetadata.load_from_dict({ + metadata = Metadata.load_from_dict({ 'columns': { 'A': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'}, 'B': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'}, @@ -623,7 +627,7 @@ def test_range_constraint_with_datetimes_and_nones(): } ) - metadata = SingleTableMetadata.load_from_dict({ + metadata = Metadata.load_from_dict({ 'columns': { 'A': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'}, 'B': {'sdtype': 'datetime', 'datetime_format': '%Y-%m-%d'}, @@ -697,7 +701,7 @@ def test_inequality_constraint_all_possible_nans_configurations(): # Setup data = pd.DataFrame(data={'A': [0, 1, np.nan, np.nan, 2], 'B': [2, np.nan, 3, np.nan, 3]}) - metadata = SingleTableMetadata.load_from_dict({ + metadata = Metadata.load_from_dict({ 'columns': { 'A': {'sdtype': 'numerical'}, 'B': {'sdtype': 'numerical'}, @@ -742,7 +746,7 @@ def test_range_constraint_all_possible_nans_configurations(): } } - metadata = SingleTableMetadata.load_from_dict(metadata_dict) + metadata = Metadata.load_from_dict(metadata_dict) synthesizer = GaussianCopulaSynthesizer(metadata) my_constraint = { @@ -812,10 +816,10 @@ def reverse_transform(column_names, data): 'number': ['1', '2', '3'], 'other': [7, 8, 9], }) - metadata = SingleTableMetadata() - metadata.detect_from_dataframe(data) - metadata.update_column('key', sdtype='id', regex_format=r'\w_\d') - metadata.set_primary_key('key') + metadata = Metadata() + metadata.detect_from_dataframes({'table': data}) + metadata.update_column('table', 'key', sdtype='id', regex_format=r'\w_\d') + metadata.set_primary_key('table', 'key') synth = GaussianCopulaSynthesizer(metadata) synth.add_custom_constraint_class(custom_constraint, 'custom') @@ -842,9 +846,10 @@ def test_timezone_aware_constraints(): data['col1'] = pd.to_datetime(data['col1']).dt.tz_localize('UTC') data['col2'] = pd.to_datetime(data['col2']).dt.tz_localize('UTC') - metadata = SingleTableMetadata() - metadata.add_column('col1', sdtype='datetime') - metadata.add_column('col2', sdtype='datetime') + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('table', 'col1', sdtype='datetime') + metadata.add_column('table', 'col2', sdtype='datetime') my_constraint = { 'constraint_class': 'Inequality', @@ -960,18 +965,25 @@ def _transform(self, table_data): def test_constraint_datetime_check(): """Test datetime columns are correctly identified in constraints. GH#1692""" # Setup - data = pd.DataFrame( - data={ - 'low_col': ['21 Sep, 15', '23 Aug, 14', '29 May, 12'], - 'high_col': ['02 Nov, 15', '12 Oct, 14', '08 Jul, 12'], - } - ) - metadata = SingleTableMetadata.load_from_dict({ - 'columns': { - 'low_col': {'sdtype': 'datetime', 'datetime_format': '%d %b, %y'}, - 'high_col': {'sdtype': 'datetime', 'datetime_format': '%d %b, %y'}, + data = { + 'table': pd.DataFrame( + data={ + 'low_col': ['21 Sep, 15', '23 Aug, 14', '29 May, 12'], + 'high_col': ['02 Nov, 15', '12 Oct, 14', '08 Jul, 12'], + } + ) + } + metadata = Metadata.load_from_dict({ + 'tables': { + 'table': { + 'columns': { + 'low_col': {'sdtype': 'datetime', 'datetime_format': '%d %b, %y'}, + 'high_col': {'sdtype': 'datetime', 'datetime_format': '%d %b, %y'}, + } + } } }) + my_constraint = { 'constraint_class': 'Inequality', 'constraint_parameters': { @@ -987,7 +999,7 @@ def test_constraint_datetime_check(): synth = GaussianCopulaSynthesizer(metadata) synth.add_constraints([my_constraint]) - synth.fit(data) + synth.fit(data['table']) samples = synth.sample(3) # Assert diff --git a/tests/integration/single_table/test_copulas.py b/tests/integration/single_table/test_copulas.py index 0a8477409..5e46a8417 100644 --- a/tests/integration/single_table/test_copulas.py +++ b/tests/integration/single_table/test_copulas.py @@ -15,7 +15,7 @@ from sdv.datasets.demo import download_demo from sdv.errors import ConstraintsNotMetError from sdv.evaluation.single_table import evaluate_quality, get_column_pair_plot, get_column_plot -from sdv.metadata import SingleTableMetadata +from sdv.metadata.metadata import Metadata from sdv.sampling import Condition from sdv.single_table import GaussianCopulaSynthesizer @@ -277,10 +277,10 @@ def test_update_transformers_with_id_generator(): sample_num = 20 data = pd.DataFrame({'user_id': list(range(4)), 'user_cat': ['a', 'b', 'c', 'd']}) - stm = SingleTableMetadata() - stm.detect_from_dataframe(data) - stm.update_column('user_id', sdtype='id') - stm.set_primary_key('user_id') + stm = Metadata() + stm.detect_from_dataframes({'table': data}) + stm.update_column('table', 'user_id', sdtype='id') + stm.set_primary_key('table', 'user_id') gc = GaussianCopulaSynthesizer(stm) custom_id = IDGenerator(starting_value=min_value_id) @@ -332,7 +332,7 @@ def test_numerical_columns_gets_pii(): data = pd.DataFrame( data={'id': [0, 1, 2, 3, 4], 'city': [0, 0, 0, 0, 0], 'numerical': [21, 22, 23, 24, 25]} ) - metadata = SingleTableMetadata.load_from_dict({ + metadata = Metadata.load_from_dict({ 'primary_key': 'id', 'columns': { 'id': {'sdtype': 'id'}, @@ -406,8 +406,8 @@ def test_categorical_column_with_numbers(): 'numerical_col': np.random.rand(20), }) - metadata = SingleTableMetadata() - metadata.detect_from_dataframe(data) + metadata = Metadata() + metadata.detect_from_dataframes({'table': data}) synthesizer = GaussianCopulaSynthesizer(metadata) @@ -435,9 +435,9 @@ def test_unknown_sdtype(): 'numerical_col': np.random.rand(3), }) - metadata = SingleTableMetadata() - metadata.detect_from_dataframe(data) - metadata.update_column('unknown', sdtype='unknown') + metadata = Metadata() + metadata.detect_from_dataframes({'table': data}) + metadata.update_column('table', 'unknown', sdtype='unknown') synthesizer = GaussianCopulaSynthesizer(metadata) @@ -482,7 +482,7 @@ def test_support_nullable_pandas_dtypes(): 'Float32': pd.Series([1.1, 2.2, 3.3, pd.NA], dtype='Float32'), 'Float64': pd.Series([1.113, 2.22, 3.3, pd.NA], dtype='Float64'), }) - metadata = SingleTableMetadata().load_from_dict({ + metadata = Metadata().load_from_dict({ 'columns': { 'Int8': {'sdtype': 'numerical', 'computer_representation': 'Int8'}, 'Int16': {'sdtype': 'numerical', 'computer_representation': 'Int16'}, diff --git a/tests/integration/single_table/test_ctgan.py b/tests/integration/single_table/test_ctgan.py index 26bc906b2..9f892f878 100644 --- a/tests/integration/single_table/test_ctgan.py +++ b/tests/integration/single_table/test_ctgan.py @@ -8,20 +8,21 @@ from sdv.datasets.demo import download_demo from sdv.errors import InvalidDataTypeError from sdv.evaluation.single_table import evaluate_quality, get_column_pair_plot, get_column_plot -from sdv.metadata import SingleTableMetadata +from sdv.metadata.metadata import Metadata from sdv.single_table import CTGANSynthesizer, TVAESynthesizer def test__estimate_num_columns(): """Test the number of columns is estimated correctly.""" # Setup - metadata = SingleTableMetadata() - metadata.add_column('numerical', sdtype='numerical') - metadata.add_column('categorical', sdtype='categorical') - metadata.add_column('categorical2', sdtype='categorical') - metadata.add_column('categorical3', sdtype='categorical') - metadata.add_column('datetime', sdtype='datetime') - metadata.add_column('boolean', sdtype='boolean') + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('table', 'numerical', sdtype='numerical') + metadata.add_column('table', 'categorical', sdtype='categorical') + metadata.add_column('table', 'categorical2', sdtype='categorical') + metadata.add_column('table', 'categorical3', sdtype='categorical') + metadata.add_column('table', 'datetime', sdtype='datetime') + metadata.add_column('table', 'boolean', sdtype='boolean') data = pd.DataFrame({ 'numerical': [0.1, 0.2, 0.3], 'datetime': ['2020-01-01', '2020-01-02', '2020-01-03'], @@ -134,7 +135,7 @@ def test_categoricals_are_not_preprocessed(): 'alcohol': ['medium', 'medium', 'low', 'high', 'low'], } ) - metadata = SingleTableMetadata.load_from_dict({ + metadata = Metadata.load_from_dict({ 'columns': { 'age': {'sdtype': 'numerical'}, 'therapy': {'sdtype': 'boolean'}, @@ -180,7 +181,7 @@ def test_categorical_metadata_with_int_data(): }, } - metadata = SingleTableMetadata.load_from_dict(metadata_dict) + metadata = Metadata.load_from_dict(metadata_dict) data = pd.DataFrame({ 'A': list(range(50)), 'B': list(range(50)), @@ -270,7 +271,7 @@ def test_ctgan_with_dropped_columns(): 'columns': {'user_id': {'sdtype': 'id'}, 'user_ssn': {'sdtype': 'ssn'}}, } - metadata = SingleTableMetadata.load_from_dict(metadata_dict) + metadata = Metadata.load_from_dict(metadata_dict) # Run synth = CTGANSynthesizer(metadata) diff --git a/tests/integration/utils/test_poc.py b/tests/integration/utils/test_poc.py index b3dfcffdc..b7b3247e8 100644 --- a/tests/integration/utils/test_poc.py +++ b/tests/integration/utils/test_poc.py @@ -7,7 +7,7 @@ import pytest from sdv.datasets.demo import download_demo -from sdv.metadata import MultiTableMetadata +from sdv.metadata.metadata import Metadata from sdv.multi_table.hma import MAX_NUMBER_OF_COLUMNS, HMASynthesizer from sdv.multi_table.utils import _get_total_estimated_columns from sdv.utils.poc import get_random_subset, simplify_schema @@ -15,7 +15,7 @@ @pytest.fixture def metadata(): - return MultiTableMetadata.load_from_dict({ + return Metadata.load_from_dict({ 'tables': { 'parent': { 'columns': { diff --git a/tests/integration/utils/test_utils.py b/tests/integration/utils/test_utils.py index 8c3498311..ad381784a 100644 --- a/tests/integration/utils/test_utils.py +++ b/tests/integration/utils/test_utils.py @@ -7,13 +7,13 @@ from sdv.datasets.demo import download_demo from sdv.errors import InvalidDataError -from sdv.metadata import MultiTableMetadata +from sdv.metadata.metadata import Metadata from sdv.utils import drop_unknown_references, get_random_sequence_subset @pytest.fixture def metadata(): - return MultiTableMetadata.load_from_dict({ + return Metadata.load_from_dict({ 'tables': { 'parent': { 'columns': { diff --git a/tests/unit/evaluation/test_multi_table.py b/tests/unit/evaluation/test_multi_table.py index 2995f9dd5..2ecdf52d1 100644 --- a/tests/unit/evaluation/test_multi_table.py +++ b/tests/unit/evaluation/test_multi_table.py @@ -11,7 +11,7 @@ get_column_plot, run_diagnostic, ) -from sdv.metadata.multi_table import MultiTableMetadata +from sdv.metadata.metadata import Metadata def test_evaluate_quality(): @@ -20,7 +20,7 @@ def test_evaluate_quality(): table = pd.DataFrame({'col': [1, 2, 3]}) data1 = {'table': table} data2 = {'table': pd.DataFrame({'col': [2, 1, 3]})} - metadata = MultiTableMetadata() + metadata = Metadata() metadata.detect_table_from_dataframe('table', table) QualityReport.generate = Mock() @@ -37,7 +37,7 @@ def test_run_diagnostic(): table = pd.DataFrame({'col': [1, 2, 3]}) data1 = {'table': table} data2 = {'table': pd.DataFrame({'col': [2, 1, 3]})} - metadata = MultiTableMetadata() + metadata = Metadata() metadata.detect_table_from_dataframe('table', table) DiagnosticReport.generate = Mock() @@ -60,7 +60,7 @@ def test_get_column_plot(mock_plot): table2 = pd.DataFrame({'col': [2, 1, 3]}) data1 = {'table': table1} data2 = {'table': table2} - metadata = MultiTableMetadata() + metadata = Metadata() metadata.detect_table_from_dataframe('table', table1) mock_plot.return_value = 'plot' @@ -79,7 +79,7 @@ def test_get_column_plot_only_real_or_synthetic(mock_plot): # Setup table1 = pd.DataFrame({'col': [1, 2, 3]}) data1 = {'table': table1} - metadata = MultiTableMetadata() + metadata = Metadata() metadata.detect_table_from_dataframe('table', table1) mock_plot.return_value = 'plot' @@ -103,7 +103,7 @@ def test_get_column_pair_plot(mock_plot): table2 = pd.DataFrame({'col1': [2, 1, 3], 'col2': [1, 2, 3]}) data1 = {'table': table1} data2 = {'table': table2} - metadata = MultiTableMetadata() + metadata = Metadata() metadata.detect_table_from_dataframe('table', table1) mock_plot.return_value = 'plot' @@ -122,7 +122,7 @@ def test_get_column_pair_plot_only_real_or_synthetic(mock_plot): # Setup table1 = pd.DataFrame({'col1': [1, 2, 3], 'col2': [3, 2, 1]}) data1 = {'table': table1} - metadata = MultiTableMetadata() + metadata = Metadata() metadata.detect_table_from_dataframe('table', table1) mock_plot.return_value = 'plot' @@ -170,7 +170,7 @@ def test_get_cardinality_plot(mock_plot): ], 'METADATA_SPEC_VERSION': 'MULTI_TABLE_V1', } - metadata = MultiTableMetadata.load_from_dict(metadata_dict) + metadata = Metadata.load_from_dict(metadata_dict) mock_plot.return_value = 'plot' # Run @@ -213,7 +213,7 @@ def test_get_cardinality_plot_plot_type(mock_plot): ], 'METADATA_SPEC_VERSION': 'MULTI_TABLE_V1', } - metadata = MultiTableMetadata.load_from_dict(metadata_dict) + metadata = Metadata.load_from_dict(metadata_dict) mock_plot.return_value = 'plot' # Run diff --git a/tests/unit/evaluation/test_single_table.py b/tests/unit/evaluation/test_single_table.py index 3a53e11e2..f669a6e40 100644 --- a/tests/unit/evaluation/test_single_table.py +++ b/tests/unit/evaluation/test_single_table.py @@ -14,7 +14,6 @@ run_diagnostic, ) from sdv.metadata.metadata import Metadata -from sdv.metadata.single_table import SingleTableMetadata def test_evaluate_quality(): @@ -22,15 +21,18 @@ def test_evaluate_quality(): # Setup data1 = pd.DataFrame({'col': [1, 2, 3]}) data2 = pd.DataFrame({'col': [2, 1, 3]}) - metadata = SingleTableMetadata() - metadata.add_column('col', sdtype='numerical') + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('table', 'col', sdtype='numerical') QualityReport.generate = Mock() # Run evaluate_quality(data1, data2, metadata) # Assert - QualityReport.generate.assert_called_once_with(data1, data2, metadata.to_dict(), True) + QualityReport.generate.assert_called_once_with( + data1, data2, metadata._convert_to_single_table().to_dict(), True + ) def test_evaluate_quality_metadata(): @@ -55,15 +57,18 @@ def test_run_diagnostic(): # Setup data1 = pd.DataFrame({'col': [1, 2, 3]}) data2 = pd.DataFrame({'col': [2, 1, 3]}) - metadata = SingleTableMetadata() - metadata.add_column('col', sdtype='numerical') + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('table', 'col', sdtype='numerical') DiagnosticReport.generate = Mock(return_value=123) # Run run_diagnostic(data1, data2, metadata) # Assert - DiagnosticReport.generate.assert_called_once_with(data1, data2, metadata.to_dict(), True) + DiagnosticReport.generate.assert_called_once_with( + data1, data2, metadata._convert_to_single_table().to_dict(), True + ) def test_run_diagnostic_metadata(): @@ -93,8 +98,9 @@ def test_get_column_plot_continuous_data(mock_get_plot): # Setup data1 = pd.DataFrame({'col': [1, 2, 3]}) data2 = pd.DataFrame({'col': [2, 1, 3]}) - metadata = SingleTableMetadata() - metadata.add_column('col', sdtype='numerical') + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('table', 'col', sdtype='numerical') # Run plot = get_column_plot(data1, data2, metadata, 'col') @@ -135,8 +141,9 @@ def test_get_column_plot_discrete_data(mock_get_plot): # Setup data1 = pd.DataFrame({'col': ['a', 'b', 'c']}) data2 = pd.DataFrame({'col': ['a', 'b', 'c']}) - metadata = SingleTableMetadata() - metadata.add_column('col', sdtype='categorical') + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('table', 'col', sdtype='categorical') # Run plot = get_column_plot(data1, data2, metadata, 'col') @@ -178,8 +185,9 @@ def test_get_column_plot_discrete_data_with_distplot(mock_get_plot): # Setup data1 = pd.DataFrame({'col': ['a', 'b', 'c']}) data2 = pd.DataFrame({'col': ['a', 'b', 'c']}) - metadata = SingleTableMetadata() - metadata.add_column('col', sdtype='categorical') + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('table', 'col', sdtype='categorical') # Run plot = get_column_plot(data1, data2, metadata, 'col', plot_type='distplot') @@ -221,8 +229,9 @@ def test_get_column_plot_invalid_sdtype(mock_get_plot): # Setup data1 = pd.DataFrame({'col': ['a', 'b', 'c']}) data2 = pd.DataFrame({'col': ['a', 'b', 'c']}) - metadata = SingleTableMetadata() - metadata.add_column('col', sdtype='id') + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('table', 'col', sdtype='id') # Run and Assert error_msg = re.escape( @@ -265,8 +274,9 @@ def test_get_column_plot_invalid_sdtype_with_plot_type(mock_get_plot): # Setup data1 = pd.DataFrame({'col': ['a', 'b', 'c']}) data2 = pd.DataFrame({'col': ['a', 'b', 'c']}) - metadata = SingleTableMetadata() - metadata.add_column('col', sdtype='id') + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('table', 'col', sdtype='id') # Run plot = get_column_plot(data1, data2, metadata, 'col', plot_type='bar') @@ -307,8 +317,9 @@ def test_get_column_plot_with_datetime_sdtype(mock_get_plot): # Setup real_data = pd.DataFrame({'datetime': ['2021-02-01', '2021-12-01']}) synthetic_data = pd.DataFrame({'datetime': ['2023-02-21', '2022-12-13']}) - metadata = SingleTableMetadata() - metadata.add_column('datetime', sdtype='datetime', datetime_format='%Y-%m-%d') + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('table', 'datetime', sdtype='datetime', datetime_format='%Y-%m-%d') # Run plot = get_column_plot(real_data, synthetic_data, metadata, 'datetime') @@ -341,9 +352,10 @@ def test_get_column_pair_plot_with_continous_data(mock_get_plot): 'amount': [1.0, 2.0, 3.0], 'date': ['2021-01-01', '2022-01-01', '2023-01-01'], }) - metadata = SingleTableMetadata() - metadata.add_column('amount', sdtype='numerical') - metadata.add_column('date', sdtype='datetime') + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('table', 'amount', sdtype='numerical') + metadata.add_column('table', 'date', sdtype='datetime') # Run plot = get_column_pair_plot(real_data, synthetic_data, metadata, columns) @@ -375,9 +387,10 @@ def test_get_column_pair_plot_with_discrete_data(mock_get_plot): columns = ['name', 'subscriber'] real_data = pd.DataFrame({'name': ['John', 'Emily'], 'subscriber': [True, False]}) synthetic_data = pd.DataFrame({'name': ['John', 'Johanna'], 'subscriber': [False, False]}) - metadata = SingleTableMetadata() - metadata.add_column('name', sdtype='categorical') - metadata.add_column('subscriber', sdtype='boolean') + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('table', 'name', sdtype='categorical') + metadata.add_column('table', 'subscriber', sdtype='boolean') # Run plot = get_column_pair_plot(real_data, synthetic_data, metadata, columns) @@ -401,9 +414,10 @@ def test_get_column_pair_plot_with_mixed_data(mock_get_plot): columns = ['name', 'counts'] real_data = pd.DataFrame({'name': ['John', 'Emily'], 'counts': [1, 2]}) synthetic_data = pd.DataFrame({'name': ['John', 'Johanna'], 'counts': [3, 1]}) - metadata = SingleTableMetadata() - metadata.add_column('name', sdtype='categorical') - metadata.add_column('counts', sdtype='numerical') + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('table', 'name', sdtype='categorical') + metadata.add_column('table', 'counts', sdtype='numerical') # Run plot = get_column_pair_plot(real_data, synthetic_data, metadata, columns) @@ -433,9 +447,10 @@ def test_get_column_pair_plot_with_forced_plot_type(mock_get_plot): 'amount': [1.0, 2.0, 3.0], 'date': ['2021-01-01', '2022-01-01', '2023-01-01'], }) - metadata = SingleTableMetadata() - metadata.add_column('amount', sdtype='numerical') - metadata.add_column('date', sdtype='datetime') + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('table', 'amount', sdtype='numerical') + metadata.add_column('table', 'date', sdtype='datetime') # Run plot = get_column_pair_plot(real_data, synthetic_data, metadata, columns, plot_type='heatmap') @@ -474,9 +489,10 @@ def test_get_column_pair_plot_with_invalid_sdtype(mock_get_plot): 'amount': [1.0, 2.0, 3.0], 'id': [1, 2, 3], }) - metadata = SingleTableMetadata() - metadata.add_column('amount', sdtype='numerical') - metadata.add_column('id', sdtype='id') + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('table', 'amount', sdtype='numerical') + metadata.add_column('table', 'id', sdtype='id') # Run and Assert error_msg = re.escape( @@ -504,9 +520,10 @@ def test_get_column_pair_plot_with_invalid_sdtype_and_plot_type(mock_get_plot): 'amount': [1.0, 2.0, 3.0], 'id': [1, 2, 3], }) - metadata = SingleTableMetadata() - metadata.add_column('amount', sdtype='numerical') - metadata.add_column('id', sdtype='id') + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('table', 'amount', sdtype='numerical') + metadata.add_column('table', 'id', sdtype='id') # Run plot = get_column_pair_plot(real_data, synthetic_data, metadata, columns, plot_type='heatmap') @@ -532,9 +549,10 @@ def test_get_column_pair_plot_with_sample_size(mock_get_plot): 'amount': [1.0, 2.0, 3.0], 'price': [11.0, 22.0, 33.0], }) - metadata = SingleTableMetadata() - metadata.add_column('amount', sdtype='numerical') - metadata.add_column('price', sdtype='numerical') + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('table', 'amount', sdtype='numerical') + metadata.add_column('table', 'price', sdtype='numerical') # Run get_column_pair_plot(real_data, synthetic_data, metadata, columns, sample_size=2) @@ -594,9 +612,10 @@ def test_get_column_pair_plot_with_sample_size_too_big(mock_get_plot): 'amount': [1.0, 2.0, 3.0], 'price': [11.0, 22.0, 33.0], }) - metadata = SingleTableMetadata() - metadata.add_column('amount', sdtype='numerical') - metadata.add_column('price', sdtype='numerical') + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('table', 'amount', sdtype='numerical') + metadata.add_column('table', 'price', sdtype='numerical') # Run plot = get_column_pair_plot(real_data, synthetic_data, metadata, columns, sample_size=10) diff --git a/tests/unit/io/local/test_local.py b/tests/unit/io/local/test_local.py index 48395f276..478af4c36 100644 --- a/tests/unit/io/local/test_local.py +++ b/tests/unit/io/local/test_local.py @@ -8,7 +8,7 @@ import pytest from sdv.io.local.local import BaseLocalHandler, CSVHandler, ExcelHandler -from sdv.metadata.multi_table import MultiTableMetadata +from sdv.metadata import Metadata class TestBaseLocalHandler: @@ -34,9 +34,9 @@ def test_create_metadata(self): metadata = instance.create_metadata(data) # Assert - assert isinstance(metadata, MultiTableMetadata) + assert isinstance(metadata, Metadata) assert metadata.to_dict() == { - 'METADATA_SPEC_VERSION': 'MULTI_TABLE_V1', + 'METADATA_SPEC_VERSION': 'V1', 'relationships': [], 'tables': { 'guests': { diff --git a/tests/unit/lite/test_single_table.py b/tests/unit/lite/test_single_table.py index c51d99b83..d9880cc16 100644 --- a/tests/unit/lite/test_single_table.py +++ b/tests/unit/lite/test_single_table.py @@ -21,7 +21,7 @@ def test___init__invalid_name(self): """ # Run and Assert with pytest.raises(ValueError, match=r"'name' must be one of *"): - SingleTablePreset(metadata=SingleTableMetadata(), name='invalid') + SingleTablePreset(metadata=Metadata(), name='invalid') @patch('sdv.lite.single_table.GaussianCopulaSynthesizer') def test__init__speed_passes_correct_parameters(self, gaussian_copula_mock): diff --git a/tests/unit/metadata/test_metadata.py b/tests/unit/metadata/test_metadata.py index 0d282c347..da2046596 100644 --- a/tests/unit/metadata/test_metadata.py +++ b/tests/unit/metadata/test_metadata.py @@ -391,7 +391,7 @@ def test__set_metadata_multi_table(self, mock_singletablemetadata): Setup: - instance of ``Metadata``. - - A dict representing a ``MultiTableMetadata``. + - A dict representing a ``Metadata``. Mock: - Mock ``SingleTableMetadata`` from ``sdv.metadata.multi_table`` @@ -457,16 +457,10 @@ def test__set_metadata_single_table(self): Setup: - instance of ``Metadata``. - - A dict representing a ``SingleTableMetadata``. - - Mock: - - Mock ``SingleTableMetadata`` from ``sdv.metadata.multi_table`` - - Side Effects: - - ``SingleTableMetadata.load_from_dict`` has been called. + - A dict representing a single table``Metadata``. """ # Setup - multitable_metadata = { + single_table_metadata = { 'columns': {'my_column': 'value'}, 'primary_key': 'pk', 'alternate_keys': [], @@ -478,7 +472,7 @@ def test__set_metadata_single_table(self): instance = Metadata() # Run - instance._set_metadata_dict(multitable_metadata) + instance._set_metadata_dict(single_table_metadata) # Assert assert instance.tables['default_table_name'].columns == {'my_column': 'value'} diff --git a/tests/unit/multi_table/test_base.py b/tests/unit/multi_table/test_base.py index 3ca501aaf..8e0c60629 100644 --- a/tests/unit/multi_table/test_base.py +++ b/tests/unit/multi_table/test_base.py @@ -151,7 +151,8 @@ def test___init___deprecated(self): """Test that init with old MultiTableMetadata gives a future warnging.""" # Setup metadata = get_multi_table_metadata() - metadata.validate = Mock() + multi_metadata = MultiTableMetadata.load_from_dict(metadata.to_dict()) + multi_metadata.validate = Mock() deprecation_msg = re.escape( "The 'MultiTableMetadata' is deprecated. Please use the new " @@ -160,7 +161,7 @@ def test___init___deprecated(self): # Run with pytest.warns(FutureWarning, match=deprecation_msg): - BaseMultiTableSynthesizer(metadata) + BaseMultiTableSynthesizer(multi_metadata) @patch('sdv.metadata.single_table.is_faker_function') def test__init__column_relationship_warning(self, mock_is_faker_function): @@ -231,7 +232,7 @@ def test__check_metadata_updated(self): def test_set_address_columns(self): """Test the ``set_address_columns`` method.""" # Setup - metadata = MultiTableMetadata().load_from_dict({ + metadata = Metadata().load_from_dict({ 'tables': { 'address_table': { 'columns': { @@ -274,7 +275,7 @@ def test_set_address_columns(self): def test_set_address_columns_error(self): """Test that ``set_address_columns`` raises an error for unknown table.""" # Setup - metadata = MultiTableMetadata() + metadata = Metadata() columns = ('country_column', 'city_column') metadata.validate = Mock() SingleTableMetadata.validate = Mock() @@ -857,7 +858,7 @@ def test_preprocess_int_columns(self): } ], } - metadata = MultiTableMetadata.load_from_dict(metadata_dict) + metadata = Metadata.load_from_dict(metadata_dict) instance = BaseMultiTableSynthesizer(metadata) instance.validate = Mock() instance._table_synthesizers = {'first_table': Mock(), 'second_table': Mock()} @@ -1425,7 +1426,7 @@ def test_add_constraints_missing_table_name(self): """Test error raised when ``table_name`` is missing.""" # Setup data = pd.DataFrame({'col': [1, 2, 3]}) - metadata = MultiTableMetadata() + metadata = Metadata() metadata.detect_table_from_dataframe('table', data) constraint = {'constraint_class': 'Inequality'} model = BaseMultiTableSynthesizer(metadata) @@ -1524,7 +1525,7 @@ def test_get_info(self, mock_version): """ # Setup data = {'tab': pd.DataFrame({'col': [1, 2, 3]})} - metadata = MultiTableMetadata() + metadata = Metadata() metadata.add_table('tab') metadata.add_column('tab', 'col', sdtype='numerical') mock_version.public = '1.0.0' @@ -1572,7 +1573,7 @@ def test_get_info_with_enterprise(self, mock_version): """ # Setup data = {'tab': pd.DataFrame({'col': [1, 2, 3]})} - metadata = MultiTableMetadata() + metadata = Metadata() metadata.add_table('tab') metadata.add_column('tab', 'col', sdtype='numerical') mock_version.public = '1.0.0' @@ -1636,7 +1637,7 @@ def test_save(self, cloudpickle_mock, mock_datetime, tmp_path, caplog): def test_save_warning(self, tmp_path): """Test that the synthesizer produces a warning if saved without fitting.""" # Setup - synthesizer = BaseMultiTableSynthesizer(MultiTableMetadata()) + synthesizer = BaseMultiTableSynthesizer(Metadata()) # Run and Assert warn_msg = re.escape( diff --git a/tests/unit/multi_table/test_hma.py b/tests/unit/multi_table/test_hma.py index ef0085c19..72abc0e25 100644 --- a/tests/unit/multi_table/test_hma.py +++ b/tests/unit/multi_table/test_hma.py @@ -6,7 +6,7 @@ import pytest from sdv.errors import SynthesizerInputError -from sdv.metadata.multi_table import MultiTableMetadata +from sdv.metadata.metadata import Metadata from sdv.multi_table.hma import HMASynthesizer from sdv.single_table.copulas import GaussianCopulaSynthesizer from tests.utils import get_multi_table_data, get_multi_table_metadata @@ -790,7 +790,7 @@ def test__estimate_num_columns_to_be_modeled_multiple_foreign_keys(self): 'col1': [0, 1, 2], }) data = {'parent': parent, 'child': child} - metadata = MultiTableMetadata.load_from_dict({ + metadata = Metadata.load_from_dict({ 'tables': { 'parent': { 'primary_key': 'id', @@ -879,7 +879,7 @@ def test__estimate_num_columns_to_be_modeled_different_distributions(self): 'col': {'sdtype': 'numerical'}, }, } - metadata = MultiTableMetadata.load_from_dict({ + metadata = Metadata.load_from_dict({ 'tables': { 'parent': { 'primary_key': 'id', @@ -1022,7 +1022,7 @@ def test__estimate_num_columns_to_be_modeled(self): 'parent': parent, 'child': child, } - metadata = MultiTableMetadata.load_from_dict({ + metadata = Metadata.load_from_dict({ 'tables': { 'root1': { 'primary_key': 'R1', @@ -1123,3 +1123,111 @@ def test__estimate_num_columns_to_be_modeled(self): num_table_cols -= 1 assert num_table_cols == estimated_num_columns[table_name] + + def test__estimate_num_columns_to_be_modeled_various_sdtypes(self): + """Test the estimated number of columns is correct for various sdtypes. + + To check that the number columns is correct we Mock the ``_finalize`` method + and compare its output with the estimated number of columns. + + The dataset used follows the structure below: + R1 R2 + | / + GP + | + P + """ + # Setup + root1 = pd.DataFrame({'R1': [0, 1, 2]}) + root2 = pd.DataFrame({'R2': [0, 1, 2], 'data': [0, 1, 2]}) + grandparent = pd.DataFrame({'GP': [0, 1, 2], 'R1': [0, 1, 2], 'R2': [0, 1, 2]}) + parent = pd.DataFrame({ + 'P': [0, 1, 2], + 'GP': [0, 1, 2], + 'numerical': [0.1, 0.5, np.nan], + 'categorical': ['a', np.nan, 'c'], + 'datetime': [None, '2019-01-02', '2019-01-03'], + 'boolean': [float('nan'), False, True], + 'id': [0, 1, 2], + }) + data = { + 'root1': root1, + 'root2': root2, + 'grandparent': grandparent, + 'parent': parent, + } + metadata = Metadata.load_from_dict({ + 'tables': { + 'root1': { + 'primary_key': 'R1', + 'columns': { + 'R1': {'sdtype': 'id'}, + }, + }, + 'root2': { + 'primary_key': 'R2', + 'columns': {'R2': {'sdtype': 'id'}, 'data': {'sdtype': 'numerical'}}, + }, + 'grandparent': { + 'primary_key': 'GP', + 'columns': { + 'GP': {'sdtype': 'id'}, + 'R1': {'sdtype': 'id'}, + 'R2': {'sdtype': 'id'}, + }, + }, + 'parent': { + 'primary_key': 'P', + 'columns': { + 'P': {'sdtype': 'id'}, + 'GP': {'sdtype': 'id'}, + 'numerical': {'sdtype': 'numerical'}, + 'categorical': {'sdtype': 'categorical'}, + 'datetime': {'sdtype': 'datetime'}, + 'boolean': {'sdtype': 'boolean'}, + 'id': {'sdtype': 'id'}, + }, + }, + }, + 'relationships': [ + { + 'parent_table_name': 'root1', + 'parent_primary_key': 'R1', + 'child_table_name': 'grandparent', + 'child_foreign_key': 'R1', + }, + { + 'parent_table_name': 'root2', + 'parent_primary_key': 'R2', + 'child_table_name': 'grandparent', + 'child_foreign_key': 'R2', + }, + { + 'parent_table_name': 'grandparent', + 'parent_primary_key': 'GP', + 'child_table_name': 'parent', + 'child_foreign_key': 'GP', + }, + ], + }) + synthesizer = HMASynthesizer(metadata) + synthesizer._finalize = Mock(return_value=data) + + # Run estimation + estimated_num_columns = synthesizer._estimate_num_columns(metadata) + + # Run actual modeling + synthesizer.fit(data) + synthesizer.sample() + + # Assert estimated number of columns is correct + tables = synthesizer._finalize.call_args[0][0] + for table_name, table in tables.items(): + # Subract all the id columns present in the data, as those are not estimated + num_table_cols = len(table.columns) + if table_name in {'parent', 'grandparent'}: + num_table_cols -= 3 + if table_name in {'root1', 'root2'}: + num_table_cols -= 1 + + assert num_table_cols == estimated_num_columns[table_name] diff --git a/tests/unit/multi_table/test_utils.py b/tests/unit/multi_table/test_utils.py index e01195f89..57b52fa08 100644 --- a/tests/unit/multi_table/test_utils.py +++ b/tests/unit/multi_table/test_utils.py @@ -8,7 +8,7 @@ import pytest from sdv.errors import InvalidDataError, SamplingError -from sdv.metadata import MultiTableMetadata +from sdv.metadata.metadata import Metadata from sdv.multi_table.utils import ( _drop_rows, _get_all_descendant_per_root_at_order_n, @@ -605,7 +605,7 @@ def test__get_disconnected_roots_from_table(table_name, expected_result): def test__simplify_relationships_and_tables(): """Test the ``_simplify_relationships`` method.""" # Setup - metadata = MultiTableMetadata().load_from_dict({ + metadata = Metadata().load_from_dict({ 'tables': { 'grandparent': {'columns': {'col_1': {'sdtype': 'numerical'}}}, 'parent': {'columns': {'col_2': {'sdtype': 'numerical'}}}, @@ -646,7 +646,7 @@ def test__simplify_relationships_and_tables(): def test__simplify_grandchildren(): """Test the ``_simplify_grandchildren`` method.""" # Setup - metadata = MultiTableMetadata().load_from_dict({ + metadata = Metadata().load_from_dict({ 'tables': { 'grandparent': {'columns': {'col_1': {'sdtype': 'numerical'}}}, 'parent': {'columns': {'col_2': {'sdtype': 'numerical'}}}, @@ -697,7 +697,7 @@ def test__get_num_column_to_drop(): datetime_columns = {f'col_{i}': {'sdtype': 'datetime'} for i in range(600, 900)} id_columns = {f'col_{i}': {'sdtype': 'id'} for i in range(900, 910)} email_columns = {f'col_{i}': {'sdtype': 'email'} for i in range(910, 920)} - metadata = MultiTableMetadata().load_from_dict({ + metadata = Metadata().load_from_dict({ 'tables': { 'child': { 'columns': { @@ -881,11 +881,11 @@ def test__simplify_children(mock_get_columns_to_drop_child, mock_hma): child_1_before_simplify['columns']['col_4'] = {'sdtype': 'categorical'} child_2_before_simplify = deepcopy(child_2) child_2_before_simplify['columns']['col_8'] = {'sdtype': 'categorical'} - metadata = MultiTableMetadata().load_from_dict({ + metadata = Metadata().load_from_dict({ 'relationships': relatioships, 'tables': {'child_1': child_1_before_simplify, 'child_2': child_2_before_simplify}, }) - metadata_after_simplify_2 = MultiTableMetadata().load_from_dict({ + metadata_after_simplify_2 = Metadata().load_from_dict({ 'relationships': relatioships, 'tables': {'child_1': child_1, 'child_2': child_2}, }) @@ -951,7 +951,7 @@ def test__simplify_metadata_no_child_simplification(mock_hma): 'other_table': {'columns': {'col_8': {'sdtype': 'numerical'}}}, 'other_root': {'columns': {'col_9': {'sdtype': 'numerical'}}}, } - metadata = MultiTableMetadata().load_from_dict({ + metadata = Metadata().load_from_dict({ 'relationships': relationships, 'tables': tables, }) @@ -1047,7 +1047,7 @@ def test__simplify_metadata(mock_get_columns_to_drop_child, mock_hma): }, 'other_root': {'columns': {'col_9': {'sdtype': 'numerical'}}}, } - metadata = MultiTableMetadata().load_from_dict({ + metadata = Metadata().load_from_dict({ 'relationships': relationships, 'tables': tables, }) @@ -1122,7 +1122,7 @@ def test__simplify_metadata(mock_get_columns_to_drop_child, mock_hma): def test__simplify_data(): """Test the ``_simplify_data`` method.""" # Setup - metadata = MultiTableMetadata().load_from_dict({ + metadata = Metadata().load_from_dict({ 'tables': { 'parent': {'columns': {'col_1': {'sdtype': 'id'}}}, 'child': {'columns': {'col_2': {'sdtype': 'id'}}}, @@ -1249,7 +1249,7 @@ def test__subsample_disconnected_roots(mock_drop_rows, mock_get_disconnected_roo 'col_12': [6, 7, 8, 9, 10], }), } - metadata = MultiTableMetadata().load_from_dict({ + metadata = Metadata().load_from_dict({ 'tables': { 'disconnected_root': { 'columns': { @@ -1378,7 +1378,7 @@ def test__get_primary_keys_referenced(): }), } - metadata = MultiTableMetadata().load_from_dict({ + metadata = Metadata().load_from_dict({ 'tables': { 'grandparent': { 'columns': { @@ -1598,7 +1598,7 @@ def test__subsample_ancestors(): 'child': {21, 22, 23, 24, 25}, } - metadata = MultiTableMetadata().load_from_dict({ + metadata = Metadata().load_from_dict({ 'tables': { 'grandparent': { 'columns': { @@ -1785,7 +1785,7 @@ def test__subsample_ancestors_schema_diamond_shape(): 'parent_2': {31, 32, 33, 34, 35}, } - metadata = MultiTableMetadata().load_from_dict({ + metadata = Metadata().load_from_dict({ 'tables': { 'grandparent': { 'columns': { diff --git a/tests/unit/sequential/test_par.py b/tests/unit/sequential/test_par.py index 21a40922c..966939869 100644 --- a/tests/unit/sequential/test_par.py +++ b/tests/unit/sequential/test_par.py @@ -19,16 +19,17 @@ class TestPARSynthesizer: def get_metadata(self, add_sequence_key=True, add_sequence_index=False): - metadata = SingleTableMetadata() - metadata.add_column('time', sdtype='datetime') - metadata.add_column('gender', sdtype='categorical') - metadata.add_column('name', sdtype='id') - metadata.add_column('measurement', sdtype='numerical') + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('table', 'time', sdtype='datetime') + metadata.add_column('table', 'gender', sdtype='categorical') + metadata.add_column('table', 'name', sdtype='id') + metadata.add_column('table', 'measurement', sdtype='numerical') if add_sequence_key: - metadata.set_sequence_key('name') + metadata.set_sequence_key('table', 'name') if add_sequence_index: - metadata.set_sequence_index('time') + metadata.set_sequence_index('table', 'time') return metadata @@ -76,7 +77,10 @@ def test___init__(self): 'verbose': False, } assert isinstance(synthesizer._data_processor, DataProcessor) - assert synthesizer._data_processor.metadata == metadata + assert ( + synthesizer._data_processor.metadata.to_dict() + == metadata._convert_to_single_table().to_dict() + ) assert isinstance(synthesizer._context_synthesizer, GaussianCopulaSynthesizer) assert synthesizer._context_synthesizer.metadata.columns == { 'gender': {'sdtype': 'categorical'}, @@ -238,14 +242,14 @@ def test_get_metadata(self): result = instance.get_metadata() # Assert - assert result._convert_to_single_table().to_dict() == metadata.to_dict() + assert result.to_dict() == metadata.to_dict() assert isinstance(result, Metadata) def test_validate_context_columns_unique_per_sequence_key(self): """Test error is raised if context column values vary for each tuple of sequence keys. Setup: - A ``SingleTableMetadata`` instance where the context columns vary for different + A ``Metadata`` instance where the context columns vary for different combinations of values of the sequence keys. """ # Setup @@ -255,12 +259,13 @@ def test_validate_context_columns_unique_per_sequence_key(self): 'ct_col1': [1, 2, 2, 3, 2], 'ct_col2': [3, 3, 4, 3, 2], }) - metadata = SingleTableMetadata() - metadata.add_column('sk_col1', sdtype='id') - metadata.add_column('sk_col2', sdtype='id') - metadata.add_column('ct_col1', sdtype='numerical') - metadata.add_column('ct_col2', sdtype='numerical') - metadata.set_sequence_key('sk_col1') + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('table', 'sk_col1', sdtype='id') + metadata.add_column('table', 'sk_col2', sdtype='id') + metadata.add_column('table', 'ct_col1', sdtype='numerical') + metadata.add_column('table', 'ct_col2', sdtype='numerical') + metadata.set_sequence_key('table', 'sk_col1') instance = PARSynthesizer(metadata=metadata, context_columns=['ct_col1', 'ct_col2']) # Run and Assert @@ -524,7 +529,7 @@ def test_auto_assign_transformers_without_enforce_min_max(self, mock_get_transfo 'measurement': [55, 60, 65], }) metadata = self.get_metadata() - metadata.set_sequence_index('time') + metadata.set_sequence_index('table', 'time') mock_get_transfomers.return_value = {'time': FloatFormatter} # Run @@ -606,7 +611,7 @@ def test__fit_sequence_columns_with_categorical_float( data = self.get_data() data['measurement'] = data['measurement'].astype(float) metadata = self.get_metadata() - metadata.update_column('measurement', sdtype='categorical') + metadata.update_column('table', 'measurement', sdtype='categorical') par = PARSynthesizer(metadata=metadata, context_columns=['gender']) sequences = [ {'context': np.array(['M'], dtype=object), 'data': [['2020-01-03'], [65.0]]}, @@ -644,7 +649,7 @@ def test__fit_sequence_columns_with_sequence_index(self, assemble_sequences_mock 'measurement': [55, 60, 65, 65, 70], }) metadata = self.get_metadata() - metadata.set_sequence_index('time') + metadata.set_sequence_index('table', 'time') par = PARSynthesizer(metadata=metadata, context_columns=['gender']) sequences = [ {'context': np.array(['F'], dtype=object), 'data': [[1, 1], [55, 60], [1, 1]]}, @@ -835,7 +840,7 @@ def test__sample_from_par_with_sequence_index(self, tqdm_mock): """ # Setup metadata = self.get_metadata() - metadata.set_sequence_index('time') + metadata.set_sequence_index('table', 'time') par = PARSynthesizer(metadata=metadata, context_columns=['gender']) model_mock = Mock() par._model = model_mock diff --git a/tests/unit/single_table/test_base.py b/tests/unit/single_table/test_base.py index 5fd111960..158f71325 100644 --- a/tests/unit/single_table/test_base.py +++ b/tests/unit/single_table/test_base.py @@ -206,7 +206,7 @@ def test___init__invalid_enforce_min_max_values(self): ' Please provide True or False.' ) with pytest.raises(SynthesizerInputError, match=err_msg): - BaseSingleTableSynthesizer(SingleTableMetadata(), enforce_min_max_values='invalid') + BaseSingleTableSynthesizer(Metadata(), enforce_min_max_values='invalid') def test___init__invalid_enforce_rounding(self): """Test it crashes when ``enforce_rounding`` is not a boolean.""" @@ -216,12 +216,12 @@ def test___init__invalid_enforce_rounding(self): ' Please provide True or False.' ) with pytest.raises(SynthesizerInputError, match=err_msg): - BaseSingleTableSynthesizer(SingleTableMetadata(), enforce_rounding='invalid') + BaseSingleTableSynthesizer(Metadata(), enforce_rounding='invalid') def test_set_address_columns_warning(self): """Test ``set_address_columns`` method when the synthesizer has been fitted.""" # Setup - synthesizer = BaseSingleTableSynthesizer(SingleTableMetadata()) + synthesizer = BaseSingleTableSynthesizer(Metadata()) # Run and Assert expected_message = re.escape( @@ -286,7 +286,7 @@ def test_auto_assign_transformers(self): def test_auto_assign_transformers_with_invalid_data(self): """Test that auto_assign_transformer throws useful error about invalid data""" # Setup - metadata = SingleTableMetadata.load_from_dict({ + metadata = Metadata.load_from_dict({ 'columns': { 'a': {'sdtype': 'categorical'}, } @@ -590,7 +590,7 @@ def test_validate(self): """ # Setup data = pd.DataFrame() - metadata = SingleTableMetadata() + metadata = Metadata() instance = BaseSingleTableSynthesizer(metadata) instance._validate_metadata = Mock() instance._validate_constraints = Mock() @@ -612,7 +612,7 @@ def test_validate_raises_constraints_error(self): """ # Setup data = pd.DataFrame() - metadata = SingleTableMetadata() + metadata = Metadata() instance = BaseSingleTableSynthesizer(metadata) instance._validate_metadata = Mock(return_value=[]) instance._validate_constraints = Mock() @@ -639,7 +639,7 @@ def test_validate_raises_invalid_data_for_metadata(self): """ # Setup data = pd.DataFrame() - metadata = SingleTableMetadata() + metadata = Metadata() instance = BaseSingleTableSynthesizer(metadata) instance._validate_metadata = Mock(return_value=[]) instance._validate_constraints = Mock() @@ -707,11 +707,12 @@ def test_update_transformers_invalid_keys(self): """ # Setup column_name_to_transformer = {'col2': RegexGenerator(), 'col3': FloatFormatter()} - metadata = SingleTableMetadata() - metadata.add_column('col2', sdtype='id') - metadata.add_column('col3', sdtype='id') - metadata.set_sequence_key(('col2')) - metadata.add_alternate_keys(['col3']) + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('table', 'col2', sdtype='id') + metadata.add_column('table', 'col3', sdtype='id') + metadata.set_sequence_key('table', 'col2') + metadata.add_alternate_keys('table', ['col3']) instance = BaseSingleTableSynthesizer(metadata) # Run and Assert @@ -728,9 +729,10 @@ def test_update_transformers_already_fitted(self): fitted_transformer = FloatFormatter() fitted_transformer.fit(pd.DataFrame({'col': [1]}), 'col') column_name_to_transformer = {'col1': BinaryEncoder(), 'col2': fitted_transformer} - metadata = SingleTableMetadata() - metadata.add_column('col1', sdtype='boolean') - metadata.add_column('col2', sdtype='numerical') + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('table', 'col1', sdtype='boolean') + metadata.add_column('table', 'col2', sdtype='numerical') instance = BaseSingleTableSynthesizer(metadata) # Run and Assert @@ -742,9 +744,10 @@ def test_update_transformers_warns_gaussian_copula(self): """Test warning is raised when ohe is used for categorical column in the GaussianCopula.""" # Setup column_name_to_transformer = {'col1': OneHotEncoder(), 'col2': FloatFormatter()} - metadata = SingleTableMetadata() - metadata.add_column('col1', sdtype='categorical') - metadata.add_column('col2', sdtype='numerical') + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('table', 'col1', sdtype='categorical') + metadata.add_column('table', 'col2', sdtype='numerical') instance = GaussianCopulaSynthesizer(metadata) instance._data_processor.fit(pd.DataFrame({'col1': [1, 2], 'col2': [1, 2]})) @@ -769,9 +772,10 @@ def test_update_transformers_warns_models(self): """ # Setup column_name_to_transformer = {'col1': OneHotEncoder(), 'col2': FloatFormatter()} - metadata = SingleTableMetadata() - metadata.add_column('col1', sdtype='categorical') - metadata.add_column('col2', sdtype='numerical') + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('table', 'col1', sdtype='categorical') + metadata.add_column('table', 'col2', sdtype='numerical') # NOTE: when PARSynthesizer is implemented, add it here as well for model in [CTGANSynthesizer, CopulaGANSynthesizer, TVAESynthesizer]: @@ -794,9 +798,10 @@ def test_update_transformers_warns_fitted(self): """ # Setup column_name_to_transformer = {'col1': GaussianNormalizer(), 'col2': GaussianNormalizer()} - metadata = SingleTableMetadata() - metadata.add_column('col1', sdtype='numerical') - metadata.add_column('col2', sdtype='numerical') + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('table', 'col1', sdtype='numerical') + metadata.add_column('table', 'col2', sdtype='numerical') instance = BaseSingleTableSynthesizer(metadata) instance._data_processor.fit(pd.DataFrame({'col1': [1, 2], 'col2': [1, 2]})) instance._fitted = True @@ -812,9 +817,10 @@ def test_update_transformers(self): """Test method correctly updates the transformers in the HyperTransformer.""" # Setup column_name_to_transformer = {'col1': GaussianNormalizer(), 'col2': GaussianNormalizer()} - metadata = SingleTableMetadata() - metadata.add_column('col1', sdtype='numerical') - metadata.add_column('col2', sdtype='numerical') + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('table', 'col1', sdtype='numerical') + metadata.add_column('table', 'col2', sdtype='numerical') instance = BaseSingleTableSynthesizer(metadata) instance._data_processor.fit(pd.DataFrame({'col1': [1, 2], 'col2': [1, 2]})) @@ -1912,7 +1918,7 @@ def test_save(self, cloudpickle_mock, mock_datetime, tmp_path, caplog): def test_save_warning(self, tmp_path): """Test that the synthesizer produces a warning if saved without fitting.""" # Setup - synthesizer = BaseSynthesizer(SingleTableMetadata()) + synthesizer = BaseSynthesizer(Metadata()) # Run and Assert warn_msg = re.escape( @@ -2033,7 +2039,7 @@ def test_add_custom_constraint_class(self): def test_add_constraint_warning(self): """Test a warning is raised when the synthesizer had already been fitted.""" # Setup - metadata = SingleTableMetadata() + metadata = Metadata() instance = BaseSingleTableSynthesizer(metadata) instance._fitted = True @@ -2045,8 +2051,9 @@ def test_add_constraint_warning(self): def test_add_constraints(self): """Test a list of constraints can be added to the synthesizer.""" # Setup - metadata = SingleTableMetadata() - metadata.add_column('col', sdtype='numerical') + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('table', 'col', sdtype='numerical') instance = BaseSingleTableSynthesizer(metadata) positive_constraint = { 'constraint_class': 'Positive', @@ -2075,8 +2082,9 @@ def test_add_constraints(self): def test_get_constraints(self): """Test a list of constraints is returned by the method.""" # Setup - metadata = SingleTableMetadata() - metadata.add_column('col', sdtype='numerical') + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('table', 'col', sdtype='numerical') instance = BaseSingleTableSynthesizer(metadata) positive_constraint = { 'constraint_class': 'Positive', @@ -2109,8 +2117,9 @@ def test_get_info_no_enterprise(self, mock_sdv_version): data = pd.DataFrame({'col': [1, 2, 3]}) mock_sdv_version.public = '1.0.0' mock_sdv_version.enterprise = None - metadata = SingleTableMetadata() - metadata.add_column('col', sdtype='numerical') + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('table', 'col', sdtype='numerical') with patch('sdv.single_table.base.datetime.datetime') as mock_date: mock_date.today.return_value = datetime(2023, 1, 23) @@ -2156,8 +2165,9 @@ def test_get_info_with_enterprise(self, mock_sdv_version): data = pd.DataFrame({'col': [1, 2, 3]}) mock_sdv_version.public = '1.0.0' mock_sdv_version.enterprise = '1.2.0' - metadata = SingleTableMetadata() - metadata.add_column('col', sdtype='numerical') + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('table', 'col', sdtype='numerical') with patch('sdv.single_table.base.datetime.datetime') as mock_date: mock_date.today.return_value = datetime(2023, 1, 23) diff --git a/tests/unit/single_table/test_copulagan.py b/tests/unit/single_table/test_copulagan.py index e41fbdc83..e7c531a26 100644 --- a/tests/unit/single_table/test_copulagan.py +++ b/tests/unit/single_table/test_copulagan.py @@ -9,7 +9,6 @@ from sdv.errors import SynthesizerInputError from sdv.metadata.metadata import Metadata -from sdv.metadata.single_table import SingleTableMetadata from sdv.single_table.copulagan import CopulaGANSynthesizer @@ -17,7 +16,7 @@ class TestCopulaGANSynthesizer: def test___init__(self): """Test creating an instance of ``CopulaGANSynthesizer``.""" # Setup - metadata = SingleTableMetadata() + metadata = Metadata() enforce_min_max_values = True enforce_rounding = True @@ -89,8 +88,9 @@ def test___init__with_unified_metadata(self): def test___init__custom(self): """Test creating an instance of ``CopulaGANSynthesizer`` with custom parameters.""" # Setup - metadata = SingleTableMetadata() - metadata.add_column('field', sdtype='numerical') + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('table', 'field', sdtype='numerical') enforce_min_max_values = False enforce_rounding = False embedding_dim = 64 @@ -158,7 +158,7 @@ def test___init__custom(self): def test___init__incorrect_numerical_distributions(self): """Test it crashes when ``numerical_distributions`` receives a non-dictionary.""" # Setup - metadata = SingleTableMetadata() + metadata = Metadata() numerical_distributions = 'invalid' # Run @@ -169,7 +169,7 @@ def test___init__incorrect_numerical_distributions(self): def test___init__invalid_column_numerical_distributions(self): """Test it crashes when ``numerical_distributions`` includes invalid columns.""" # Setup - metadata = SingleTableMetadata() + metadata = Metadata() numerical_distributions = {'totally_fake_column_name': 'beta'} # Run @@ -184,7 +184,7 @@ def test___init__invalid_column_numerical_distributions(self): def test_get_params(self): """Test that inherited method ``get_params`` returns all the specific init parameters.""" # Setup - metadata = SingleTableMetadata() + metadata = Metadata() instance = CopulaGANSynthesizer(metadata) # Run @@ -224,18 +224,11 @@ def test__create_gaussian_normalizer_config(self, mock_rdt): """ # Setup numerical_distributions = {'age': 'gamma'} - metadata = SingleTableMetadata() - metadata.columns = { - 'name': { - 'sdtype': 'categorical', - }, - 'age': { - 'sdtype': 'numerical', - }, - 'account': { - 'sdtype': 'numerical', - }, - } + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('table', 'name', sdtype='categorical') + metadata.add_column('table', 'age', sdtype='numerical') + metadata.add_column('table', 'account', sdtype='numerical') instance = CopulaGANSynthesizer(metadata, numerical_distributions=numerical_distributions) processed_data = pd.DataFrame({ @@ -280,8 +273,9 @@ def test__fit_logging(self, mock_rdt, mock_ctgansynthesizer__fit, mock_logger): were renamed/dropped during preprocessing. """ # Setup - metadata = SingleTableMetadata() - metadata.add_column('col', sdtype='numerical') + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('table', 'col', sdtype='numerical') numerical_distributions = {'col': 'gamma'} instance = CopulaGANSynthesizer(metadata, numerical_distributions=numerical_distributions) processed_data = pd.DataFrame() @@ -305,7 +299,7 @@ def test__fit(self, mock_rdt, mock_ctgansynthesizer__fit): one of the ``copulas`` distributions. """ # Setup - metadata = SingleTableMetadata() + metadata = Metadata() instance = CopulaGANSynthesizer(metadata) instance._create_gaussian_normalizer_config = Mock() processed_data = pd.DataFrame() @@ -333,8 +327,8 @@ def test_get_learned_distributions(self): """ # Setup data = pd.DataFrame({'zero': [0, 0, 0], 'one': [1, 1, 1]}) - stm = SingleTableMetadata() - stm.detect_from_dataframe(data) + stm = Metadata() + stm.detect_from_dataframes({'table': data}) cgs = CopulaGANSynthesizer(stm) zero_transformer_mock = Mock(spec_set=GaussianNormalizer) zero_transformer_mock._univariate.to_dict.return_value = { @@ -378,8 +372,8 @@ def test_get_learned_distributions_raises_an_error(self): """Test that ``get_learned_distributions`` raises an error.""" # Setup data = pd.DataFrame({'zero': [0, 0, 0], 'one': [1, 1, 1]}) - stm = SingleTableMetadata() - stm.detect_from_dataframe(data) + stm = Metadata() + stm.detect_from_dataframes({'table': data}) cgs = CopulaGANSynthesizer(stm) # Run and Assert diff --git a/tests/unit/single_table/test_copulas.py b/tests/unit/single_table/test_copulas.py index 130fb068d..ce4300a12 100644 --- a/tests/unit/single_table/test_copulas.py +++ b/tests/unit/single_table/test_copulas.py @@ -9,7 +9,6 @@ from sdv.errors import SynthesizerInputError from sdv.metadata.metadata import Metadata -from sdv.metadata.single_table import SingleTableMetadata from sdv.single_table.copulas import GaussianCopulaSynthesizer @@ -37,7 +36,7 @@ def test_get_distribution_class_not_in_distributions(self): def test___init__(self): """Test creating an instance of ``GaussianCopulaSynthesizer``.""" # Setup - metadata = SingleTableMetadata() + metadata = Metadata() enforce_min_max_values = True enforce_rounding = True numerical_distributions = None @@ -91,8 +90,9 @@ def test___init__with_unified_metadata(self): def test___init__custom(self): """Test creating an instance of ``GaussianCopulaSynthesizer`` with custom parameters.""" # Setup - metadata = SingleTableMetadata() - metadata.add_column('field', sdtype='numerical') + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('table', 'field', sdtype='numerical') enforce_min_max_values = False enforce_rounding = False numerical_distributions = {'field': 'gamma'} @@ -118,7 +118,7 @@ def test___init__custom(self): def test___init__incorrect_numerical_distributions(self): """Test it crashes when ``numerical_distributions`` receives a non-dictionary.""" # Setup - metadata = SingleTableMetadata() + metadata = Metadata() numerical_distributions = 'invalid' # Run @@ -129,7 +129,7 @@ def test___init__incorrect_numerical_distributions(self): def test___init__incorrect_column_numerical_distributions(self): """Test it crashes when ``numerical_distributions`` includes invalid columns.""" # Setup - metadata = SingleTableMetadata() + metadata = Metadata() numerical_distributions = {'totally_fake_column_name': 'beta'} # Run @@ -144,7 +144,7 @@ def test___init__incorrect_column_numerical_distributions(self): def test_get_parameters(self): """Test that inherited method ``get_parameters`` returns the specified init parameters.""" # Setup - metadata = SingleTableMetadata() + metadata = Metadata() instance = GaussianCopulaSynthesizer(metadata) # Run @@ -167,8 +167,9 @@ def test__fit_logging(self, mock_logger): were renamed/dropped during preprocessing. """ # Setup - metadata = SingleTableMetadata() - metadata.add_column('col', sdtype='numerical') + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('table', 'col', sdtype='numerical') numerical_distributions = {'col': 'gamma'} instance = GaussianCopulaSynthesizer( metadata, numerical_distributions=numerical_distributions @@ -194,9 +195,10 @@ def test__fit(self, mock_multivariate, mock_warnings): the ``numerical_distributions``. """ # Setup - metadata = SingleTableMetadata() - metadata.add_column('name', sdtype='numerical') - metadata.add_column('user.id', sdtype='numerical') + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('table', 'name', sdtype='numerical') + metadata.add_column('table', 'user.id', sdtype='numerical') numerical_distributions = {'name': 'uniform', 'user.id': 'gamma'} processed_data = pd.DataFrame({ @@ -339,7 +341,7 @@ def test__rebuild_gaussian_copula(self): - numpy array, Square correlation matrix """ # Setup - metadata = SingleTableMetadata() + metadata = Metadata() gaussian_copula = GaussianCopulaSynthesizer(metadata) model_parameters = { 'univariates': { @@ -373,7 +375,7 @@ def test__rebuild_gaussian_copula(self): def test__rebuild_gaussian_copula_with_defaults(self, logger_mock): """Test the method with invalid parameters and default fallbacks.""" # Setup - metadata = SingleTableMetadata() + metadata = Metadata() gaussian_copula = GaussianCopulaSynthesizer(metadata, default_distribution='truncnorm') distribution_mock = Mock() delattr(distribution_mock.MODEL_CLASS, '_argcheck') @@ -472,8 +474,8 @@ def test_get_learned_distributions(self): """ # Setup data = pd.DataFrame({'zero': [0, 0, 0], 'one': [1, 1, 1]}) - stm = SingleTableMetadata() - stm.detect_from_dataframe(data) + stm = Metadata() + stm.detect_from_dataframes({'table': data}) gcs = GaussianCopulaSynthesizer(stm, numerical_distributions={'one': 'uniform'}) gcs.fit(data) @@ -497,8 +499,8 @@ def test_get_learned_distributions_raises_an_error(self): """ # Setup data = pd.DataFrame({'zero': [0, 0, 0], 'one': [1, 1, 1]}) - stm = SingleTableMetadata() - stm.detect_from_dataframe(data) + stm = Metadata() + stm.detect_from_dataframes({'table': data}) gcs = GaussianCopulaSynthesizer(stm) # Run and Assert diff --git a/tests/unit/single_table/test_ctgan.py b/tests/unit/single_table/test_ctgan.py index 0f81558e6..5cb3e6000 100644 --- a/tests/unit/single_table/test_ctgan.py +++ b/tests/unit/single_table/test_ctgan.py @@ -7,7 +7,7 @@ from sdmetrics import visualization from sdv.errors import InvalidDataTypeError, NotFittedError -from sdv.metadata.single_table import SingleTableMetadata +from sdv.metadata.metadata import Metadata from sdv.single_table.ctgan import CTGANSynthesizer, TVAESynthesizer, _validate_no_category_dtype @@ -33,7 +33,7 @@ class TestCTGANSynthesizer: def test___init__(self): """Test creating an instance of ``CTGANSynthesizer``.""" # Setup - metadata = SingleTableMetadata() + metadata = Metadata() enforce_min_max_values = True enforce_rounding = True @@ -65,7 +65,7 @@ def test___init__(self): def test___init__with_unified_metadata(self): """Test creating an instance of ``CTGANSynthesizer`` with Metadata.""" # Setup - metadata = SingleTableMetadata() + metadata = Metadata() enforce_min_max_values = True enforce_rounding = True @@ -97,7 +97,7 @@ def test___init__with_unified_metadata(self): def test___init__custom(self): """Test creating an instance of ``CTGANSynthesizer`` with custom parameters.""" # Setup - metadata = SingleTableMetadata() + metadata = Metadata() enforce_min_max_values = False enforce_rounding = False embedding_dim = 64 @@ -157,7 +157,7 @@ def test___init__custom(self): def test_get_parameters(self): """Test that inherited method ``get_parameters`` returns the specific init parameters.""" # Setup - metadata = SingleTableMetadata() + metadata = Metadata() instance = CTGANSynthesizer(metadata) # Run @@ -187,10 +187,11 @@ def test_get_parameters(self): def test__estimate_num_columns(self): """Test that ``_estimate_num_columns`` returns without crashing the number of columns.""" # Setup - metadata = SingleTableMetadata() - metadata.add_column('id', sdtype='numerical') - metadata.add_column('name', sdtype='categorical') - metadata.add_column('surname', sdtype='categorical') + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('table', 'id', sdtype='numerical') + metadata.add_column('table', 'name', sdtype='categorical') + metadata.add_column('table', 'surname', sdtype='categorical') data = pd.DataFrame({ 'id': np.random.rand(1_001), 'name': [f'cat_{i}' for i in range(1_001)], @@ -212,9 +213,10 @@ def test__estimate_num_columns(self): def test_preprocessing_many_categories(self, capfd): """Test a message is printed during preprocess when a column has many categories.""" # Setup - metadata = SingleTableMetadata() - metadata.add_column('name_longer_than_Original_Column_Name', sdtype='numerical') - metadata.add_column('categorical', sdtype='categorical') + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('table', 'name_longer_than_Original_Column_Name', sdtype='numerical') + metadata.add_column('table', 'categorical', sdtype='categorical') data = pd.DataFrame({ 'name_longer_than_Original_Column_Name': np.random.rand(1_001), 'categorical': [f'cat_{i}' for i in range(1_001)], @@ -242,9 +244,10 @@ def test_preprocessing_many_categories(self, capfd): def test_preprocessing_few_categories(self, capfd): """Test a message is not printed during preprocess when a column has few categories.""" # Setup - metadata = SingleTableMetadata() - metadata.add_column('name_longer_than_Original_Column_Name', sdtype='numerical') - metadata.add_column('categorical', sdtype='categorical') + metadata = Metadata() + metadata.add_table('table') + metadata.add_column('table', 'name_longer_than_Original_Column_Name', sdtype='numerical') + metadata.add_column('table', 'categorical', sdtype='categorical') data = pd.DataFrame({ 'name_longer_than_Original_Column_Name': np.random.rand(10), 'categorical': [f'cat_{i}' for i in range(10)], @@ -269,8 +272,9 @@ def test__fit(self, mock_category_validate, mock_detect_discrete_columns, mock_c that have been detected by the utility function. """ # Setup - metadata = SingleTableMetadata() - instance = CTGANSynthesizer(metadata) + metadata = Metadata() + single_metadata = metadata._convert_to_single_table() + instance = CTGANSynthesizer(single_metadata) processed_data = Mock() # Run @@ -279,7 +283,9 @@ def test__fit(self, mock_category_validate, mock_detect_discrete_columns, mock_c # Assert mock_category_validate.assert_called_once_with(processed_data) mock_detect_discrete_columns.assert_called_once_with( - metadata, processed_data, instance._data_processor._hyper_transformer.field_transformers + single_metadata, + processed_data, + instance._data_processor._hyper_transformer.field_transformers, ) mock_ctgan.assert_called_once_with( batch_size=500, @@ -309,7 +315,7 @@ def test_get_loss_values(self): mock_model = Mock() loss_values = pd.DataFrame({'Epoch': [0, 1, 2], 'Loss': [0.8, 0.6, 0.5]}) mock_model.loss_values = loss_values - metadata = SingleTableMetadata() + metadata = Metadata() instance = CTGANSynthesizer(metadata) instance._model = mock_model instance._fitted = True @@ -323,7 +329,7 @@ def test_get_loss_values(self): def test_get_loss_values_error(self): """Test the ``get_loss_values`` errors if synthesizer has not been fitted.""" # Setup - metadata = SingleTableMetadata() + metadata = Metadata() instance = CTGANSynthesizer(metadata) # Run / Assert @@ -335,7 +341,7 @@ def test_get_loss_values_error(self): def test_get_loss_values_plot(self, mock_line_plot): """Test the ``get_loss_values_plot`` method from ``CTGANSynthesizer.""" # Setup - metadata = SingleTableMetadata() + metadata = Metadata() instance = CTGANSynthesizer(metadata) mock_loss_value = Mock() mock_loss_value.item.return_value = 0.1 @@ -364,7 +370,7 @@ class TestTVAESynthesizer: def test___init__(self): """Test creating an instance of ``TVAESynthesizer``.""" # Setup - metadata = SingleTableMetadata() + metadata = Metadata() enforce_min_max_values = True enforce_rounding = True @@ -391,7 +397,7 @@ def test___init__(self): def test___init__custom(self): """Test creating an instance of ``TVAESynthesizer`` with custom parameters.""" # Setup - metadata = SingleTableMetadata() + metadata = Metadata() enforce_min_max_values = False enforce_rounding = False embedding_dim = 64 @@ -436,7 +442,7 @@ def test___init__custom(self): def test_get_parameters(self): """Test that inherited method ``get_parameters`` returns the specific init parameters.""" # Setup - metadata = SingleTableMetadata() + metadata = Metadata() instance = TVAESynthesizer(metadata) # Run @@ -468,8 +474,9 @@ def test__fit(self, mock_category_validate, mock_detect_discrete_columns, mock_t that have been detected by the utility function. """ # Setup - metadata = SingleTableMetadata() - instance = TVAESynthesizer(metadata) + metadata = Metadata() + single_metadata = metadata._convert_to_single_table() + instance = TVAESynthesizer(single_metadata) processed_data = Mock() # Run @@ -478,7 +485,9 @@ def test__fit(self, mock_category_validate, mock_detect_discrete_columns, mock_t # Assert mock_category_validate.assert_called_once_with(processed_data) mock_detect_discrete_columns.assert_called_once_with( - metadata, processed_data, instance._data_processor._hyper_transformer.field_transformers + single_metadata, + processed_data, + instance._data_processor._hyper_transformer.field_transformers, ) mock_tvae.assert_called_once_with( batch_size=500, @@ -503,7 +512,7 @@ def test_get_loss_values(self): mock_model = Mock() loss_values = pd.DataFrame({'Epoch': [0, 1, 2], 'Loss': [0.8, 0.6, 0.5]}) mock_model.loss_values = loss_values - metadata = SingleTableMetadata() + metadata = Metadata() instance = TVAESynthesizer(metadata) instance._model = mock_model instance._fitted = True @@ -517,7 +526,7 @@ def test_get_loss_values(self): def test_get_loss_values_error(self): """Test the ``get_loss_values`` errors if synthesizer has not been fitted.""" # Setup - metadata = SingleTableMetadata() + metadata = Metadata() instance = TVAESynthesizer(metadata) # Run / Assert diff --git a/tests/unit/utils/test_poc.py b/tests/unit/utils/test_poc.py index 6cd82c183..1473ed539 100644 --- a/tests/unit/utils/test_poc.py +++ b/tests/unit/utils/test_poc.py @@ -6,8 +6,8 @@ import pytest from sdv.errors import InvalidDataError -from sdv.metadata import MultiTableMetadata from sdv.metadata.errors import InvalidMetadataError +from sdv.metadata.metadata import Metadata from sdv.utils.poc import ( drop_unknown_references, get_random_subset, @@ -68,7 +68,7 @@ def test_simplify_schema( # Setup data = Mock() metadata = Mock() - simplified_metatadata = MultiTableMetadata() + simplified_metatadata = Metadata() mock_get_total_estimated_columns.return_value = 2000 mock_simplify_metadata.return_value = simplified_metatadata mock_simplify_data.return_value = { @@ -91,7 +91,7 @@ def test_simplify_schema( def test_simplify_schema_invalid_metadata(): """Test ``simplify_schema`` when the metadata is not invalid.""" # Setup - metadata = MultiTableMetadata().load_from_dict({ + metadata = Metadata().load_from_dict({ 'tables': {'table1': {'columns': {'column1': {'sdtype': 'categorical'}}}}, 'relationships': [ { @@ -119,7 +119,7 @@ def test_simplify_schema_invalid_metadata(): def test_simplify_schema_invalid_data(): """Test ``simplify_schema`` when the data is not valid.""" # Setup - metadata = MultiTableMetadata().load_from_dict({ + metadata = Metadata().load_from_dict({ 'tables': { 'table1': {'columns': {'column1': {'sdtype': 'id'}}, 'primary_key': 'column1'}, 'table2': { @@ -152,7 +152,7 @@ def test_simplify_schema_invalid_data(): def test_get_random_subset_invalid_metadata(): """Test ``get_random_subset`` when the metadata is invalid.""" # Setup - metadata = MultiTableMetadata().load_from_dict({ + metadata = Metadata().load_from_dict({ 'tables': {'table1': {'columns': {'column1': {'sdtype': 'categorical'}}}}, 'relationships': [ { @@ -180,7 +180,7 @@ def test_get_random_subset_invalid_metadata(): def test_get_random_subset_invalid_data(): """Test ``get_random_subset`` when the data is not valid.""" # Setup - metadata = MultiTableMetadata().load_from_dict({ + metadata = Metadata().load_from_dict({ 'tables': { 'table1': {'columns': {'column1': {'sdtype': 'id'}}, 'primary_key': 'column1'}, 'table2': { diff --git a/tests/utils.py b/tests/utils.py index a2fce6d1f..eded42cb1 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -5,7 +5,7 @@ import pandas as pd from sdv.logging import get_sdv_logger -from sdv.metadata.multi_table import MultiTableMetadata +from sdv.metadata.metadata import Metadata class DataFrameMatcher: @@ -80,7 +80,7 @@ def get_multi_table_metadata(): 'METADATA_SPEC_VERSION': 'MULTI_TABLE_V1', } - return MultiTableMetadata.load_from_dict(dict_metadata) + return Metadata.load_from_dict(dict_metadata) def get_multi_table_data(): From a9f089040731110d4ec271eccf27e79a85b330bf Mon Sep 17 00:00:00 2001 From: John La Date: Wed, 28 Aug 2024 14:17:39 -0500 Subject: [PATCH 07/16] Hide internal warnings about using SingleTableSynthesizer (#2201) --- sdv/sampling/hierarchical_sampler.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/sdv/sampling/hierarchical_sampler.py b/sdv/sampling/hierarchical_sampler.py index 0d5bf855d..be0fbaa3e 100644 --- a/sdv/sampling/hierarchical_sampler.py +++ b/sdv/sampling/hierarchical_sampler.py @@ -95,7 +95,11 @@ def _add_child_rows(self, child_name, parent_name, parent_row, sampled_data, num foreign_key = self.metadata._get_foreign_keys(parent_name, child_name)[0] if num_rows is None: num_rows = parent_row[f'__{child_name}__{foreign_key}__num_rows'] - child_synthesizer = self._recreate_child_synthesizer(child_name, parent_name, parent_row) + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', message=".*The 'SingleTableMetadata' is deprecated.*") + child_synthesizer = self._recreate_child_synthesizer( + child_name, parent_name, parent_row + ) sampled_rows = self._sample_rows(child_synthesizer, num_rows) if len(sampled_rows): From 6767125feed2c71ddcf5a5a5be88da97f073a09d Mon Sep 17 00:00:00 2001 From: John La Date: Fri, 30 Aug 2024 14:49:23 -0500 Subject: [PATCH 08/16] Hide warnings when using Metadata in PARSynthesizer (#2202) --- sdv/metadata/metadata.py | 24 ++++++++++++++++++++++++ sdv/sequential/par.py | 30 ++++++++++++++++++++---------- tests/unit/sequential/test_par.py | 4 +--- 3 files changed, 45 insertions(+), 13 deletions(-) diff --git a/sdv/metadata/metadata.py b/sdv/metadata/metadata.py index 3e855a6eb..fc3942a26 100644 --- a/sdv/metadata/metadata.py +++ b/sdv/metadata/metadata.py @@ -99,3 +99,27 @@ def _convert_to_single_table(self): ) return next(iter(self.tables.values()), SingleTableMetadata()) + + def set_sequence_index(self, table_name, column_name): + """Set the sequence index of a table. + + Args: + table_name (str): + Name of the table to set the sequence index. + column_name (str): + Name of the sequence index column. + """ + self._validate_table_exists(table_name) + self.tables[table_name].set_sequence_index(column_name) + + def set_sequence_key(self, table_name, column_name): + """Set the sequence key of a table. + + Args: + table_name (str): + Name of the table to set the sequence key. + column_name (str, tulple[str]): + Name (or tuple of names) of the sequence key column(s). + """ + self._validate_table_exists(table_name) + self.tables[table_name].set_sequence_key(column_name) diff --git a/sdv/sequential/par.py b/sdv/sequential/par.py index 59908acf4..13bcc8336 100644 --- a/sdv/sequential/par.py +++ b/sdv/sequential/par.py @@ -3,6 +3,7 @@ import inspect import logging import uuid +import warnings import numpy as np import pandas as pd @@ -13,6 +14,8 @@ from sdv._utils import _cast_to_iterable, _groupby_list from sdv.errors import SamplingError, SynthesizerInputError +from sdv.metadata.errors import InvalidMetadataError +from sdv.metadata.metadata import Metadata from sdv.metadata.single_table import SingleTableMetadata from sdv.sampling import Condition from sdv.single_table import GaussianCopulaSynthesizer @@ -93,6 +96,9 @@ def __init__( cuda=True, verbose=False, ): + if type(metadata) is Metadata and len(metadata.tables) > 1: + raise InvalidMetadataError('PARSynthesizer can only be used with a single table.') + super().__init__( metadata=metadata, enforce_min_max_values=enforce_min_max_values, @@ -121,11 +127,13 @@ def __init__( 'verbose': verbose, } context_metadata = self._get_context_metadata() - self._context_synthesizer = GaussianCopulaSynthesizer( - metadata=context_metadata, - enforce_min_max_values=enforce_min_max_values, - enforce_rounding=enforce_rounding, - ) + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', message=".*The 'SingleTableMetadata' is deprecated.*") + self._context_synthesizer = GaussianCopulaSynthesizer( + metadata=context_metadata, + enforce_min_max_values=enforce_min_max_values, + enforce_rounding=enforce_rounding, + ) def get_parameters(self): """Return the parameters used to instantiate the synthesizer.""" @@ -350,11 +358,13 @@ def _fit_context_model(self, transformed): if pd.api.types.is_numeric_dtype(context[column]): context_metadata.update_column(column, sdtype='numerical') - self._context_synthesizer = GaussianCopulaSynthesizer( - context_metadata, - enforce_min_max_values=self._context_synthesizer.enforce_min_max_values, - enforce_rounding=self._context_synthesizer.enforce_rounding, - ) + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', message=".*The 'SingleTableMetadata' is deprecated.*") + self._context_synthesizer = GaussianCopulaSynthesizer( + context_metadata, + enforce_min_max_values=self._context_synthesizer.enforce_min_max_values, + enforce_rounding=self._context_synthesizer.enforce_rounding, + ) context = context.groupby(self._sequence_key).first().reset_index() self._context_synthesizer.fit(context) diff --git a/tests/unit/sequential/test_par.py b/tests/unit/sequential/test_par.py index 966939869..d7dae5832 100644 --- a/tests/unit/sequential/test_par.py +++ b/tests/unit/sequential/test_par.py @@ -1075,9 +1075,7 @@ def test___init__with_unified_metadata(self): # Run and Assert PARSynthesizer(metadata) - error_msg = re.escape( - 'Metadata contains more than one table, use a MultiTableSynthesizer instead.' - ) + error_msg = re.escape('PARSynthesizer can only be used with a single table.') with pytest.raises(InvalidMetadataError, match=error_msg): PARSynthesizer(multi_metadata) From d3ce85759491ca3d113b5987f0cbf684b4a1d4f0 Mon Sep 17 00:00:00 2001 From: John La Date: Tue, 3 Sep 2024 16:05:40 -0500 Subject: [PATCH 09/16] Add warnings that were suggested from metadata bughunt (#2203) --- sdv/metadata/metadata.py | 4 +++ sdv/metadata/multi_table.py | 11 +++++-- sdv/multi_table/base.py | 7 +++-- tests/unit/metadata/test_multi_table.py | 39 +++++++++++++++++++++++++ 4 files changed, 55 insertions(+), 6 deletions(-) diff --git a/sdv/metadata/metadata.py b/sdv/metadata/metadata.py index fc3942a26..537d8bba2 100644 --- a/sdv/metadata/metadata.py +++ b/sdv/metadata/metadata.py @@ -70,6 +70,10 @@ def _set_metadata_dict(self, metadata, single_table_name=None): else: if single_table_name is None: single_table_name = self.DEFAULT_SINGLE_TABLE_NAME + warnings.warn( + 'No table name was provided to metadata containing only one table. ' + f'Assigning name: {single_table_name}' + ) self.tables[single_table_name] = SingleTableMetadata.load_from_dict(metadata) def _get_single_table_name(self): diff --git a/sdv/metadata/multi_table.py b/sdv/metadata/multi_table.py index 61be08db8..cb38bdb12 100644 --- a/sdv/metadata/multi_table.py +++ b/sdv/metadata/multi_table.py @@ -409,6 +409,8 @@ def update_columns(self, table_name, column_names, **kwargs): **kwargs: Any key word arguments that describe metadata for the columns. """ + if not isinstance(column_names, list): + raise InvalidMetadataError('Please pass in a list to column_names arg.') self._validate_table_exists(table_name) table = self.tables.get(table_name) table.update_columns(column_names, **kwargs) @@ -832,8 +834,8 @@ def validate_data(self, data): * all foreign keys belong to a primay key Args: - data (pd.DataFrame): - The data to validate. + data (dict): + A dictionary of table names to pd.DataFrames. Raises: InvalidDataError: @@ -843,6 +845,9 @@ def validate_data(self, data): A warning is being raised if ``datetime_format`` is missing from a column represented as ``object`` in the dataframe and its sdtype is ``datetime``. """ + if not isinstance(data, dict): + raise InvalidMetadataError('Please pass in a dictionary mapping tables to dataframes.') + errors = [] errors += self._validate_missing_tables(data) errors += self._validate_all_tables(data) @@ -880,7 +885,7 @@ def get_column_names(self, table_name, **kwargs): Args: table_name (str): - The name of the table to get column names for.s + The name of the table to get column names for. **kwargs: Metadata keywords to filter on, for example sdtype='id' or pii=True. diff --git a/sdv/multi_table/base.py b/sdv/multi_table/base.py index c2f09d6b9..0588936cf 100644 --- a/sdv/multi_table/base.py +++ b/sdv/multi_table/base.py @@ -76,10 +76,11 @@ def _set_temp_numpy_seed(self): def _initialize_models(self): with disable_single_table_logger(): for table_name, table_metadata in self.metadata.tables.items(): - synthesizer_parameters = {'locales': self.locales} - synthesizer_parameters.update(self._table_parameters.get(table_name, {})) + synthesizer_parameters = self._table_parameters.get(table_name, {}) + metadata_dict = {'tables': {table_name: table_metadata.to_dict()}} + metadata = Metadata.load_from_dict(metadata_dict) self._table_synthesizers[table_name] = self._synthesizer( - metadata=table_metadata, **synthesizer_parameters + metadata=metadata, **synthesizer_parameters ) self._table_synthesizers[table_name]._data_processor.table_name = table_name diff --git a/tests/unit/metadata/test_multi_table.py b/tests/unit/metadata/test_multi_table.py index 43c9941a6..cd20b7434 100644 --- a/tests/unit/metadata/test_multi_table.py +++ b/tests/unit/metadata/test_multi_table.py @@ -3160,3 +3160,42 @@ def test_anonymize(self, mock_load): 'parent_primary_key': 'col1', 'child_foreign_key': 'col2', } + + def test_update_columns_no_list_error(self): + """Test that ``update_columns`` only takes in list and that an error is thrown.""" + # Setup + metadata = MultiTableMetadata() + metadata.add_table('table') + metadata.add_column('table', 'col1', sdtype='numerical') + + error_msg = re.escape('Please pass in a list to column_names arg.') + # Run and Assert + with pytest.raises(InvalidMetadataError, match=error_msg): + metadata.update_columns('table', 'col1', sdtype='categorical') + + def test_validate_data_without_dict(self): + """Test that ``validate_data`` only takes in dict and that an error is thrown otherwise.""" + # Setup + metadata = MultiTableMetadata.load_from_dict({ + 'tables': { + 'table_1': { + 'columns': { + 'col_1': {'sdtype': 'numerical'}, + 'col_2': {'sdtype': 'categorical'}, + 'latitude': {'sdtype': 'latitude'}, + 'longitude': {'sdtype': 'longitude'}, + } + } + } + }) + data = pd.DataFrame({ + 'col_1': [1, 2, 3], + 'col_2': ['a', 'b', 'c'], + 'latitude': [1, 2, 3], + 'longitude': [1, 2, 3], + }) + error_msg = re.escape('Please pass in a dictionary mapping tables to dataframes.') + + # Run and Assert + with pytest.raises(InvalidMetadataError, match=error_msg): + metadata.validate_data(data) From b977f0e6098767950dd19af78b48d0e20fddf17c Mon Sep 17 00:00:00 2001 From: lajohn4747 Date: Wed, 11 Sep 2024 11:33:01 -0500 Subject: [PATCH 10/16] Update documentation for create_metadata --- sdv/io/local/local.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sdv/io/local/local.py b/sdv/io/local/local.py index ca827eddf..d7ac71b09 100644 --- a/sdv/io/local/local.py +++ b/sdv/io/local/local.py @@ -25,8 +25,8 @@ def create_metadata(self, data): Dictionary of table names to dataframes. Returns: - MultiTableMetadata: - An ``sdv.metadata.MultiTableMetadata`` object with the detected metadata + Metadata: + An ``sdv.metadata.Metadata`` object with the detected metadata properties from the data. """ metadata = Metadata() From ef3db2f7ba0ad95d4f6aa8761775c23da360ff89 Mon Sep 17 00:00:00 2001 From: Andrew Montanez Date: Thu, 12 Sep 2024 10:03:22 -0500 Subject: [PATCH 11/16] SDV - Add a Metadata.detect_from_dataframe function (#2222) --- sdv/metadata/metadata.py | 24 ++++++++++++++++ tests/integration/metadata/test_metadata.py | 32 ++++++++++++++++++++- tests/unit/metadata/test_metadata.py | 31 ++++++++++++++++++-- 3 files changed, 84 insertions(+), 3 deletions(-) diff --git a/sdv/metadata/metadata.py b/sdv/metadata/metadata.py index 537d8bba2..8473dd3a0 100644 --- a/sdv/metadata/metadata.py +++ b/sdv/metadata/metadata.py @@ -2,6 +2,8 @@ import warnings +import pandas as pd + from sdv.metadata.errors import InvalidMetadataError from sdv.metadata.multi_table import MultiTableMetadata from sdv.metadata.single_table import SingleTableMetadata @@ -51,6 +53,28 @@ def load_from_dict(cls, metadata_dict, single_table_name=None): instance._set_metadata_dict(metadata_dict, single_table_name) return instance + @classmethod + def detect_from_dataframe(cls, data, table_name=DEFAULT_SINGLE_TABLE_NAME): + """Detect the metadata for a DataFrame. + + This method automatically detects the ``sdtypes`` for the given ``pandas.DataFrame``. + All data column names are converted to strings. + + Args: + data (pandas.DataFrame): + Dictionary of table names to dataframes. + + Returns: + Metadata: + A new metadata object with the sdtypes detected from the data. + """ + if not isinstance(data, pd.DataFrame): + raise ValueError('The provided data must be a pandas DataFrame object.') + + metadata = Metadata() + metadata.detect_table_from_dataframe(table_name, data) + return metadata + def _set_metadata_dict(self, metadata, single_table_name=None): """Set a ``metadata`` dictionary to the current instance. diff --git a/tests/integration/metadata/test_metadata.py b/tests/integration/metadata/test_metadata.py index 7efe8421d..d83544896 100644 --- a/tests/integration/metadata/test_metadata.py +++ b/tests/integration/metadata/test_metadata.py @@ -85,7 +85,7 @@ def test_detect_from_dataframes_multi_table(): assert metadata.to_dict() == expected_metadata -def test_detect_from_data_frames_single_table(): +def test_detect_from_dataframes_single_table(): """Test the ``detect_from_dataframes`` method works with a single table.""" # Setup data, _ = download_demo(modality='multi_table', dataset_name='fake_hotels') @@ -116,6 +116,36 @@ def test_detect_from_data_frames_single_table(): assert metadata.to_dict() == expected_metadata +def test_detect_from_dataframe(): + """Test that a single table can be detected as a DataFrame.""" + # Setup + data, _ = download_demo(modality='multi_table', dataset_name='fake_hotels') + + metadata = Metadata.detect_from_dataframe(data['hotels']) + + # Run + metadata.validate() + + # Assert + expected_metadata = { + 'METADATA_SPEC_VERSION': 'V1', + 'tables': { + DEFAULT_TABLE_NAME: { + 'columns': { + 'hotel_id': {'sdtype': 'id'}, + 'city': {'sdtype': 'city', 'pii': True}, + 'state': {'sdtype': 'administrative_unit', 'pii': True}, + 'rating': {'sdtype': 'numerical'}, + 'classification': {'sdtype': 'unknown', 'pii': True}, + }, + 'primary_key': 'hotel_id', + } + }, + 'relationships': [], + } + assert metadata.to_dict() == expected_metadata + + def test_detect_from_csvs(tmp_path): """Test the ``detect_from_csvs`` method.""" # Setup diff --git a/tests/unit/metadata/test_metadata.py b/tests/unit/metadata/test_metadata.py index da2046596..64c55999d 100644 --- a/tests/unit/metadata/test_metadata.py +++ b/tests/unit/metadata/test_metadata.py @@ -1,9 +1,10 @@ -from unittest.mock import patch +from unittest.mock import Mock, patch +import pandas as pd import pytest from sdv.metadata.metadata import Metadata -from tests.utils import get_multi_table_data, get_multi_table_metadata +from tests.utils import DataFrameMatcher, get_multi_table_data, get_multi_table_metadata class TestMetadataClass: @@ -537,3 +538,29 @@ def test_validate_data_no_relationships(self): # Run and Assert metadata.validate_data(data) assert metadata.METADATA_SPEC_VERSION == 'V1' + + @patch('sdv.metadata.metadata.Metadata') + def test_detect_from_dataframe(self, mock_metadata): + """Test that the method calls the detection method and returns the metadata. + + Expected to call ``detect_table_from_dataframe`` for the dataframe. + """ + # Setup + mock_metadata.detect_table_from_dataframe = Mock() + data = pd.DataFrame() + + # Run + metadata = Metadata.detect_from_dataframe(data) + + # Assert + mock_metadata.return_value.detect_table_from_dataframe.assert_any_call( + Metadata.DEFAULT_SINGLE_TABLE_NAME, DataFrameMatcher(data) + ) + assert metadata == mock_metadata.return_value + + def test_detect_from_dataframe_raises_error_if_not_dataframe(self): + """Test that the method raises an error if data isn't a DataFrame.""" + # Run and assert + expected_message = 'The provided data must be a pandas DataFrame object.' + with pytest.raises(ValueError, match=expected_message): + Metadata.detect_from_dataframe(Mock()) From 097d5f33b3a60a780b1f31d4879c5ddd299288f2 Mon Sep 17 00:00:00 2001 From: Andrew Montanez Date: Thu, 12 Sep 2024 16:03:54 -0500 Subject: [PATCH 12/16] Improve usage of `detect_from_dataframes` function (#2221) --- sdv/io/local/local.py | 3 +- sdv/metadata/metadata.py | 25 ++++++++++++ .../data_processing/test_data_processor.py | 6 +-- tests/integration/lite/test_single_table.py | 9 ++--- tests/integration/metadata/test_metadata.py | 7 +--- .../metadata/test_visualization.py | 6 +-- tests/integration/multi_table/test_hma.py | 24 ++++------- tests/integration/sequential/test_par.py | 12 ++---- tests/integration/single_table/test_base.py | 21 ++++------ .../single_table/test_constraints.py | 9 ++--- .../integration/single_table/test_copulas.py | 9 ++--- tests/unit/metadata/test_metadata.py | 40 +++++++++++++++++++ tests/unit/single_table/test_copulas.py | 3 +- 13 files changed, 101 insertions(+), 73 deletions(-) diff --git a/sdv/io/local/local.py b/sdv/io/local/local.py index d7ac71b09..bf32f8294 100644 --- a/sdv/io/local/local.py +++ b/sdv/io/local/local.py @@ -29,8 +29,7 @@ def create_metadata(self, data): An ``sdv.metadata.Metadata`` object with the detected metadata properties from the data. """ - metadata = Metadata() - metadata.detect_from_dataframes(data) + metadata = Metadata.detect_from_dataframes(data) return metadata def read(self): diff --git a/sdv/metadata/metadata.py b/sdv/metadata/metadata.py index 8473dd3a0..cbfd0a1b5 100644 --- a/sdv/metadata/metadata.py +++ b/sdv/metadata/metadata.py @@ -53,6 +53,31 @@ def load_from_dict(cls, metadata_dict, single_table_name=None): instance._set_metadata_dict(metadata_dict, single_table_name) return instance + @classmethod + def detect_from_dataframes(cls, data): + """Detect the metadata for all tables in a dictionary of dataframes. + + This method automatically detects the ``sdtypes`` for the given ``pandas.DataFrames``. + All data column names are converted to strings. + + Args: + data (dict): + Dictionary of table names to dataframes. + + Returns: + Metadata: + A new metadata object with the sdtypes detected from the data. + """ + if not data or not all(isinstance(df, pd.DataFrame) for df in data.values()): + raise ValueError('The provided dictionary must contain only pandas DataFrame objects.') + + metadata = Metadata() + for table_name, dataframe in data.items(): + metadata.detect_table_from_dataframe(table_name, dataframe) + + metadata._detect_relationships(data) + return metadata + @classmethod def detect_from_dataframe(cls, data, table_name=DEFAULT_SINGLE_TABLE_NAME): """Detect the metadata for a DataFrame. diff --git a/tests/integration/data_processing/test_data_processor.py b/tests/integration/data_processing/test_data_processor.py index dba25709b..65f865d2f 100644 --- a/tests/integration/data_processing/test_data_processor.py +++ b/tests/integration/data_processing/test_data_processor.py @@ -156,8 +156,7 @@ def test_with_primary_key_numerical(self): """ # Load metadata and data data, _ = download_demo('single_table', 'adult') - adult_metadata = Metadata() - adult_metadata.detect_from_dataframes({'adult': data}) + adult_metadata = Metadata.detect_from_dataframes({'adult': data}) # Add primary key field adult_metadata.add_column('adult', 'id', sdtype='id') @@ -196,8 +195,7 @@ def test_with_alternate_keys(self): # Load metadata and data data, _ = download_demo('single_table', 'adult') data['fnlwgt'] = data['fnlwgt'].astype(str) - adult_metadata = Metadata() - adult_metadata.detect_from_dataframes({'adult': data}) + adult_metadata = Metadata.detect_from_dataframes({'adult': data}) # Add primary key field adult_metadata.add_column('adult', 'id', sdtype='id') diff --git a/tests/integration/lite/test_single_table.py b/tests/integration/lite/test_single_table.py index ad6a42aea..e0f07a7bb 100644 --- a/tests/integration/lite/test_single_table.py +++ b/tests/integration/lite/test_single_table.py @@ -12,8 +12,7 @@ def test_sample(): data = pd.DataFrame({'a': [1, 2, 3, np.nan]}) # Run - metadata = Metadata() - metadata.detect_from_dataframes({'adult': data}) + metadata = Metadata.detect_from_dataframes({'adult': data}) preset = SingleTablePreset(metadata, name='FAST_ML') preset.fit(data) samples = preset.sample(num_rows=10, max_tries_per_batch=20, batch_size=5) @@ -29,8 +28,7 @@ def test_sample_with_constraints(): data = pd.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]}) # Run - metadata = Metadata() - metadata.detect_from_dataframes({'table': data}) + metadata = Metadata.detect_from_dataframes({'table': data}) preset = SingleTablePreset(metadata, name='FAST_ML') constraints = [ { @@ -57,8 +55,7 @@ def test_warnings_are_shown(): data = pd.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]}) # Run - metadata = Metadata() - metadata.detect_from_dataframes({'table': data}) + metadata = Metadata.detect_from_dataframes({'table': data}) with pytest.warns(FutureWarning, match=warn_message): preset = SingleTablePreset(metadata, name='FAST_ML') diff --git a/tests/integration/metadata/test_metadata.py b/tests/integration/metadata/test_metadata.py index d83544896..ffc2b1bc8 100644 --- a/tests/integration/metadata/test_metadata.py +++ b/tests/integration/metadata/test_metadata.py @@ -32,10 +32,8 @@ def test_detect_from_dataframes_multi_table(): # Setup real_data, _ = download_demo(modality='multi_table', dataset_name='fake_hotels') - metadata = Metadata() - # Run - metadata.detect_from_dataframes(real_data) + metadata = Metadata.detect_from_dataframes(real_data) # Assert metadata.update_column( @@ -90,8 +88,7 @@ def test_detect_from_dataframes_single_table(): # Setup data, _ = download_demo(modality='multi_table', dataset_name='fake_hotels') - metadata = Metadata() - metadata.detect_from_dataframes({'table_1': data['hotels']}) + metadata = Metadata.detect_from_dataframes({'table_1': data['hotels']}) # Run metadata.validate() diff --git a/tests/integration/metadata/test_visualization.py b/tests/integration/metadata/test_visualization.py index eb8c6d657..2d801ea5d 100644 --- a/tests/integration/metadata/test_visualization.py +++ b/tests/integration/metadata/test_visualization.py @@ -9,8 +9,7 @@ def test_visualize_graph_for_single_table(): """Test it runs when a column name contains symbols.""" # Setup data = pd.DataFrame({'\\|=/bla@#$324%^,"&*()><...': ['a', 'b', 'c']}) - metadata = Metadata() - metadata.detect_from_dataframes({'table': data}) + metadata = Metadata.detect_from_dataframes({'table': data}) model = GaussianCopulaSynthesizer(metadata) # Run @@ -26,8 +25,7 @@ def test_visualize_graph_for_multi_table(): data1 = pd.DataFrame({'\\|=/bla@#$324%^,"&*()><...': ['a', 'b', 'c']}) data2 = pd.DataFrame({'\\|=/bla@#$324%^,"&*()><...': ['a', 'b', 'c']}) tables = {'1': data1, '2': data2} - metadata = Metadata() - metadata.detect_from_dataframes(tables) + metadata = Metadata.detect_from_dataframes(tables) metadata.update_column('1', '\\|=/bla@#$324%^,"&*()><...', sdtype='id') metadata.update_column('2', '\\|=/bla@#$324%^,"&*()><...', sdtype='id') metadata.set_primary_key('1', '\\|=/bla@#$324%^,"&*()><...') diff --git a/tests/integration/multi_table/test_hma.py b/tests/integration/multi_table/test_hma.py index 4393481e3..5efc300bb 100644 --- a/tests/integration/multi_table/test_hma.py +++ b/tests/integration/multi_table/test_hma.py @@ -1286,8 +1286,7 @@ def test_metadata_updated_no_warning(self, tmp_path): assert len(captured_warnings) == 0 # Run 2 - metadata_detect = Metadata() - metadata_detect.detect_from_dataframes(data) + metadata_detect = Metadata.detect_from_dataframes(data) metadata_detect.relationships = metadata.relationships for table_name, table_metadata in metadata.tables.items(): @@ -1326,8 +1325,7 @@ def test_metadata_updated_warning_detect(self): """ # Setup data, metadata = download_demo('multi_table', 'got_families') - metadata_detect = Metadata() - metadata_detect.detect_from_dataframes(data) + metadata_detect = Metadata.detect_from_dataframes(data) metadata_detect.relationships = metadata.relationships for table_name, table_metadata in metadata.tables.items(): @@ -1455,8 +1453,7 @@ def test_sampling_with_unknown_sdtype_numerical_column(self): tables_dict = {'people': table1, 'company': table2} - metadata = Metadata() - metadata.detect_from_dataframes(tables_dict) + metadata = Metadata.detect_from_dataframes(tables_dict) # Run synth = HMASynthesizer(metadata) @@ -1889,8 +1886,7 @@ def test_detect_from_dataframe_numerical_col(): child_table_name='child_data', ) - test_metadata = Metadata() - test_metadata.detect_from_dataframes(data) + test_metadata = Metadata.detect_from_dataframes(data) test_metadata.update_column('parent_data', '1', sdtype='id') test_metadata.update_column('child_data', '3', sdtype='id') test_metadata.update_column('child_data', '4', sdtype='id') @@ -1913,8 +1909,7 @@ def test_detect_from_dataframe_numerical_col(): assert sample['parent_data'].columns.tolist() == data['parent_data'].columns.tolist() assert sample['child_data'].columns.tolist() == data['child_data'].columns.tolist() - test_metadata = Metadata() - test_metadata.detect_from_dataframes(data) + test_metadata = Metadata.detect_from_dataframes(data) def test_table_name_logging(caplog): @@ -1929,8 +1924,7 @@ def test_table_name_logging(caplog): 'parent_data': parent_data, 'child_data': child_data, } - metadata = Metadata() - metadata.detect_from_dataframes(data) + metadata = Metadata.detect_from_dataframes(data) instance = HMASynthesizer(metadata) # Run @@ -2008,8 +2002,7 @@ def test_hma_synthesizer_with_fixed_combinations(): } # Creating metadata for the dataset - metadata = Metadata() - metadata.detect_from_dataframes(data) + metadata = Metadata.detect_from_dataframes(data) metadata.update_column('users', 'user_id', sdtype='id') metadata.update_column('records', 'record_id', sdtype='id') @@ -2058,8 +2051,7 @@ def test_fit_int_primary_key_regex_includes_zero(regex): 'parent_data': parent_data, 'child_data': child_data, } - metadata = Metadata() - metadata.detect_from_dataframes(data) + metadata = Metadata.detect_from_dataframes(data) metadata.update_column('parent_data', 'parent_id', sdtype='id', regex_format=regex) metadata.set_primary_key('parent_data', 'parent_id') diff --git a/tests/integration/sequential/test_par.py b/tests/integration/sequential/test_par.py index c9628a429..6ae0af54e 100644 --- a/tests/integration/sequential/test_par.py +++ b/tests/integration/sequential/test_par.py @@ -21,8 +21,7 @@ def _get_par_data_and_metadata(): 'entity': [1, 1, 2, 2], 'context': ['a', 'a', 'b', 'b'], }) - metadata = Metadata() - metadata.detect_from_dataframes({'table': data}) + metadata = Metadata.detect_from_dataframes({'table': data}) metadata.update_column('table', 'entity', sdtype='id') metadata.set_sequence_key('table', 'entity') metadata.set_sequence_index('table', 'date') @@ -34,8 +33,7 @@ def test_par(): # Setup data = load_demo() data['date'] = pd.to_datetime(data['date']) - metadata = Metadata() - metadata.detect_from_dataframes({'table': data}) + metadata = Metadata.detect_from_dataframes({'table': data}) metadata.update_column('table', 'store_id', sdtype='id') metadata.set_sequence_key('table', 'store_id') metadata.set_sequence_index('table', 'date') @@ -68,8 +66,7 @@ def test_column_after_date_simple(): 'date': [date, date], 'col2': ['hello', 'world'], }) - metadata = Metadata() - metadata.detect_from_dataframes({'table': data}) + metadata = Metadata.detect_from_dataframes({'table': data}) metadata.update_column('table', 'col', sdtype='id') metadata.set_sequence_key('table', 'col') metadata.set_sequence_index('table', 'date') @@ -348,8 +345,7 @@ def test_par_unique_sequence_index_with_enforce_min_max(): test_df[['visits', 'pre_date']] = test_df[['visits', 'pre_date']].apply( pd.to_datetime, format='%Y-%m-%d', errors='coerce' ) - metadata = Metadata() - metadata.detect_from_dataframes({'table': test_df}) + metadata = Metadata.detect_from_dataframes({'table': test_df}) metadata.update_column(table_name='table', column_name='s_key', sdtype='id') metadata.set_sequence_key('table', 's_key') metadata.set_sequence_index('table', 'visits') diff --git a/tests/integration/single_table/test_base.py b/tests/integration/single_table/test_base.py index 5f99abf93..7ab9bb27c 100644 --- a/tests/integration/single_table/test_base.py +++ b/tests/integration/single_table/test_base.py @@ -313,10 +313,9 @@ def test_config_creation_doesnt_raise_error(): 'address_col': ['223 Williams Rd', '75 Waltham St', '77 Mass Ave'], 'numerical_col': [1, 2, 3], }) - test_metadata = Metadata() # Run - test_metadata.detect_from_dataframes({'table': test_data}) + test_metadata = Metadata.detect_from_dataframes({'table': test_data}) test_metadata.update_column( table_name='table', column_name='address_col', sdtype='address', pii=False ) @@ -335,8 +334,7 @@ def test_transformers_correctly_auto_assigned(): 'categorical_col': ['a', 'b', 'a'], }) - metadata = Metadata() - metadata.detect_from_dataframes({'table': data}) + metadata = Metadata.detect_from_dataframes({'table': data}) metadata.update_column( table_name='table', column_name='primary_key', sdtype='id', regex_format='user-[0-9]{3}' ) @@ -425,8 +423,7 @@ def test_auto_assign_transformers_and_update_with_pii(): } ) - metadata = Metadata() - metadata.detect_from_dataframes({'table': data}) + metadata = Metadata.detect_from_dataframes({'table': data}) # Run metadata.update_column(table_name='table', column_name='id', sdtype='first_name') @@ -458,8 +455,7 @@ def test_refitting_a_model(): } ) - metadata = Metadata() - metadata.detect_from_dataframes({'table': data}) + metadata = Metadata.detect_from_dataframes({'table': data}) metadata.update_column(table_name='table', column_name='name', sdtype='name') metadata.update_column('table', 'id', sdtype='id') metadata.set_primary_key('table', 'id') @@ -619,8 +615,7 @@ def test_metadata_updated_no_warning(mock__fit, tmp_path): assert len(captured_warnings) == 0 # Run 2 - metadata_detect = Metadata() - metadata_detect.detect_from_dataframes({'mock_table': data}) + metadata_detect = Metadata.detect_from_dataframes({'mock_table': data}) file_name = tmp_path / 'singletable.json' metadata_detect.save_to_json(file_name) with warnings.catch_warnings(record=True) as captured_warnings: @@ -737,8 +732,7 @@ def test_fit_raises_version_error(): 'col 2': [4, 5, 6], 'col 3': ['a', 'b', 'c'], }) - metadata = Metadata() - metadata.detect_from_dataframes({'table': data}) + metadata = Metadata.detect_from_dataframes({'table': data}) instance = BaseSingleTableSynthesizer(metadata) instance._fitted_sdv_version = '1.0.0' @@ -813,10 +807,9 @@ def test_detect_from_dataframe_numerical_col(synthesizer_class): 2: [4, 5, 6], 3: ['a', 'b', 'c'], }) - metadata = Metadata() # Run - metadata.detect_from_dataframes({'table': data}) + metadata = Metadata.detect_from_dataframes({'table': data}) instance = synthesizer_class(metadata) instance.fit(data) sample = instance.sample(5) diff --git a/tests/integration/single_table/test_constraints.py b/tests/integration/single_table/test_constraints.py index e42cdb227..3477d4557 100644 --- a/tests/integration/single_table/test_constraints.py +++ b/tests/integration/single_table/test_constraints.py @@ -339,8 +339,7 @@ def test_custom_constraints_from_file(tmpdir): 'categorical_col': ['a', 'b', 'a'], }) - metadata = Metadata() - metadata.detect_from_dataframes({'table': data}) + metadata = Metadata.detect_from_dataframes({'table': data}) metadata.update_column(table_name='table', column_name='pii_col', sdtype='address', pii=True) synthesizer = GaussianCopulaSynthesizer( metadata, enforce_min_max_values=False, enforce_rounding=False @@ -383,8 +382,7 @@ def test_custom_constraints_from_object(tmpdir): 'categorical_col': ['a', 'b', 'a'], }) - metadata = Metadata() - metadata.detect_from_dataframes({'table': data}) + metadata = Metadata.detect_from_dataframes({'table': data}) metadata.update_column(table_name='table', column_name='pii_col', sdtype='address', pii=True) synthesizer = GaussianCopulaSynthesizer( metadata, enforce_min_max_values=False, enforce_rounding=False @@ -816,8 +814,7 @@ def reverse_transform(column_names, data): 'number': ['1', '2', '3'], 'other': [7, 8, 9], }) - metadata = Metadata() - metadata.detect_from_dataframes({'table': data}) + metadata = Metadata.detect_from_dataframes({'table': data}) metadata.update_column('table', 'key', sdtype='id', regex_format=r'\w_\d') metadata.set_primary_key('table', 'key') synth = GaussianCopulaSynthesizer(metadata) diff --git a/tests/integration/single_table/test_copulas.py b/tests/integration/single_table/test_copulas.py index 5e46a8417..6a25fe5dc 100644 --- a/tests/integration/single_table/test_copulas.py +++ b/tests/integration/single_table/test_copulas.py @@ -277,8 +277,7 @@ def test_update_transformers_with_id_generator(): sample_num = 20 data = pd.DataFrame({'user_id': list(range(4)), 'user_cat': ['a', 'b', 'c', 'd']}) - stm = Metadata() - stm.detect_from_dataframes({'table': data}) + stm = Metadata.detect_from_dataframes({'table': data}) stm.update_column('table', 'user_id', sdtype='id') stm.set_primary_key('table', 'user_id') @@ -406,8 +405,7 @@ def test_categorical_column_with_numbers(): 'numerical_col': np.random.rand(20), }) - metadata = Metadata() - metadata.detect_from_dataframes({'table': data}) + metadata = Metadata.detect_from_dataframes({'table': data}) synthesizer = GaussianCopulaSynthesizer(metadata) @@ -435,8 +433,7 @@ def test_unknown_sdtype(): 'numerical_col': np.random.rand(3), }) - metadata = Metadata() - metadata.detect_from_dataframes({'table': data}) + metadata = Metadata.detect_from_dataframes({'table': data}) metadata.update_column('table', 'unknown', sdtype='unknown') synthesizer = GaussianCopulaSynthesizer(metadata) diff --git a/tests/unit/metadata/test_metadata.py b/tests/unit/metadata/test_metadata.py index 64c55999d..68147fc5c 100644 --- a/tests/unit/metadata/test_metadata.py +++ b/tests/unit/metadata/test_metadata.py @@ -539,6 +539,46 @@ def test_validate_data_no_relationships(self): metadata.validate_data(data) assert metadata.METADATA_SPEC_VERSION == 'V1' + @patch('sdv.metadata.metadata.Metadata') + def test_detect_from_dataframes(self, mock_metadata): + """Test ``detect_from_dataframes``. + + Expected to call ``detect_table_from_dataframe`` for each table name and dataframe + in the input. + """ + # Setup + mock_metadata.detect_table_from_dataframe = Mock() + mock_metadata._detect_relationships = Mock() + guests_table = pd.DataFrame() + hotels_table = pd.DataFrame() + data = {'guests': guests_table, 'hotels': hotels_table} + + # Run + metadata = Metadata.detect_from_dataframes(data) + + # Assert + mock_metadata.return_value.detect_table_from_dataframe.assert_any_call( + 'guests', guests_table + ) + mock_metadata.return_value.detect_table_from_dataframe.assert_any_call( + 'hotels', hotels_table + ) + mock_metadata.return_value._detect_relationships.assert_called_once_with(data) + assert metadata == mock_metadata.return_value + + def test_detect_from_dataframes_bad_input(self): + """Test that an error is raised if the dictionary contains something other than DataFrames. + + If the data contains values that aren't pandas.DataFrames, it should error. + """ + # Setup + data = {'guests': Mock(), 'hotels': Mock()} + + # Run and Assert + expected_message = 'The provided dictionary must contain only pandas DataFrame objects.' + with pytest.raises(ValueError, match=expected_message): + Metadata.detect_from_dataframes(data) + @patch('sdv.metadata.metadata.Metadata') def test_detect_from_dataframe(self, mock_metadata): """Test that the method calls the detection method and returns the metadata. diff --git a/tests/unit/single_table/test_copulas.py b/tests/unit/single_table/test_copulas.py index ce4300a12..be523196d 100644 --- a/tests/unit/single_table/test_copulas.py +++ b/tests/unit/single_table/test_copulas.py @@ -474,8 +474,7 @@ def test_get_learned_distributions(self): """ # Setup data = pd.DataFrame({'zero': [0, 0, 0], 'one': [1, 1, 1]}) - stm = Metadata() - stm.detect_from_dataframes({'table': data}) + stm = Metadata.detect_from_dataframes({'table': data}) gcs = GaussianCopulaSynthesizer(stm, numerical_distributions={'one': 'uniform'}) gcs.fit(data) From 573ae4c9fd81334c1ecf9f8e9669b03b227c9245 Mon Sep 17 00:00:00 2001 From: R-Palazzo <116157184+R-Palazzo@users.noreply.github.com> Date: Fri, 13 Sep 2024 09:21:24 +0200 Subject: [PATCH 13/16] Add a warning if you're loading a SingleTableMetadata object (#2224) --- sdv/metadata/metadata.py | 8 +++ tests/integration/metadata/test_metadata.py | 35 +++++++++++++ tests/unit/metadata/test_metadata.py | 56 +++++++++++++-------- 3 files changed, 77 insertions(+), 22 deletions(-) diff --git a/sdv/metadata/metadata.py b/sdv/metadata/metadata.py index cbfd0a1b5..d32284626 100644 --- a/sdv/metadata/metadata.py +++ b/sdv/metadata/metadata.py @@ -32,6 +32,14 @@ def load_from_json(cls, filepath, single_table_name=None): A ``Metadata`` instance. """ metadata = read_json(filepath) + if metadata.get('METADATA_SPEC_VERSION') == 'SINGLE_TABLE_V1': + single_table_name = single_table_name or cls.DEFAULT_SINGLE_TABLE_NAME + warnings.warn( + 'You are loading an older SingleTableMetadata object. This will be converted into' + f" the new Metadata object with a placeholder table name ('{single_table_name}')." + ' Please save this new object for future usage.' + ) + return cls.load_from_dict(metadata, single_table_name) @classmethod diff --git a/tests/integration/metadata/test_metadata.py b/tests/integration/metadata/test_metadata.py index ffc2b1bc8..5467dfb49 100644 --- a/tests/integration/metadata/test_metadata.py +++ b/tests/integration/metadata/test_metadata.py @@ -27,6 +27,41 @@ def test_metadata(): assert instance.relationships == [] +def test_load_from_json_single_table_metadata(tmp_path): + """Test the ``load_from_json`` method with a single table metadata.""" + # Setup + old_metadata = SingleTableMetadata.load_from_dict({ + 'columns': { + 'column_1': {'sdtype': 'numerical'}, + 'column_2': {'sdtype': 'categorical'}, + }, + }) + old_metadata.save_to_json(tmp_path / 'metadata.json') + expected_warning = re.escape( + 'You are loading an older SingleTableMetadata object. This will be converted ' + f"into the new Metadata object with a placeholder table name ('{DEFAULT_TABLE_NAME}')." + ' Please save this new object for future usage.' + ) + + # Run + with pytest.warns(UserWarning, match=expected_warning): + metadata = Metadata.load_from_json(tmp_path / 'metadata.json') + + # Assert + assert metadata.to_dict() == { + 'tables': { + DEFAULT_TABLE_NAME: { + 'columns': { + 'column_1': {'sdtype': 'numerical'}, + 'column_2': {'sdtype': 'categorical'}, + }, + }, + }, + 'relationships': [], + 'METADATA_SPEC_VERSION': 'V1', + } + + def test_detect_from_dataframes_multi_table(): """Test the ``detect_from_dataframes`` method works with multi-table.""" # Setup diff --git a/tests/unit/metadata/test_metadata.py b/tests/unit/metadata/test_metadata.py index 68147fc5c..a69ada211 100644 --- a/tests/unit/metadata/test_metadata.py +++ b/tests/unit/metadata/test_metadata.py @@ -1,4 +1,5 @@ -from unittest.mock import Mock, patch +import re +from unittest.mock import Mock, call, patch import pandas as pd import pytest @@ -94,17 +95,12 @@ def test_load_from_json_path_does_not_exist(self, mock_path): with pytest.raises(ValueError, match=error_msg): Metadata.load_from_json('filepath.json') - @patch('sdv.metadata.utils.Path') - @patch('sdv.metadata.utils.json') - def test_load_from_json_single_table(self, mock_json, mock_path): - """Test the ``load_from_json`` method. - - Test that ``load_from_json`` function creates an instance with the contents returned by the - ``json`` load function when passing in a single table metadata json. + @patch('sdv.metadata.metadata.read_json') + def test_load_from_json_single_table(self, mock_read_json): + """Test the ``load_from_json`` method for single table metadata. Mock: - - Mock the ``Path`` library in order to return ``True``. - - Mock the ``json`` library in order to use a custom return. + - Mock the ``read_json`` function in order to return a custom json. Input: - String representing a filepath. @@ -114,26 +110,42 @@ def test_load_from_json_single_table(self, mock_json, mock_path): file (``json.load`` return value) """ # Setup - instance = Metadata() - mock_path.return_value.exists.return_value = True - mock_path.return_value.name = 'filepath.json' - mock_json.load.return_value = { + mock_read_json.return_value = { 'columns': {'animals': {'type': 'categorical'}}, 'primary_key': 'animals', 'METADATA_SPEC_VERSION': 'SINGLE_TABLE_V1', } + warning_message = ( + 'You are loading an older SingleTableMetadata object. This will be converted' + " into the new Metadata object with a placeholder table name ('{}')." + ' Please save this new object for future usage.' + ) + + expected_warning_with_table_name = re.escape(warning_message.format('filepath')) + expected_warning_without_table_name = re.escape( + warning_message.format('default_table_name') + ) # Run - instance = Metadata.load_from_json('filepath.json', 'filepath') + with pytest.warns(UserWarning, match=expected_warning_with_table_name): + instance_with_table_name = Metadata.load_from_json('filepath.json', 'filepath') + with pytest.warns(UserWarning, match=expected_warning_without_table_name): + instance_without_table_name = Metadata.load_from_json('filepath.json') # Assert - assert list(instance.tables.keys()) == ['filepath'] - assert instance.tables['filepath'].columns == {'animals': {'type': 'categorical'}} - assert instance.tables['filepath'].primary_key == 'animals' - assert instance.tables['filepath'].sequence_key is None - assert instance.tables['filepath'].alternate_keys == [] - assert instance.tables['filepath'].sequence_index is None - assert instance.tables['filepath']._version == 'SINGLE_TABLE_V1' + mock_read_json.assert_has_calls([call('filepath.json'), call('filepath.json')]) + table_name_to_instance = { + 'filepath': instance_with_table_name, + 'default_table_name': instance_without_table_name, + } + for table_name, instance in table_name_to_instance.items(): + assert list(instance.tables.keys()) == [table_name] + assert instance.tables[table_name].columns == {'animals': {'type': 'categorical'}} + assert instance.tables[table_name].primary_key == 'animals' + assert instance.tables[table_name].sequence_key is None + assert instance.tables[table_name].alternate_keys == [] + assert instance.tables[table_name].sequence_index is None + assert instance.tables[table_name]._version == 'SINGLE_TABLE_V1' @patch('sdv.metadata.utils.Path') @patch('sdv.metadata.utils.json') From 6dbffa723fcd37fdefb181cac81edb78866dae04 Mon Sep 17 00:00:00 2001 From: R-Palazzo <116157184+R-Palazzo@users.noreply.github.com> Date: Mon, 16 Sep 2024 17:13:29 +0200 Subject: [PATCH 14/16] Add `metadata.validate_table` method for single table usage (#2225) --- sdv/metadata/metadata.py | 22 ++++++++++++++++++++++ tests/unit/metadata/test_metadata.py | 27 +++++++++++++++++++++++++++ 2 files changed, 49 insertions(+) diff --git a/sdv/metadata/metadata.py b/sdv/metadata/metadata.py index d32284626..053ef5896 100644 --- a/sdv/metadata/metadata.py +++ b/sdv/metadata/metadata.py @@ -184,3 +184,25 @@ def set_sequence_key(self, table_name, column_name): """ self._validate_table_exists(table_name) self.tables[table_name].set_sequence_key(column_name) + + def validate_table(self, data, table_name=None): + """Validate a table against the metadata. + + Args: + data (pandas.DataFrame): + Data to validate. + table_name (str): + Name of the table to validate. + """ + if table_name is None: + with warnings.catch_warnings(): + warnings.simplefilter('ignore') + table_name = self._get_single_table_name() + + if not table_name: + raise InvalidMetadataError( + 'Metadata contains more than one table, please specify the `table_name` ' + 'to validate.' + ) + + return self.validate_data({table_name: data}) diff --git a/tests/unit/metadata/test_metadata.py b/tests/unit/metadata/test_metadata.py index a69ada211..1524c0757 100644 --- a/tests/unit/metadata/test_metadata.py +++ b/tests/unit/metadata/test_metadata.py @@ -4,6 +4,8 @@ import pandas as pd import pytest +from sdv.errors import InvalidDataError +from sdv.metadata.errors import InvalidMetadataError from sdv.metadata.metadata import Metadata from tests.utils import DataFrameMatcher, get_multi_table_data, get_multi_table_metadata @@ -551,6 +553,31 @@ def test_validate_data_no_relationships(self): metadata.validate_data(data) assert metadata.METADATA_SPEC_VERSION == 'V1' + def test_validate_table(self): + """Test the ``validate_table``method.""" + # Setup + metadata_multi_table = get_multi_table_metadata() + metadata_single_table = Metadata.load_from_dict( + metadata_multi_table.to_dict()['tables']['nesreca'], 'nesreca' + ) + table = get_multi_table_data()['nesreca'] + + expected_error_wrong_name = re.escape( + 'The provided data does not match the metadata:\n' + "The provided data is missing the tables {'nesreca'}." + ) + expected_error_mutli_table = re.escape( + 'Metadata contains more than one table, please specify the `table_name` to validate.' + ) + + # Run and Assert + metadata_single_table.validate_table(table) + metadata_single_table.validate_table(table, 'nesreca') + with pytest.raises(InvalidDataError, match=expected_error_wrong_name): + metadata_single_table.validate_table(table, 'wrong_name') + with pytest.raises(InvalidMetadataError, match=expected_error_mutli_table): + metadata_multi_table.validate_table(table) + @patch('sdv.metadata.metadata.Metadata') def test_detect_from_dataframes(self, mock_metadata): """Test ``detect_from_dataframes``. From 2a93baf3b165519df97621b1a590ac6a0da1cc7d Mon Sep 17 00:00:00 2001 From: R-Palazzo <116157184+R-Palazzo@users.noreply.github.com> Date: Wed, 18 Sep 2024 20:23:41 +0200 Subject: [PATCH 15/16] For single-table use cases, make it frictionless to update Metadata (#2228) --- sdv/metadata/metadata.py | 72 +++++++++- sdv/metadata/multi_table.py | 8 +- sdv/multi_table/base.py | 3 +- .../data_processing/test_data_processor.py | 24 ++-- .../evaluation/test_single_table.py | 2 +- tests/integration/metadata/test_metadata.py | 130 +++++++++++------- .../integration/metadata/test_multi_table.py | 12 +- .../metadata/test_visualization.py | 6 +- tests/integration/multi_table/test_hma.py | 102 +++++++------- tests/integration/sequential/test_par.py | 31 +++-- tests/integration/single_table/test_base.py | 32 ++--- .../single_table/test_constraints.py | 34 ++--- .../integration/single_table/test_copulas.py | 6 +- tests/integration/single_table/test_ctgan.py | 12 +- tests/unit/evaluation/test_single_table.py | 48 +++---- tests/unit/metadata/test_metadata.py | 51 +++++++ tests/unit/metadata/test_multi_table.py | 6 +- tests/unit/multi_table/test_base.py | 20 +-- tests/unit/multi_table/test_hma.py | 12 +- tests/unit/multi_table/test_utils.py | 2 +- tests/unit/sequential/test_par.py | 34 +++-- tests/unit/single_table/test_base.py | 37 ++--- tests/unit/single_table/test_copulagan.py | 10 +- tests/unit/single_table/test_copulas.py | 8 +- tests/unit/single_table/test_ctgan.py | 14 +- 25 files changed, 439 insertions(+), 277 deletions(-) diff --git a/sdv/metadata/metadata.py b/sdv/metadata/metadata.py index 053ef5896..06b9c2ddf 100644 --- a/sdv/metadata/metadata.py +++ b/sdv/metadata/metadata.py @@ -161,7 +161,18 @@ def _convert_to_single_table(self): return next(iter(self.tables.values()), SingleTableMetadata()) - def set_sequence_index(self, table_name, column_name): + def _handle_table_name(self, table_name): + if table_name is None: + if len(self.tables) == 1: + table_name = next(iter(self.tables)) + else: + raise ValueError( + 'Metadata contains more than one table, please specify the `table_name`.' + ) + + return table_name + + def set_sequence_index(self, column_name, table_name=None): """Set the sequence index of a table. Args: @@ -170,18 +181,21 @@ def set_sequence_index(self, table_name, column_name): column_name (str): Name of the sequence index column. """ + table_name = self._handle_table_name(table_name) self._validate_table_exists(table_name) self.tables[table_name].set_sequence_index(column_name) - def set_sequence_key(self, table_name, column_name): + def set_sequence_key(self, column_name, table_name=None): """Set the sequence key of a table. Args: - table_name (str): - Name of the table to set the sequence key. column_name (str, tulple[str]): Name (or tuple of names) of the sequence key column(s). + table_name (str): + Name of the table to set the sequence key. + Defaults to None. """ + table_name = self._handle_table_name(table_name) self._validate_table_exists(table_name) self.tables[table_name].set_sequence_key(column_name) @@ -206,3 +220,53 @@ def validate_table(self, data, table_name=None): ) return self.validate_data({table_name: data}) + + def get_column_names(self, table_name=None, **kwargs): + """Return a list of column names that match the given metadata keyword arguments.""" + table_name = self._handle_table_name(table_name) + return super().get_column_names(table_name, **kwargs) + + def update_column(self, column_name, table_name=None, **kwargs): + """Update an existing column for a table in the ``Metadata``.""" + table_name = self._handle_table_name(table_name) + super().update_column(table_name, column_name, **kwargs) + + def update_columns(self, column_names, table_name=None, **kwargs): + """Update the metadata of multiple columns.""" + table_name = self._handle_table_name(table_name) + super().update_columns(table_name, column_names, **kwargs) + + def update_columns_metadata(self, column_metadata, table_name=None): + """Update the metadata of multiple columns.""" + table_name = self._handle_table_name(table_name) + super().update_columns_metadata(table_name, column_metadata) + + def add_column(self, column_name, table_name=None, **kwargs): + """Add a column to the metadata.""" + table_name = self._handle_table_name(table_name) + super().add_column(table_name, column_name, **kwargs) + + def add_column_relationship( + self, + relationship_type, + column_names, + table_name=None, + ): + """Add a column relationship to the metadata.""" + table_name = self._handle_table_name(table_name) + super().add_column_relationship(table_name, relationship_type, column_names) + + def set_primary_key(self, column_name, table_name=None): + """Set the primary key of a table.""" + table_name = self._handle_table_name(table_name) + super().set_primary_key(table_name, column_name) + + def remove_primary_key(self, table_name=None): + """Remove the primary key of a table.""" + table_name = self._handle_table_name(table_name) + super().remove_primary_key(table_name) + + def add_alternate_keys(self, column_names, table_name=None): + """Add alternate keys to a table.""" + table_name = self._handle_table_name(table_name) + super().add_alternate_keys(table_name, column_names) diff --git a/sdv/metadata/multi_table.py b/sdv/metadata/multi_table.py index cb38bdb12..201a8c8c6 100644 --- a/sdv/metadata/multi_table.py +++ b/sdv/metadata/multi_table.py @@ -515,14 +515,18 @@ def _detect_relationships(self, data=None): try: original_foreign_key_sdtype = child_meta.columns[primary_key]['sdtype'] if original_foreign_key_sdtype != 'id': - self.update_column(child_candidate, primary_key, sdtype='id') + self.update_column( + table_name=child_candidate, column_name=primary_key, sdtype='id' + ) self.add_relationship( parent_candidate, child_candidate, primary_key, primary_key ) except InvalidMetadataError: self.update_column( - child_candidate, primary_key, sdtype=original_foreign_key_sdtype + table_name=child_candidate, + column_name=primary_key, + sdtype=original_foreign_key_sdtype, ) continue diff --git a/sdv/multi_table/base.py b/sdv/multi_table/base.py index 0588936cf..843653d06 100644 --- a/sdv/multi_table/base.py +++ b/sdv/multi_table/base.py @@ -76,7 +76,8 @@ def _set_temp_numpy_seed(self): def _initialize_models(self): with disable_single_table_logger(): for table_name, table_metadata in self.metadata.tables.items(): - synthesizer_parameters = self._table_parameters.get(table_name, {}) + synthesizer_parameters = {'locales': self.locales} + synthesizer_parameters.update(self._table_parameters.get(table_name, {})) metadata_dict = {'tables': {table_name: table_metadata.to_dict()}} metadata = Metadata.load_from_dict(metadata_dict) self._table_synthesizers[table_name] = self._synthesizer( diff --git a/tests/integration/data_processing/test_data_processor.py b/tests/integration/data_processing/test_data_processor.py index 65f865d2f..bb87029b4 100644 --- a/tests/integration/data_processing/test_data_processor.py +++ b/tests/integration/data_processing/test_data_processor.py @@ -53,7 +53,7 @@ def test_with_anonymized_columns(self): data, metadata = download_demo('single_table', 'adult') # Add anonymized field - metadata.update_column('adult', 'occupation', sdtype='job', pii=True) + metadata.update_column('occupation', 'adult', sdtype='job', pii=True) # Instance ``DataProcessor`` dp = DataProcessor(metadata._convert_to_single_table()) @@ -101,11 +101,11 @@ def test_with_anonymized_columns_and_primary_key(self): data, metadata = download_demo('single_table', 'adult') # Add anonymized field - metadata.update_column('adult', 'occupation', sdtype='job', pii=True) + metadata.update_column('occupation', 'adult', sdtype='job', pii=True) # Add primary key field - metadata.add_column('adult', 'id', sdtype='id', regex_format='ID_\\d{4}[0-9]') - metadata.set_primary_key('adult', 'id') + metadata.add_column('id', 'adult', sdtype='id', regex_format='ID_\\d{4}[0-9]') + metadata.set_primary_key('id', 'adult') # Add id size = len(data) @@ -159,8 +159,8 @@ def test_with_primary_key_numerical(self): adult_metadata = Metadata.detect_from_dataframes({'adult': data}) # Add primary key field - adult_metadata.add_column('adult', 'id', sdtype='id') - adult_metadata.set_primary_key('adult', 'id') + adult_metadata.add_column('id', 'adult', sdtype='id') + adult_metadata.set_primary_key('id', 'adult') # Add id size = len(data) @@ -198,13 +198,13 @@ def test_with_alternate_keys(self): adult_metadata = Metadata.detect_from_dataframes({'adult': data}) # Add primary key field - adult_metadata.add_column('adult', 'id', sdtype='id') - adult_metadata.set_primary_key('adult', 'id') + adult_metadata.add_column('id', 'adult', sdtype='id') + adult_metadata.set_primary_key('id', 'adult') - adult_metadata.add_column('adult', 'secondary_id', sdtype='id') - adult_metadata.update_column('adult', 'fnlwgt', sdtype='id', regex_format='ID_\\d{4}[0-9]') + adult_metadata.add_column('secondary_id', 'adult', sdtype='id') + adult_metadata.update_column('fnlwgt', 'adult', sdtype='id', regex_format='ID_\\d{4}[0-9]') - adult_metadata.add_alternate_keys('adult', ['secondary_id', 'fnlwgt']) + adult_metadata.add_alternate_keys(['secondary_id', 'fnlwgt'], 'adult') # Add id size = len(data) @@ -345,7 +345,7 @@ def test_localized_anonymized_columns(self): """Test data processor uses the default locale for anonymized columns.""" # Setup data, metadata = download_demo('single_table', 'adult') - metadata.update_column('adult', 'occupation', sdtype='job', pii=True) + metadata.update_column('occupation', 'adult', sdtype='job', pii=True) dp = DataProcessor(metadata._convert_to_single_table(), locales=['en_CA', 'fr_CA']) diff --git a/tests/integration/evaluation/test_single_table.py b/tests/integration/evaluation/test_single_table.py index 8b3cf23fa..d943b02d3 100644 --- a/tests/integration/evaluation/test_single_table.py +++ b/tests/integration/evaluation/test_single_table.py @@ -12,7 +12,7 @@ def test_evaluation(): data = pd.DataFrame({'col': [1, 2, 3]}) metadata = Metadata() metadata.add_table('table') - metadata.add_column('table', 'col', sdtype='numerical') + metadata.add_column('col', 'table', sdtype='numerical') synthesizer = GaussianCopulaSynthesizer(metadata, default_distribution='truncnorm') # Run and Assert diff --git a/tests/integration/metadata/test_metadata.py b/tests/integration/metadata/test_metadata.py index 5467dfb49..fd5909b24 100644 --- a/tests/integration/metadata/test_metadata.py +++ b/tests/integration/metadata/test_metadata.py @@ -1,5 +1,6 @@ import os import re +from copy import deepcopy import pytest @@ -241,56 +242,6 @@ def test_detect_from_csvs(tmp_path): assert metadata.to_dict() == expected_metadata -def test_detect_table_from_csv(tmp_path): - """Test the ``detect_table_from_csv`` method.""" - # Setup - real_data, _ = download_demo(modality='multi_table', dataset_name='fake_hotels') - - metadata = Metadata() - - for table_name, dataframe in real_data.items(): - csv_path = tmp_path / f'{table_name}.csv' - dataframe.to_csv(csv_path, index=False) - - # Run - metadata.detect_table_from_csv('hotels', tmp_path / 'hotels.csv') - - # Assert - metadata.update_column( - table_name='hotels', - column_name='city', - sdtype='categorical', - ) - metadata.update_column( - table_name='hotels', - column_name='state', - sdtype='categorical', - ) - metadata.update_column( - table_name='hotels', - column_name='classification', - sdtype='categorical', - ) - expected_metadata = { - 'tables': { - 'hotels': { - 'columns': { - 'hotel_id': {'sdtype': 'id'}, - 'city': {'sdtype': 'categorical'}, - 'state': {'sdtype': 'categorical'}, - 'rating': {'sdtype': 'numerical'}, - 'classification': {'sdtype': 'categorical'}, - }, - 'primary_key': 'hotel_id', - } - }, - 'relationships': [], - 'METADATA_SPEC_VERSION': 'V1', - } - - assert metadata.to_dict() == expected_metadata - - def test_single_table_compatibility(tmp_path): """Test if SingleTableMetadata still has compatibility with single table synthesizers.""" # Setup @@ -432,3 +383,82 @@ def test_multi_table_compatibility(tmp_path): assert expected_metadata == synthesizer_2.metadata.to_dict() for table in metadata_sample: assert metadata_sample[table].columns.to_list() == loaded_sample[table].columns.to_list() + + +params = [ + ('update_column', ['column_name'], {'column_name': 'has_rewards', 'sdtype': 'categorical'}), + ( + 'update_columns', + ['column_names'], + {'column_names': ['has_rewards', 'billing_address'], 'sdtype': 'categorical'}, + ), + ( + 'update_columns_metadata', + ['column_metadata'], + {'column_metadata': {'has_rewards': {'sdtype': 'categorical'}}}, + ), + ('add_column', ['column_name'], {'column_name': 'has_rewards_2', 'sdtype': 'categorical'}), + ('set_primary_key', ['column_name'], {'column_name': 'billing_address'}), + ('remove_primary_key', [], {}), + ( + 'add_column_relationship', + ['relationship_type', 'column_names'], + {'column_names': ['billing_address'], 'relationship_type': 'address'}, + ), + ('add_alternate_keys', ['column_names'], {'column_names': ['billing_address']}), + ('set_sequence_key', ['column_name'], {'column_name': 'billing_address'}), + ('get_column_names', [], {'sdtype': 'datetime'}), +] + + +@pytest.mark.parametrize('method, args, kwargs', params) +def test_any_metadata_update_single_table(method, args, kwargs): + """Test that any method that updates metadata works for single-table case.""" + # Setup + _, metadata = download_demo('single_table', 'fake_hotel_guests') + metadata.update_column( + table_name='fake_hotel_guests', column_name='billing_address', sdtype='street_address' + ) + parameter = [kwargs[arg] for arg in args] + remaining_kwargs = {key: value for key, value in kwargs.items() if key not in args} + metadata_before = deepcopy(metadata).to_dict() + + # Run + result = getattr(metadata, method)(*parameter, **remaining_kwargs) + + # Assert + expected_dict = metadata.to_dict() + if method != 'get_column_names': + assert expected_dict != metadata_before + else: + assert result == ['checkin_date', 'checkout_date'] + + +@pytest.mark.parametrize('method, args, kwargs', params) +def test_any_metadata_update_multi_table(method, args, kwargs): + """Test that any method that updates metadata works for multi-table case.""" + # Setup + _, metadata = download_demo('multi_table', 'fake_hotels') + metadata.update_column( + table_name='guests', column_name='billing_address', sdtype='street_address' + ) + parameter = [kwargs[arg] for arg in args] + remaining_kwargs = {key: value for key, value in kwargs.items() if key not in args} + metadata_before = deepcopy(metadata).to_dict() + expected_error = re.escape( + 'Metadata contains more than one table, please specify the `table_name`.' + ) + + # Run + with pytest.raises(ValueError, match=expected_error): + getattr(metadata, method)(*parameter, **remaining_kwargs) + + parameter.append('guests') + result = getattr(metadata, method)(*parameter, **remaining_kwargs) + + # Assert + expected_dict = metadata.to_dict() + if method != 'get_column_names': + assert expected_dict != metadata_before + else: + assert result == ['checkin_date', 'checkout_date'] diff --git a/tests/integration/metadata/test_multi_table.py b/tests/integration/metadata/test_multi_table.py index 1384eaf74..23e16de42 100644 --- a/tests/integration/metadata/test_multi_table.py +++ b/tests/integration/metadata/test_multi_table.py @@ -34,11 +34,11 @@ def _validate_sdtypes(cls, columns_to_sdtypes): mock_rdt_transformers.address.RandomLocationGenerator = RandomLocationGeneratorMock _, instance = download_demo('multi_table', 'fake_hotels') - instance.update_column('hotels', 'city', sdtype='city') - instance.update_column('hotels', 'state', sdtype='state') + instance.update_column('city', 'hotels', sdtype='city') + instance.update_column('state', 'hotels', sdtype='state') # Run - instance.add_column_relationship('hotels', 'address', ['city', 'state']) + instance.add_column_relationship('address', ['city', 'state'], 'hotels') # Assert instance.validate() @@ -303,9 +303,9 @@ def test_get_table_metadata(): """Test the ``get_table_metadata`` method.""" # Setup metadata = get_multi_table_metadata() - metadata.add_column('nesreca', 'latitude', sdtype='latitude') - metadata.add_column('nesreca', 'longitude', sdtype='longitude') - metadata.add_column_relationship('nesreca', 'gps', ['latitude', 'longitude']) + metadata.add_column('latitude', 'nesreca', sdtype='latitude') + metadata.add_column('longitude', 'nesreca', sdtype='longitude') + metadata.add_column_relationship('gps', ['latitude', 'longitude'], 'nesreca') # Run table_metadata = metadata.get_table_metadata('nesreca') diff --git a/tests/integration/metadata/test_visualization.py b/tests/integration/metadata/test_visualization.py index 2d801ea5d..23a7aee17 100644 --- a/tests/integration/metadata/test_visualization.py +++ b/tests/integration/metadata/test_visualization.py @@ -26,9 +26,9 @@ def test_visualize_graph_for_multi_table(): data2 = pd.DataFrame({'\\|=/bla@#$324%^,"&*()><...': ['a', 'b', 'c']}) tables = {'1': data1, '2': data2} metadata = Metadata.detect_from_dataframes(tables) - metadata.update_column('1', '\\|=/bla@#$324%^,"&*()><...', sdtype='id') - metadata.update_column('2', '\\|=/bla@#$324%^,"&*()><...', sdtype='id') - metadata.set_primary_key('1', '\\|=/bla@#$324%^,"&*()><...') + metadata.update_column('\\|=/bla@#$324%^,"&*()><...', '1', sdtype='id') + metadata.update_column('\\|=/bla@#$324%^,"&*()><...', '2', sdtype='id') + metadata.set_primary_key('\\|=/bla@#$324%^,"&*()><...', '1') metadata.add_relationship( '1', '2', '\\|=/bla@#$324%^,"&*()><...', '\\|=/bla@#$324%^,"&*()><...' ) diff --git a/tests/integration/multi_table/test_hma.py b/tests/integration/multi_table/test_hma.py index 5efc300bb..288e470ad 100644 --- a/tests/integration/multi_table/test_hma.py +++ b/tests/integration/multi_table/test_hma.py @@ -87,8 +87,8 @@ def test_hma_reset_sampling(self): faker = Faker() data, metadata = download_demo('multi_table', 'got_families') metadata.add_column( - 'characters', 'ssn', + 'characters', sdtype='ssn', ) data['characters']['ssn'] = [faker.lexify() for _ in range(len(data['characters']))] @@ -126,7 +126,7 @@ def test_get_info(self): today = datetime.datetime.today().strftime('%Y-%m-%d') metadata = Metadata() metadata.add_table('tab') - metadata.add_column('tab', 'col', sdtype='numerical') + metadata.add_column('col', 'tab', sdtype='numerical') synthesizer = HMASynthesizer(metadata) # Run @@ -221,12 +221,12 @@ def get_custom_constraint_data_and_metadata(self): metadata = Metadata() metadata.detect_table_from_dataframe('parent', parent_data) - metadata.update_column('parent', 'primary_key', sdtype='id') + metadata.update_column('primary_key', 'parent', sdtype='id') metadata.detect_table_from_dataframe('child', child_data) - metadata.update_column('child', 'user_id', sdtype='id') - metadata.update_column('child', 'id', sdtype='id') - metadata.set_primary_key('parent', 'primary_key') - metadata.set_primary_key('child', 'id') + metadata.update_column('user_id', 'child', sdtype='id') + metadata.update_column('id', 'child', sdtype='id') + metadata.set_primary_key('primary_key', 'parent') + metadata.set_primary_key('id', 'child') metadata.add_relationship( parent_primary_key='primary_key', parent_table_name='parent', @@ -361,10 +361,10 @@ def test_hma_with_inequality_constraint(self): metadata = Metadata() metadata.detect_table_from_dataframe(table_name='parent_table', data=parent_table) - metadata.update_column('parent_table', 'id', sdtype='id') + metadata.update_column('id', 'parent_table', sdtype='id') metadata.detect_table_from_dataframe(table_name='child_table', data=child_table) - metadata.update_column('child_table', 'id', sdtype='id') - metadata.update_column('child_table', 'parent_id', sdtype='id') + metadata.update_column('id', 'child_table', sdtype='id') + metadata.update_column('parent_id', 'child_table', sdtype='id') metadata.set_primary_key(table_name='parent_table', column_name='id') metadata.set_primary_key(table_name='child_table', column_name='id') @@ -452,14 +452,14 @@ def test_hma_primary_key_and_foreign_key_only(self): for table_name, table in data.items(): metadata.detect_table_from_dataframe(table_name, table) - metadata.update_column('users', 'user_id', sdtype='id') - metadata.update_column('sessions', 'session_id', sdtype='id') - metadata.update_column('games', 'game_id', sdtype='id') - metadata.update_column('games', 'session_id', sdtype='id') - metadata.update_column('games', 'user_id', sdtype='id') - metadata.set_primary_key('users', 'user_id') - metadata.set_primary_key('sessions', 'session_id') - metadata.set_primary_key('games', 'game_id') + metadata.update_column('user_id', 'users', sdtype='id') + metadata.update_column('session_id', 'sessions', sdtype='id') + metadata.update_column('game_id', 'games', sdtype='id') + metadata.update_column('session_id', 'games', sdtype='id') + metadata.update_column('user_id', 'games', sdtype='id') + metadata.set_primary_key('user_id', 'users') + metadata.set_primary_key('session_id', 'sessions') + metadata.set_primary_key('game_id', 'games') metadata.add_relationship('users', 'games', 'user_id', 'user_id') metadata.add_relationship('sessions', 'games', 'session_id', 'session_id') @@ -1350,25 +1350,25 @@ def test_null_foreign_keys(self): # Setup metadata = Metadata() metadata.add_table('parent_table1') - metadata.add_column('parent_table1', 'id', sdtype='id') - metadata.set_primary_key('parent_table1', 'id') + metadata.add_column('id', 'parent_table1', sdtype='id') + metadata.set_primary_key('id', 'parent_table1') metadata.add_table('parent_table2') - metadata.add_column('parent_table2', 'id', sdtype='id') - metadata.set_primary_key('parent_table2', 'id') + metadata.add_column('id', 'parent_table2', sdtype='id') + metadata.set_primary_key('id', 'parent_table2') metadata.add_table('child_table1') - metadata.add_column('child_table1', 'id', sdtype='id') - metadata.set_primary_key('child_table1', 'id') - metadata.add_column('child_table1', 'fk1', sdtype='id') - metadata.add_column('child_table1', 'fk2', sdtype='id') + metadata.add_column('id', 'child_table1', sdtype='id') + metadata.set_primary_key('id', 'child_table1') + metadata.add_column('fk1', 'child_table1', sdtype='id') + metadata.add_column('fk2', 'child_table1', sdtype='id') metadata.add_table('child_table2') - metadata.add_column('child_table2', 'id', sdtype='id') - metadata.set_primary_key('child_table2', 'id') - metadata.add_column('child_table2', 'fk1', sdtype='id') - metadata.add_column('child_table2', 'fk2', sdtype='id') - metadata.add_column('child_table2', 'cat_type', sdtype='categorical') + metadata.add_column('id', 'child_table2', sdtype='id') + metadata.set_primary_key('id', 'child_table2') + metadata.add_column('fk1', 'child_table2', sdtype='id') + metadata.add_column('fk2', 'child_table2', sdtype='id') + metadata.add_column('cat_type', 'child_table2', sdtype='categorical') metadata.add_relationship( parent_table_name='parent_table1', @@ -1841,7 +1841,7 @@ def test_fit_and_sample_numerical_col_names(): } ] metadata = Metadata.load_from_dict(metadata_dict) - metadata.set_primary_key('0', '1') + metadata.set_primary_key('1', '0') # Run synth = HMASynthesizer(metadata) @@ -1874,11 +1874,11 @@ def test_detect_from_dataframe_numerical_col(): metadata = Metadata() metadata.detect_table_from_dataframe('parent_data', parent_data) metadata.detect_table_from_dataframe('child_data', child_data) - metadata.update_column('parent_data', '1', sdtype='id') - metadata.update_column('child_data', '3', sdtype='id') - metadata.update_column('child_data', '4', sdtype='id') - metadata.set_primary_key('parent_data', '1') - metadata.set_primary_key('child_data', '4') + metadata.update_column('1', 'parent_data', sdtype='id') + metadata.update_column('3', 'child_data', sdtype='id') + metadata.update_column('4', 'child_data', sdtype='id') + metadata.set_primary_key('1', 'parent_data') + metadata.set_primary_key('4', 'child_data') metadata.add_relationship( parent_primary_key='1', parent_table_name='parent_data', @@ -1887,11 +1887,11 @@ def test_detect_from_dataframe_numerical_col(): ) test_metadata = Metadata.detect_from_dataframes(data) - test_metadata.update_column('parent_data', '1', sdtype='id') - test_metadata.update_column('child_data', '3', sdtype='id') - test_metadata.update_column('child_data', '4', sdtype='id') - test_metadata.set_primary_key('parent_data', '1') - test_metadata.set_primary_key('child_data', '4') + test_metadata.update_column('1', 'parent_data', sdtype='id') + test_metadata.update_column('3', 'child_data', sdtype='id') + test_metadata.update_column('4', 'child_data', sdtype='id') + test_metadata.set_primary_key('1', 'parent_data') + test_metadata.set_primary_key('4', 'child_data') test_metadata.add_relationship( parent_primary_key='1', parent_table_name='parent_data', @@ -2004,13 +2004,13 @@ def test_hma_synthesizer_with_fixed_combinations(): # Creating metadata for the dataset metadata = Metadata.detect_from_dataframes(data) - metadata.update_column('users', 'user_id', sdtype='id') - metadata.update_column('records', 'record_id', sdtype='id') - metadata.update_column('records', 'user_id', sdtype='id') - metadata.update_column('records', 'location_id', sdtype='id') - metadata.update_column('locations', 'location_id', sdtype='id') - metadata.set_primary_key('users', 'user_id') - metadata.set_primary_key('locations', 'location_id') + metadata.update_column('user_id', 'users', sdtype='id') + metadata.update_column('record_id', 'records', sdtype='id') + metadata.update_column('user_id', 'records', sdtype='id') + metadata.update_column('location_id', 'records', sdtype='id') + metadata.update_column('location_id', 'locations', sdtype='id') + metadata.set_primary_key('user_id', 'users') + metadata.set_primary_key('location_id', 'locations') metadata.add_relationship('users', 'records', 'user_id', 'user_id') metadata.add_relationship('locations', 'records', 'location_id', 'location_id') @@ -2052,8 +2052,8 @@ def test_fit_int_primary_key_regex_includes_zero(regex): 'child_data': child_data, } metadata = Metadata.detect_from_dataframes(data) - metadata.update_column('parent_data', 'parent_id', sdtype='id', regex_format=regex) - metadata.set_primary_key('parent_data', 'parent_id') + metadata.update_column('parent_id', 'parent_data', sdtype='id', regex_format=regex) + metadata.set_primary_key('parent_id', 'parent_data') # Run and Assert instance = HMASynthesizer(metadata) diff --git a/tests/integration/sequential/test_par.py b/tests/integration/sequential/test_par.py index 6ae0af54e..620a95cf2 100644 --- a/tests/integration/sequential/test_par.py +++ b/tests/integration/sequential/test_par.py @@ -22,9 +22,11 @@ def _get_par_data_and_metadata(): 'context': ['a', 'a', 'b', 'b'], }) metadata = Metadata.detect_from_dataframes({'table': data}) - metadata.update_column('table', 'entity', sdtype='id') - metadata.set_sequence_key('table', 'entity') - metadata.set_sequence_index('table', 'date') + metadata.update_column('entity', 'table', sdtype='id') + metadata.set_sequence_key('entity', 'table') + + metadata.set_sequence_index('date', 'table') + return data, metadata @@ -34,9 +36,11 @@ def test_par(): data = load_demo() data['date'] = pd.to_datetime(data['date']) metadata = Metadata.detect_from_dataframes({'table': data}) - metadata.update_column('table', 'store_id', sdtype='id') - metadata.set_sequence_key('table', 'store_id') - metadata.set_sequence_index('table', 'date') + metadata.update_column('store_id', 'table', sdtype='id') + metadata.set_sequence_key('store_id', 'table') + + metadata.set_sequence_index('date', 'table') + model = PARSynthesizer( metadata=metadata, context_columns=['region'], @@ -67,9 +71,10 @@ def test_column_after_date_simple(): 'col2': ['hello', 'world'], }) metadata = Metadata.detect_from_dataframes({'table': data}) - metadata.update_column('table', 'col', sdtype='id') - metadata.set_sequence_key('table', 'col') - metadata.set_sequence_index('table', 'date') + metadata.update_column('col', 'table', sdtype='id') + metadata.set_sequence_key('col', 'table') + + metadata.set_sequence_index('date', 'table') # Run model = PARSynthesizer(metadata=metadata, epochs=1) @@ -347,8 +352,10 @@ def test_par_unique_sequence_index_with_enforce_min_max(): ) metadata = Metadata.detect_from_dataframes({'table': test_df}) metadata.update_column(table_name='table', column_name='s_key', sdtype='id') - metadata.set_sequence_key('table', 's_key') - metadata.set_sequence_index('table', 'visits') + metadata.set_sequence_key('s_key', 'table') + + metadata.set_sequence_index('visits', 'table') + synthesizer = PARSynthesizer( metadata, enforce_min_max_values=True, enforce_rounding=False, epochs=100, verbose=True ) @@ -441,7 +448,7 @@ def test_par_categorical_column_represented_by_floats(): # Setup data, metadata = download_demo('sequential', 'nasdaq100_2019') data['category'] = [100.0 if i % 2 == 0 else 50.0 for i in data.index] - metadata.add_column('nasdaq100_2019', 'category', sdtype='categorical') + metadata.add_column('category', 'nasdaq100_2019', sdtype='categorical') # Run synth = PARSynthesizer(metadata) diff --git a/tests/integration/single_table/test_base.py b/tests/integration/single_table/test_base.py index 7ab9bb27c..88265ec1f 100644 --- a/tests/integration/single_table/test_base.py +++ b/tests/integration/single_table/test_base.py @@ -93,9 +93,9 @@ def test_sample_from_conditions_with_batch_size(): metadata = Metadata() metadata.add_table('table') - metadata.add_column('table', 'column1', sdtype='numerical') - metadata.add_column('table', 'column2', sdtype='numerical') - metadata.add_column('table', 'column3', sdtype='numerical') + metadata.add_column('column1', 'table', sdtype='numerical') + metadata.add_column('column2', 'table', sdtype='numerical') + metadata.add_column('column3', 'table', sdtype='numerical') model = GaussianCopulaSynthesizer(metadata) model.fit(data) @@ -120,9 +120,9 @@ def test_sample_from_conditions_negative_float(): metadata = Metadata() metadata.add_table('table') - metadata.add_column('table', 'column1', sdtype='numerical') - metadata.add_column('table', 'column2', sdtype='numerical') - metadata.add_column('table', 'column3', sdtype='numerical') + metadata.add_column('column1', 'table', sdtype='numerical') + metadata.add_column('column2', 'table', sdtype='numerical') + metadata.add_column('column3', 'table', sdtype='numerical') model = GaussianCopulaSynthesizer(metadata) model.fit(data) @@ -203,7 +203,7 @@ def test_sample_keys_are_scrambled(): """Test that the keys are scrambled in the sampled data.""" # Setup data, metadata = download_demo(modality='single_table', dataset_name='fake_hotel_guests') - metadata.update_column('fake_hotel_guests', 'guest_email', sdtype='id', regex_format='[A-Z]{3}') + metadata.update_column('guest_email', 'fake_hotel_guests', sdtype='id', regex_format='[A-Z]{3}') synthesizer = GaussianCopulaSynthesizer(metadata) synthesizer.fit(data) @@ -234,9 +234,9 @@ def test_multiple_fits(): }) metadata = Metadata() metadata.add_table('table') - metadata.add_column('table', 'city', sdtype='categorical') - metadata.add_column('table', 'state', sdtype='categorical') - metadata.add_column('table', 'measurement', sdtype='numerical') + metadata.add_column('city', 'table', sdtype='categorical') + metadata.add_column('state', 'table', sdtype='categorical') + metadata.add_column('measurement', 'table', sdtype='numerical') constraint = { 'constraint_class': 'FixedCombinations', 'constraint_parameters': {'column_names': ['city', 'state']}, @@ -338,7 +338,7 @@ def test_transformers_correctly_auto_assigned(): metadata.update_column( table_name='table', column_name='primary_key', sdtype='id', regex_format='user-[0-9]{3}' ) - metadata.set_primary_key('table', 'primary_key') + metadata.set_primary_key('primary_key', 'table') metadata.update_column(table_name='table', column_name='pii_col', sdtype='address', pii=True) synthesizer = GaussianCopulaSynthesizer( metadata, enforce_min_max_values=False, enforce_rounding=False @@ -428,7 +428,7 @@ def test_auto_assign_transformers_and_update_with_pii(): # Run metadata.update_column(table_name='table', column_name='id', sdtype='first_name') metadata.update_column(table_name='table', column_name='name', sdtype='name') - metadata.set_primary_key('table', 'id') + metadata.set_primary_key('id', 'table') synthesizer = GaussianCopulaSynthesizer(metadata) synthesizer.auto_assign_transformers(data) @@ -457,8 +457,8 @@ def test_refitting_a_model(): metadata = Metadata.detect_from_dataframes({'table': data}) metadata.update_column(table_name='table', column_name='name', sdtype='name') - metadata.update_column('table', 'id', sdtype='id') - metadata.set_primary_key('table', 'id') + metadata.update_column('id', 'table', sdtype='id') + metadata.set_primary_key('id', 'table') synthesizer = GaussianCopulaSynthesizer(metadata) synthesizer.fit(data) @@ -483,7 +483,7 @@ def test_get_info(): today = datetime.datetime.today().strftime('%Y-%m-%d') metadata = Metadata() metadata.add_table('table') - metadata.add_column('table', 'col', sdtype='numerical') + metadata.add_column('col', 'table', sdtype='numerical') synthesizer = GaussianCopulaSynthesizer(metadata) # Run @@ -628,7 +628,7 @@ def test_metadata_updated_no_warning(mock__fit, tmp_path): # Run 3 instance = BaseSingleTableSynthesizer(metadata_detect) - metadata_detect.update_column('mock_table', 'col 1', sdtype='categorical') + metadata_detect.update_column('col 1', 'mock_table', sdtype='categorical') file_name = tmp_path / 'singletable_2.json' metadata_detect.save_to_json(file_name) with warnings.catch_warnings(record=True) as captured_warnings: diff --git a/tests/integration/single_table/test_constraints.py b/tests/integration/single_table/test_constraints.py index 3477d4557..6e7a8c9e0 100644 --- a/tests/integration/single_table/test_constraints.py +++ b/tests/integration/single_table/test_constraints.py @@ -74,9 +74,9 @@ def test_fit_with_unique_constraint_on_data_with_only_index_column(): metadata = Metadata() metadata.add_table('table') - metadata.add_column('table', 'key', sdtype='id') - metadata.add_column('table', 'index', sdtype='categorical') - metadata.set_primary_key('table', 'key') + metadata.add_column('key', 'table', sdtype='id') + metadata.add_column('index', 'table', sdtype='categorical') + metadata.set_primary_key('key', 'table') model = GaussianCopulaSynthesizer(metadata) constraint = { @@ -139,10 +139,10 @@ def test_fit_with_unique_constraint_on_data_which_has_index_column(): metadata = Metadata() metadata.add_table('table') - metadata.add_column('table', 'key', sdtype='id') - metadata.add_column('table', 'index', sdtype='categorical') - metadata.add_column('table', 'test_column', sdtype='categorical') - metadata.set_primary_key('table', 'key') + metadata.add_column('key', 'table', sdtype='id') + metadata.add_column('index', 'table', sdtype='categorical') + metadata.add_column('test_column', 'table', sdtype='categorical') + metadata.set_primary_key('key', 'table') model = GaussianCopulaSynthesizer(metadata) constraint = { @@ -198,9 +198,9 @@ def test_fit_with_unique_constraint_on_data_subset(): metadata = Metadata() metadata.add_table('table') - metadata.add_column('table', 'key', sdtype='id') - metadata.add_column('table', 'test_column', sdtype='categorical') - metadata.set_primary_key('table', 'key') + metadata.add_column('key', 'table', sdtype='id') + metadata.add_column('test_column', 'table', sdtype='categorical') + metadata.set_primary_key('key', 'table') test_df = test_df.iloc[[1, 3, 4]] constraint = { @@ -295,9 +295,9 @@ def test_conditional_sampling_constraint_uses_reject_sampling(gm_mock, isinstanc metadata = Metadata() metadata.add_table('table') - metadata.add_column('table', 'city', sdtype='categorical') - metadata.add_column('table', 'state', sdtype='categorical') - metadata.add_column('table', 'age', sdtype='numerical') + metadata.add_column('city', 'table', sdtype='categorical') + metadata.add_column('state', 'table', sdtype='categorical') + metadata.add_column('age', 'table', sdtype='numerical') model = GaussianCopulaSynthesizer(metadata) @@ -815,8 +815,8 @@ def reverse_transform(column_names, data): 'other': [7, 8, 9], }) metadata = Metadata.detect_from_dataframes({'table': data}) - metadata.update_column('table', 'key', sdtype='id', regex_format=r'\w_\d') - metadata.set_primary_key('table', 'key') + metadata.update_column('key', 'table', sdtype='id', regex_format=r'\w_\d') + metadata.set_primary_key('key', 'table') synth = GaussianCopulaSynthesizer(metadata) synth.add_custom_constraint_class(custom_constraint, 'custom') @@ -845,8 +845,8 @@ def test_timezone_aware_constraints(): metadata = Metadata() metadata.add_table('table') - metadata.add_column('table', 'col1', sdtype='datetime') - metadata.add_column('table', 'col2', sdtype='datetime') + metadata.add_column('col1', 'table', sdtype='datetime') + metadata.add_column('col2', 'table', sdtype='datetime') my_constraint = { 'constraint_class': 'Inequality', diff --git a/tests/integration/single_table/test_copulas.py b/tests/integration/single_table/test_copulas.py index 6a25fe5dc..f22427104 100644 --- a/tests/integration/single_table/test_copulas.py +++ b/tests/integration/single_table/test_copulas.py @@ -278,8 +278,8 @@ def test_update_transformers_with_id_generator(): data = pd.DataFrame({'user_id': list(range(4)), 'user_cat': ['a', 'b', 'c', 'd']}) stm = Metadata.detect_from_dataframes({'table': data}) - stm.update_column('table', 'user_id', sdtype='id') - stm.set_primary_key('table', 'user_id') + stm.update_column('user_id', 'table', sdtype='id') + stm.set_primary_key('user_id', 'table') gc = GaussianCopulaSynthesizer(stm) custom_id = IDGenerator(starting_value=min_value_id) @@ -434,7 +434,7 @@ def test_unknown_sdtype(): }) metadata = Metadata.detect_from_dataframes({'table': data}) - metadata.update_column('table', 'unknown', sdtype='unknown') + metadata.update_column('unknown', 'table', sdtype='unknown') synthesizer = GaussianCopulaSynthesizer(metadata) diff --git a/tests/integration/single_table/test_ctgan.py b/tests/integration/single_table/test_ctgan.py index 9f892f878..0f87c5fe7 100644 --- a/tests/integration/single_table/test_ctgan.py +++ b/tests/integration/single_table/test_ctgan.py @@ -17,12 +17,12 @@ def test__estimate_num_columns(): # Setup metadata = Metadata() metadata.add_table('table') - metadata.add_column('table', 'numerical', sdtype='numerical') - metadata.add_column('table', 'categorical', sdtype='categorical') - metadata.add_column('table', 'categorical2', sdtype='categorical') - metadata.add_column('table', 'categorical3', sdtype='categorical') - metadata.add_column('table', 'datetime', sdtype='datetime') - metadata.add_column('table', 'boolean', sdtype='boolean') + metadata.add_column('numerical', 'table', sdtype='numerical') + metadata.add_column('categorical', 'table', sdtype='categorical') + metadata.add_column('categorical2', 'table', sdtype='categorical') + metadata.add_column('categorical3', 'table', sdtype='categorical') + metadata.add_column('datetime', 'table', sdtype='datetime') + metadata.add_column('boolean', 'table', sdtype='boolean') data = pd.DataFrame({ 'numerical': [0.1, 0.2, 0.3], 'datetime': ['2020-01-01', '2020-01-02', '2020-01-03'], diff --git a/tests/unit/evaluation/test_single_table.py b/tests/unit/evaluation/test_single_table.py index f669a6e40..08106923f 100644 --- a/tests/unit/evaluation/test_single_table.py +++ b/tests/unit/evaluation/test_single_table.py @@ -23,7 +23,7 @@ def test_evaluate_quality(): data2 = pd.DataFrame({'col': [2, 1, 3]}) metadata = Metadata() metadata.add_table('table') - metadata.add_column('table', 'col', sdtype='numerical') + metadata.add_column('col', 'table', sdtype='numerical') QualityReport.generate = Mock() # Run @@ -59,7 +59,7 @@ def test_run_diagnostic(): data2 = pd.DataFrame({'col': [2, 1, 3]}) metadata = Metadata() metadata.add_table('table') - metadata.add_column('table', 'col', sdtype='numerical') + metadata.add_column('col', 'table', sdtype='numerical') DiagnosticReport.generate = Mock(return_value=123) # Run @@ -100,7 +100,7 @@ def test_get_column_plot_continuous_data(mock_get_plot): data2 = pd.DataFrame({'col': [2, 1, 3]}) metadata = Metadata() metadata.add_table('table') - metadata.add_column('table', 'col', sdtype='numerical') + metadata.add_column('col', 'table', sdtype='numerical') # Run plot = get_column_plot(data1, data2, metadata, 'col') @@ -143,7 +143,7 @@ def test_get_column_plot_discrete_data(mock_get_plot): data2 = pd.DataFrame({'col': ['a', 'b', 'c']}) metadata = Metadata() metadata.add_table('table') - metadata.add_column('table', 'col', sdtype='categorical') + metadata.add_column('col', 'table', sdtype='categorical') # Run plot = get_column_plot(data1, data2, metadata, 'col') @@ -187,7 +187,7 @@ def test_get_column_plot_discrete_data_with_distplot(mock_get_plot): data2 = pd.DataFrame({'col': ['a', 'b', 'c']}) metadata = Metadata() metadata.add_table('table') - metadata.add_column('table', 'col', sdtype='categorical') + metadata.add_column('col', 'table', sdtype='categorical') # Run plot = get_column_plot(data1, data2, metadata, 'col', plot_type='distplot') @@ -231,7 +231,7 @@ def test_get_column_plot_invalid_sdtype(mock_get_plot): data2 = pd.DataFrame({'col': ['a', 'b', 'c']}) metadata = Metadata() metadata.add_table('table') - metadata.add_column('table', 'col', sdtype='id') + metadata.add_column('col', 'table', sdtype='id') # Run and Assert error_msg = re.escape( @@ -276,7 +276,7 @@ def test_get_column_plot_invalid_sdtype_with_plot_type(mock_get_plot): data2 = pd.DataFrame({'col': ['a', 'b', 'c']}) metadata = Metadata() metadata.add_table('table') - metadata.add_column('table', 'col', sdtype='id') + metadata.add_column('col', 'table', sdtype='id') # Run plot = get_column_plot(data1, data2, metadata, 'col', plot_type='bar') @@ -319,7 +319,7 @@ def test_get_column_plot_with_datetime_sdtype(mock_get_plot): synthetic_data = pd.DataFrame({'datetime': ['2023-02-21', '2022-12-13']}) metadata = Metadata() metadata.add_table('table') - metadata.add_column('table', 'datetime', sdtype='datetime', datetime_format='%Y-%m-%d') + metadata.add_column('datetime', 'table', sdtype='datetime', datetime_format='%Y-%m-%d') # Run plot = get_column_plot(real_data, synthetic_data, metadata, 'datetime') @@ -354,8 +354,8 @@ def test_get_column_pair_plot_with_continous_data(mock_get_plot): }) metadata = Metadata() metadata.add_table('table') - metadata.add_column('table', 'amount', sdtype='numerical') - metadata.add_column('table', 'date', sdtype='datetime') + metadata.add_column('amount', 'table', sdtype='numerical') + metadata.add_column('date', 'table', sdtype='datetime') # Run plot = get_column_pair_plot(real_data, synthetic_data, metadata, columns) @@ -389,8 +389,8 @@ def test_get_column_pair_plot_with_discrete_data(mock_get_plot): synthetic_data = pd.DataFrame({'name': ['John', 'Johanna'], 'subscriber': [False, False]}) metadata = Metadata() metadata.add_table('table') - metadata.add_column('table', 'name', sdtype='categorical') - metadata.add_column('table', 'subscriber', sdtype='boolean') + metadata.add_column('name', 'table', sdtype='categorical') + metadata.add_column('subscriber', 'table', sdtype='boolean') # Run plot = get_column_pair_plot(real_data, synthetic_data, metadata, columns) @@ -416,8 +416,8 @@ def test_get_column_pair_plot_with_mixed_data(mock_get_plot): synthetic_data = pd.DataFrame({'name': ['John', 'Johanna'], 'counts': [3, 1]}) metadata = Metadata() metadata.add_table('table') - metadata.add_column('table', 'name', sdtype='categorical') - metadata.add_column('table', 'counts', sdtype='numerical') + metadata.add_column('name', 'table', sdtype='categorical') + metadata.add_column('counts', 'table', sdtype='numerical') # Run plot = get_column_pair_plot(real_data, synthetic_data, metadata, columns) @@ -449,8 +449,8 @@ def test_get_column_pair_plot_with_forced_plot_type(mock_get_plot): }) metadata = Metadata() metadata.add_table('table') - metadata.add_column('table', 'amount', sdtype='numerical') - metadata.add_column('table', 'date', sdtype='datetime') + metadata.add_column('amount', 'table', sdtype='numerical') + metadata.add_column('date', 'table', sdtype='datetime') # Run plot = get_column_pair_plot(real_data, synthetic_data, metadata, columns, plot_type='heatmap') @@ -491,8 +491,8 @@ def test_get_column_pair_plot_with_invalid_sdtype(mock_get_plot): }) metadata = Metadata() metadata.add_table('table') - metadata.add_column('table', 'amount', sdtype='numerical') - metadata.add_column('table', 'id', sdtype='id') + metadata.add_column('amount', 'table', sdtype='numerical') + metadata.add_column('id', 'table', sdtype='id') # Run and Assert error_msg = re.escape( @@ -522,8 +522,8 @@ def test_get_column_pair_plot_with_invalid_sdtype_and_plot_type(mock_get_plot): }) metadata = Metadata() metadata.add_table('table') - metadata.add_column('table', 'amount', sdtype='numerical') - metadata.add_column('table', 'id', sdtype='id') + metadata.add_column('amount', 'table', sdtype='numerical') + metadata.add_column('id', 'table', sdtype='id') # Run plot = get_column_pair_plot(real_data, synthetic_data, metadata, columns, plot_type='heatmap') @@ -551,8 +551,8 @@ def test_get_column_pair_plot_with_sample_size(mock_get_plot): }) metadata = Metadata() metadata.add_table('table') - metadata.add_column('table', 'amount', sdtype='numerical') - metadata.add_column('table', 'price', sdtype='numerical') + metadata.add_column('amount', 'table', sdtype='numerical') + metadata.add_column('price', 'table', sdtype='numerical') # Run get_column_pair_plot(real_data, synthetic_data, metadata, columns, sample_size=2) @@ -614,8 +614,8 @@ def test_get_column_pair_plot_with_sample_size_too_big(mock_get_plot): }) metadata = Metadata() metadata.add_table('table') - metadata.add_column('table', 'amount', sdtype='numerical') - metadata.add_column('table', 'price', sdtype='numerical') + metadata.add_column('amount', 'table', sdtype='numerical') + metadata.add_column('price', 'table', sdtype='numerical') # Run plot = get_column_pair_plot(real_data, synthetic_data, metadata, columns, sample_size=10) diff --git a/tests/unit/metadata/test_metadata.py b/tests/unit/metadata/test_metadata.py index 1524c0757..c7f38052e 100644 --- a/tests/unit/metadata/test_metadata.py +++ b/tests/unit/metadata/test_metadata.py @@ -643,3 +643,54 @@ def test_detect_from_dataframe_raises_error_if_not_dataframe(self): expected_message = 'The provided data must be a pandas DataFrame object.' with pytest.raises(ValueError, match=expected_message): Metadata.detect_from_dataframe(Mock()) + + def test__handle_table_name(self): + """Test the ``_handle_table_name`` method.""" + # Setup + metadata = Metadata() + metadata.tables = ['table_name'] + expected_error = re.escape( + 'Metadata contains more than one table, please specify the `table_name`.' + ) + + # Run + result_none = metadata._handle_table_name(None) + result_table_1 = metadata._handle_table_name('table_1') + metadata.tables = ['table_1', 'table_2'] + with pytest.raises(ValueError, match=expected_error): + metadata._handle_table_name(None) + + result_table_2 = metadata._handle_table_name('table_2') + + # Assert + assert result_none == 'table_name' + assert result_table_1 == 'table_1' + assert result_table_2 == 'table_2' + + params = [ + ('update_column', ['column_name']), + ('update_columns', ['column_names']), + ('update_columns_metadata', ['column_metadata']), + ('add_column', ['column_name']), + ('set_primary_key', ['column_name']), + ('remove_primary_key', []), + ('add_column_relationship', ['relationship_type', 'column_names']), + ('add_alternate_keys', ['column_names']), + ('get_column_names', []), + ] + + @pytest.mark.parametrize('method, args', params) + def test_update_methods(self, method, args): + """Test that all update methods call the superclass method with the resolved arguments.""" + # Setup + metadata = Metadata() + metadata._handle_table_name = Mock(return_value='table_name') + superclass = Metadata.__bases__[0] + + with patch.object(superclass, method) as mock_super_method: + # Run + getattr(metadata, method)(*args, 'table_name') + + # Assert + metadata._handle_table_name.assert_called_once_with('table_name') + mock_super_method.assert_called_once_with('table_name', *args) diff --git a/tests/unit/metadata/test_multi_table.py b/tests/unit/metadata/test_multi_table.py index cd20b7434..aefe8eade 100644 --- a/tests/unit/metadata/test_multi_table.py +++ b/tests/unit/metadata/test_multi_table.py @@ -1396,11 +1396,11 @@ def test_validate_data_datetime_warning(self): '20220826', '20220826', ]) - metadata.add_column('upravna_enota', 'warning_date_str', sdtype='datetime') + metadata.add_column('warning_date_str', 'upravna_enota', sdtype='datetime') metadata.add_column( - 'upravna_enota', 'valid_date', sdtype='datetime', datetime_format='%Y%m%d%H%M%S%f' + 'valid_date', 'upravna_enota', sdtype='datetime', datetime_format='%Y%m%d%H%M%S%f' ) - metadata.add_column('upravna_enota', 'datetime', sdtype='datetime') + metadata.add_column('datetime', 'upravna_enota', sdtype='datetime') # Run and Assert warning_df = pd.DataFrame({ diff --git a/tests/unit/multi_table/test_base.py b/tests/unit/multi_table/test_base.py index 8e0c60629..446c92531 100644 --- a/tests/unit/multi_table/test_base.py +++ b/tests/unit/multi_table/test_base.py @@ -169,10 +169,10 @@ def test__init__column_relationship_warning(self, mock_is_faker_function): # Setup mock_is_faker_function.return_value = True metadata = get_multi_table_metadata() - metadata.add_column('nesreca', 'lat', sdtype='latitude') - metadata.add_column('nesreca', 'lon', sdtype='longitude') + metadata.add_column('lat', 'nesreca', sdtype='latitude') + metadata.add_column('lon', 'nesreca', sdtype='longitude') - metadata.add_column_relationship('nesreca', 'gps', ['lat', 'lon']) + metadata.add_column_relationship('gps', ['lat', 'lon'], 'nesreca') expected_warning = ( "The metadata contains a column relationship of type 'gps' " @@ -540,7 +540,7 @@ def test_validate_constraints_not_met(self): metadata = get_multi_table_metadata() data = get_multi_table_data() data['nesreca']['val'] = list(range(4)) - metadata.add_column('nesreca', 'val', sdtype='numerical') + metadata.add_column('val', 'nesreca', sdtype='numerical') instance = BaseMultiTableSynthesizer(metadata) inequality_constraint = { 'constraint_class': 'Inequality', @@ -1344,8 +1344,8 @@ def test_add_constraints(self): # Setup metadata = get_multi_table_metadata() instance = BaseMultiTableSynthesizer(metadata) - metadata.add_column('nesreca', 'positive_int', sdtype='numerical') - metadata.add_column('oseba', 'negative_int', sdtype='numerical') + metadata.add_column('positive_int', 'nesreca', sdtype='numerical') + metadata.add_column('negative_int', 'oseba', sdtype='numerical') positive_constraint = { 'constraint_class': 'Positive', 'table_name': 'nesreca', @@ -1401,8 +1401,8 @@ def test_get_constraints(self): # Setup metadata = get_multi_table_metadata() instance = BaseMultiTableSynthesizer(metadata) - metadata.add_column('nesreca', 'positive_int', sdtype='numerical') - metadata.add_column('oseba', 'negative_int', sdtype='numerical') + metadata.add_column('positive_int', 'nesreca', sdtype='numerical') + metadata.add_column('negative_int', 'oseba', sdtype='numerical') positive_constraint = { 'constraint_class': 'Positive', 'table_name': 'nesreca', @@ -1527,7 +1527,7 @@ def test_get_info(self, mock_version): data = {'tab': pd.DataFrame({'col': [1, 2, 3]})} metadata = Metadata() metadata.add_table('tab') - metadata.add_column('tab', 'col', sdtype='numerical') + metadata.add_column('col', 'tab', sdtype='numerical') mock_version.public = '1.0.0' mock_version.enterprise = None @@ -1575,7 +1575,7 @@ def test_get_info_with_enterprise(self, mock_version): data = {'tab': pd.DataFrame({'col': [1, 2, 3]})} metadata = Metadata() metadata.add_table('tab') - metadata.add_column('tab', 'col', sdtype='numerical') + metadata.add_column('col', 'tab', sdtype='numerical') mock_version.public = '1.0.0' mock_version.enterprise = '1.1.0' diff --git a/tests/unit/multi_table/test_hma.py b/tests/unit/multi_table/test_hma.py index 72abc0e25..727b40f5a 100644 --- a/tests/unit/multi_table/test_hma.py +++ b/tests/unit/multi_table/test_hma.py @@ -166,9 +166,9 @@ def test__augment_table(self): # Setup metadata = get_multi_table_metadata() instance = HMASynthesizer(metadata) - metadata.add_column('nesreca', 'value', sdtype='numerical') - metadata.add_column('oseba', 'oseba_value', sdtype='numerical') - metadata.add_column('upravna_enota', 'name', sdtype='categorical') + metadata.add_column('value', 'nesreca', sdtype='numerical') + metadata.add_column('oseba_value', 'oseba', sdtype='numerical') + metadata.add_column('name', 'upravna_enota', sdtype='categorical') data = get_multi_table_data() data['nesreca']['value'] = [0, 1, 2, 3] @@ -712,9 +712,9 @@ def test_get_learned_distributions_raises_an_error(self): """Test that ``get_learned_distributions`` raises an error.""" # Setup metadata = get_multi_table_metadata() - metadata.add_column('nesreca', 'value', sdtype='numerical') - metadata.add_column('oseba', 'value', sdtype='numerical') - metadata.add_column('upravna_enota', 'a_value', sdtype='numerical') + metadata.add_column('value', 'nesreca', sdtype='numerical') + metadata.add_column('value', 'oseba', sdtype='numerical') + metadata.add_column('a_value', 'upravna_enota', sdtype='numerical') instance = HMASynthesizer(metadata) # Run and Assert diff --git a/tests/unit/multi_table/test_utils.py b/tests/unit/multi_table/test_utils.py index 57b52fa08..a08ddfb38 100644 --- a/tests/unit/multi_table/test_utils.py +++ b/tests/unit/multi_table/test_utils.py @@ -1947,7 +1947,7 @@ def test__subsample_data( def test__subsample_data_with_null_foreing_keys(): """Test the ``_subsample_data`` method when there are null foreign keys.""" # Setup - metadata = MultiTableMetadata.load_from_dict({ + metadata = Metadata.load_from_dict({ 'tables': { 'parent': { 'columns': { diff --git a/tests/unit/sequential/test_par.py b/tests/unit/sequential/test_par.py index d7dae5832..e1062fcfd 100644 --- a/tests/unit/sequential/test_par.py +++ b/tests/unit/sequential/test_par.py @@ -21,15 +21,15 @@ class TestPARSynthesizer: def get_metadata(self, add_sequence_key=True, add_sequence_index=False): metadata = Metadata() metadata.add_table('table') - metadata.add_column('table', 'time', sdtype='datetime') - metadata.add_column('table', 'gender', sdtype='categorical') - metadata.add_column('table', 'name', sdtype='id') - metadata.add_column('table', 'measurement', sdtype='numerical') + metadata.add_column('time', 'table', sdtype='datetime') + metadata.add_column('gender', 'table', sdtype='categorical') + metadata.add_column('name', 'table', sdtype='id') + metadata.add_column('measurement', 'table', sdtype='numerical') if add_sequence_key: - metadata.set_sequence_key('table', 'name') + metadata.set_sequence_key('name', 'table') if add_sequence_index: - metadata.set_sequence_index('table', 'time') + metadata.set_sequence_index('time', 'table') return metadata @@ -261,11 +261,12 @@ def test_validate_context_columns_unique_per_sequence_key(self): }) metadata = Metadata() metadata.add_table('table') - metadata.add_column('table', 'sk_col1', sdtype='id') - metadata.add_column('table', 'sk_col2', sdtype='id') - metadata.add_column('table', 'ct_col1', sdtype='numerical') - metadata.add_column('table', 'ct_col2', sdtype='numerical') - metadata.set_sequence_key('table', 'sk_col1') + metadata.add_column('sk_col1', 'table', sdtype='id') + metadata.add_column('sk_col2', 'table', sdtype='id') + metadata.add_column('ct_col1', 'table', sdtype='numerical') + metadata.add_column('ct_col2', 'table', sdtype='numerical') + metadata.set_sequence_key('sk_col1', 'table') + instance = PARSynthesizer(metadata=metadata, context_columns=['ct_col1', 'ct_col2']) # Run and Assert @@ -529,7 +530,8 @@ def test_auto_assign_transformers_without_enforce_min_max(self, mock_get_transfo 'measurement': [55, 60, 65], }) metadata = self.get_metadata() - metadata.set_sequence_index('table', 'time') + metadata.set_sequence_index('time', 'table') + mock_get_transfomers.return_value = {'time': FloatFormatter} # Run @@ -611,7 +613,7 @@ def test__fit_sequence_columns_with_categorical_float( data = self.get_data() data['measurement'] = data['measurement'].astype(float) metadata = self.get_metadata() - metadata.update_column('table', 'measurement', sdtype='categorical') + metadata.update_column('measurement', 'table', sdtype='categorical') par = PARSynthesizer(metadata=metadata, context_columns=['gender']) sequences = [ {'context': np.array(['M'], dtype=object), 'data': [['2020-01-03'], [65.0]]}, @@ -649,7 +651,8 @@ def test__fit_sequence_columns_with_sequence_index(self, assemble_sequences_mock 'measurement': [55, 60, 65, 65, 70], }) metadata = self.get_metadata() - metadata.set_sequence_index('table', 'time') + metadata.set_sequence_index('time', 'table') + par = PARSynthesizer(metadata=metadata, context_columns=['gender']) sequences = [ {'context': np.array(['F'], dtype=object), 'data': [[1, 1], [55, 60], [1, 1]]}, @@ -840,7 +843,8 @@ def test__sample_from_par_with_sequence_index(self, tqdm_mock): """ # Setup metadata = self.get_metadata() - metadata.set_sequence_index('table', 'time') + metadata.set_sequence_index('time', 'table') + par = PARSynthesizer(metadata=metadata, context_columns=['gender']) model_mock = Mock() par._model = model_mock diff --git a/tests/unit/single_table/test_base.py b/tests/unit/single_table/test_base.py index 158f71325..5541030f7 100644 --- a/tests/unit/single_table/test_base.py +++ b/tests/unit/single_table/test_base.py @@ -709,10 +709,11 @@ def test_update_transformers_invalid_keys(self): column_name_to_transformer = {'col2': RegexGenerator(), 'col3': FloatFormatter()} metadata = Metadata() metadata.add_table('table') - metadata.add_column('table', 'col2', sdtype='id') - metadata.add_column('table', 'col3', sdtype='id') - metadata.set_sequence_key('table', 'col2') - metadata.add_alternate_keys('table', ['col3']) + metadata.add_column('col2', 'table', sdtype='id') + metadata.add_column('col3', 'table', sdtype='id') + metadata.set_sequence_key('col2', 'table') + + metadata.add_alternate_keys(['col3'], 'table') instance = BaseSingleTableSynthesizer(metadata) # Run and Assert @@ -731,8 +732,8 @@ def test_update_transformers_already_fitted(self): column_name_to_transformer = {'col1': BinaryEncoder(), 'col2': fitted_transformer} metadata = Metadata() metadata.add_table('table') - metadata.add_column('table', 'col1', sdtype='boolean') - metadata.add_column('table', 'col2', sdtype='numerical') + metadata.add_column('col1', 'table', sdtype='boolean') + metadata.add_column('col2', 'table', sdtype='numerical') instance = BaseSingleTableSynthesizer(metadata) # Run and Assert @@ -746,8 +747,8 @@ def test_update_transformers_warns_gaussian_copula(self): column_name_to_transformer = {'col1': OneHotEncoder(), 'col2': FloatFormatter()} metadata = Metadata() metadata.add_table('table') - metadata.add_column('table', 'col1', sdtype='categorical') - metadata.add_column('table', 'col2', sdtype='numerical') + metadata.add_column('col1', 'table', sdtype='categorical') + metadata.add_column('col2', 'table', sdtype='numerical') instance = GaussianCopulaSynthesizer(metadata) instance._data_processor.fit(pd.DataFrame({'col1': [1, 2], 'col2': [1, 2]})) @@ -774,8 +775,8 @@ def test_update_transformers_warns_models(self): column_name_to_transformer = {'col1': OneHotEncoder(), 'col2': FloatFormatter()} metadata = Metadata() metadata.add_table('table') - metadata.add_column('table', 'col1', sdtype='categorical') - metadata.add_column('table', 'col2', sdtype='numerical') + metadata.add_column('col1', 'table', sdtype='categorical') + metadata.add_column('col2', 'table', sdtype='numerical') # NOTE: when PARSynthesizer is implemented, add it here as well for model in [CTGANSynthesizer, CopulaGANSynthesizer, TVAESynthesizer]: @@ -800,8 +801,8 @@ def test_update_transformers_warns_fitted(self): column_name_to_transformer = {'col1': GaussianNormalizer(), 'col2': GaussianNormalizer()} metadata = Metadata() metadata.add_table('table') - metadata.add_column('table', 'col1', sdtype='numerical') - metadata.add_column('table', 'col2', sdtype='numerical') + metadata.add_column('col1', 'table', sdtype='numerical') + metadata.add_column('col2', 'table', sdtype='numerical') instance = BaseSingleTableSynthesizer(metadata) instance._data_processor.fit(pd.DataFrame({'col1': [1, 2], 'col2': [1, 2]})) instance._fitted = True @@ -819,8 +820,8 @@ def test_update_transformers(self): column_name_to_transformer = {'col1': GaussianNormalizer(), 'col2': GaussianNormalizer()} metadata = Metadata() metadata.add_table('table') - metadata.add_column('table', 'col1', sdtype='numerical') - metadata.add_column('table', 'col2', sdtype='numerical') + metadata.add_column('col1', 'table', sdtype='numerical') + metadata.add_column('col2', 'table', sdtype='numerical') instance = BaseSingleTableSynthesizer(metadata) instance._data_processor.fit(pd.DataFrame({'col1': [1, 2], 'col2': [1, 2]})) @@ -2053,7 +2054,7 @@ def test_add_constraints(self): # Setup metadata = Metadata() metadata.add_table('table') - metadata.add_column('table', 'col', sdtype='numerical') + metadata.add_column('col', 'table', sdtype='numerical') instance = BaseSingleTableSynthesizer(metadata) positive_constraint = { 'constraint_class': 'Positive', @@ -2084,7 +2085,7 @@ def test_get_constraints(self): # Setup metadata = Metadata() metadata.add_table('table') - metadata.add_column('table', 'col', sdtype='numerical') + metadata.add_column('col', 'table', sdtype='numerical') instance = BaseSingleTableSynthesizer(metadata) positive_constraint = { 'constraint_class': 'Positive', @@ -2119,7 +2120,7 @@ def test_get_info_no_enterprise(self, mock_sdv_version): mock_sdv_version.enterprise = None metadata = Metadata() metadata.add_table('table') - metadata.add_column('table', 'col', sdtype='numerical') + metadata.add_column('col', 'table', sdtype='numerical') with patch('sdv.single_table.base.datetime.datetime') as mock_date: mock_date.today.return_value = datetime(2023, 1, 23) @@ -2167,7 +2168,7 @@ def test_get_info_with_enterprise(self, mock_sdv_version): mock_sdv_version.enterprise = '1.2.0' metadata = Metadata() metadata.add_table('table') - metadata.add_column('table', 'col', sdtype='numerical') + metadata.add_column('col', 'table', sdtype='numerical') with patch('sdv.single_table.base.datetime.datetime') as mock_date: mock_date.today.return_value = datetime(2023, 1, 23) diff --git a/tests/unit/single_table/test_copulagan.py b/tests/unit/single_table/test_copulagan.py index e7c531a26..354ca841c 100644 --- a/tests/unit/single_table/test_copulagan.py +++ b/tests/unit/single_table/test_copulagan.py @@ -90,7 +90,7 @@ def test___init__custom(self): # Setup metadata = Metadata() metadata.add_table('table') - metadata.add_column('table', 'field', sdtype='numerical') + metadata.add_column('field', 'table', sdtype='numerical') enforce_min_max_values = False enforce_rounding = False embedding_dim = 64 @@ -226,9 +226,9 @@ def test__create_gaussian_normalizer_config(self, mock_rdt): numerical_distributions = {'age': 'gamma'} metadata = Metadata() metadata.add_table('table') - metadata.add_column('table', 'name', sdtype='categorical') - metadata.add_column('table', 'age', sdtype='numerical') - metadata.add_column('table', 'account', sdtype='numerical') + metadata.add_column('name', 'table', sdtype='categorical') + metadata.add_column('age', 'table', sdtype='numerical') + metadata.add_column('account', 'table', sdtype='numerical') instance = CopulaGANSynthesizer(metadata, numerical_distributions=numerical_distributions) processed_data = pd.DataFrame({ @@ -275,7 +275,7 @@ def test__fit_logging(self, mock_rdt, mock_ctgansynthesizer__fit, mock_logger): # Setup metadata = Metadata() metadata.add_table('table') - metadata.add_column('table', 'col', sdtype='numerical') + metadata.add_column('col', 'table', sdtype='numerical') numerical_distributions = {'col': 'gamma'} instance = CopulaGANSynthesizer(metadata, numerical_distributions=numerical_distributions) processed_data = pd.DataFrame() diff --git a/tests/unit/single_table/test_copulas.py b/tests/unit/single_table/test_copulas.py index be523196d..57fc1dd08 100644 --- a/tests/unit/single_table/test_copulas.py +++ b/tests/unit/single_table/test_copulas.py @@ -92,7 +92,7 @@ def test___init__custom(self): # Setup metadata = Metadata() metadata.add_table('table') - metadata.add_column('table', 'field', sdtype='numerical') + metadata.add_column('field', 'table', sdtype='numerical') enforce_min_max_values = False enforce_rounding = False numerical_distributions = {'field': 'gamma'} @@ -169,7 +169,7 @@ def test__fit_logging(self, mock_logger): # Setup metadata = Metadata() metadata.add_table('table') - metadata.add_column('table', 'col', sdtype='numerical') + metadata.add_column('col', 'table', sdtype='numerical') numerical_distributions = {'col': 'gamma'} instance = GaussianCopulaSynthesizer( metadata, numerical_distributions=numerical_distributions @@ -197,8 +197,8 @@ def test__fit(self, mock_multivariate, mock_warnings): # Setup metadata = Metadata() metadata.add_table('table') - metadata.add_column('table', 'name', sdtype='numerical') - metadata.add_column('table', 'user.id', sdtype='numerical') + metadata.add_column('name', 'table', sdtype='numerical') + metadata.add_column('user.id', 'table', sdtype='numerical') numerical_distributions = {'name': 'uniform', 'user.id': 'gamma'} processed_data = pd.DataFrame({ diff --git a/tests/unit/single_table/test_ctgan.py b/tests/unit/single_table/test_ctgan.py index 5cb3e6000..a3b946f4b 100644 --- a/tests/unit/single_table/test_ctgan.py +++ b/tests/unit/single_table/test_ctgan.py @@ -189,9 +189,9 @@ def test__estimate_num_columns(self): # Setup metadata = Metadata() metadata.add_table('table') - metadata.add_column('table', 'id', sdtype='numerical') - metadata.add_column('table', 'name', sdtype='categorical') - metadata.add_column('table', 'surname', sdtype='categorical') + metadata.add_column('id', 'table', sdtype='numerical') + metadata.add_column('name', 'table', sdtype='categorical') + metadata.add_column('surname', 'table', sdtype='categorical') data = pd.DataFrame({ 'id': np.random.rand(1_001), 'name': [f'cat_{i}' for i in range(1_001)], @@ -215,8 +215,8 @@ def test_preprocessing_many_categories(self, capfd): # Setup metadata = Metadata() metadata.add_table('table') - metadata.add_column('table', 'name_longer_than_Original_Column_Name', sdtype='numerical') - metadata.add_column('table', 'categorical', sdtype='categorical') + metadata.add_column('name_longer_than_Original_Column_Name', 'table', sdtype='numerical') + metadata.add_column('categorical', 'table', sdtype='categorical') data = pd.DataFrame({ 'name_longer_than_Original_Column_Name': np.random.rand(1_001), 'categorical': [f'cat_{i}' for i in range(1_001)], @@ -246,8 +246,8 @@ def test_preprocessing_few_categories(self, capfd): # Setup metadata = Metadata() metadata.add_table('table') - metadata.add_column('table', 'name_longer_than_Original_Column_Name', sdtype='numerical') - metadata.add_column('table', 'categorical', sdtype='categorical') + metadata.add_column('name_longer_than_Original_Column_Name', 'table', sdtype='numerical') + metadata.add_column('categorical', 'table', sdtype='categorical') data = pd.DataFrame({ 'name_longer_than_Original_Column_Name': np.random.rand(10), 'categorical': [f'cat_{i}' for i in range(10)], From 43290f491f22c7b43c46e7823607c542a0e9c4b2 Mon Sep 17 00:00:00 2001 From: R-Palazzo Date: Fri, 27 Sep 2024 11:12:09 -0400 Subject: [PATCH 16/16] fix integreation test with SDMetrics-0.16.0 --- tests/integration/evaluation/test_multi_table.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integration/evaluation/test_multi_table.py b/tests/integration/evaluation/test_multi_table.py index 7ca797435..3075a8450 100644 --- a/tests/integration/evaluation/test_multi_table.py +++ b/tests/integration/evaluation/test_multi_table.py @@ -60,7 +60,7 @@ def test_evaluation(): def test_evaluation_metadata(): """Test ``evaluate_quality`` and ``run_diagnostic`` with Metadata.""" # Setup - table = pd.DataFrame({'id': [0, 1, 2, 3], 'col': [1, 2, 3, 4]}) + table = pd.DataFrame({'id': [0, 1, 2, 3], 'col': [1, 2, 3, 4.0]}) slightly_different_table = pd.DataFrame({'id': [0, 1, 2, 3], 'col': [1, 2, 3, 3.5]}) data = { 'table1': table,