From dfd1a59fc6096f7c212ab668f38732d38fa3e4ae Mon Sep 17 00:00:00 2001 From: Andrew Montanez Date: Fri, 17 Mar 2023 10:34:34 -0700 Subject: [PATCH] Add load_from_dict to SingleTableMetadata and MultiTableMetadata (#1315) --- sdv/data_processing/data_processor.py | 2 +- sdv/datasets/demo.py | 2 +- sdv/metadata/multi_table.py | 12 +++++------ sdv/metadata/single_table.py | 10 ++++----- sdv/sequential/par.py | 2 +- tests/integration/single_table/test_base.py | 4 ++-- tests/unit/metadata/test_multi_table.py | 24 ++++++++++----------- tests/unit/metadata/test_single_table.py | 14 ++++++------ tests/utils.py | 2 +- 9 files changed, 36 insertions(+), 36 deletions(-) diff --git a/sdv/data_processing/data_processor.py b/sdv/data_processing/data_processor.py index 8fda1851c..69920889c 100644 --- a/sdv/data_processing/data_processor.py +++ b/sdv/data_processing/data_processor.py @@ -824,7 +824,7 @@ def from_dict(cls, metadata_dict, enforce_rounding=True, enforce_min_max_values= If passed, set the ``enforce_min_max_values`` on the new instance. """ instance = cls( - metadata=SingleTableMetadata._load_from_dict(metadata_dict['metadata']), + metadata=SingleTableMetadata.load_from_dict(metadata_dict['metadata']), enforce_rounding=enforce_rounding, enforce_min_max_values=enforce_min_max_values, model_kwargs=metadata_dict.get('model_kwargs') diff --git a/sdv/datasets/demo.py b/sdv/datasets/demo.py index 94fea74d4..c24a48bc6 100644 --- a/sdv/datasets/demo.py +++ b/sdv/datasets/demo.py @@ -101,7 +101,7 @@ def _get_metadata(modality, output_folder_name, in_memory_directory): else: metadict = json.loads(in_memory_directory['metadata_v1.json']) - metadata = metadata._load_from_dict(metadict) + metadata = metadata.load_from_dict(metadict) return metadata diff --git a/sdv/metadata/multi_table.py b/sdv/metadata/multi_table.py index 2f4de453f..de3072528 100644 --- a/sdv/metadata/multi_table.py +++ b/sdv/metadata/multi_table.py @@ -579,24 +579,24 @@ def _set_metadata_dict(self, metadata): Python dictionary representing a ``MultiTableMetadata`` object. """ for table_name, table_dict in metadata.get('tables', {}).items(): - self.tables[table_name] = SingleTableMetadata._load_from_dict(table_dict) + self.tables[table_name] = SingleTableMetadata.load_from_dict(table_dict) for relationship in metadata.get('relationships', []): self.relationships.append(relationship) @classmethod - def _load_from_dict(cls, metadata): + def load_from_dict(cls, metadata_dict): """Create a ``MultiTableMetadata`` instance from a python ``dict``. Args: - metadata (dict): + metadata_dict (dict): Python dictionary representing a ``MultiTableMetadata`` object. Returns: Instance of ``MultiTableMetadata``. """ instance = cls() - instance._set_metadata_dict(metadata) + instance._set_metadata_dict(metadata_dict) return instance def save_to_json(self, filepath): @@ -630,7 +630,7 @@ def load_from_json(cls, filepath): A ``MultiTableMetadata`` instance. """ metadata = read_json(filepath) - return cls._load_from_dict(metadata) + return cls.load_from_dict(metadata) def __repr__(self): """Pretty print the ``MultiTableMetadata``.""" @@ -698,7 +698,7 @@ def upgrade_metadata(cls, old_filepath, new_filepath): 'relationships': relationships, 'METADATA_SPEC_VERSION': cls.METADATA_SPEC_VERSION } - metadata = cls._load_from_dict(metadata_dict) + metadata = cls.load_from_dict(metadata_dict) metadata.save_to_json(new_filepath) try: metadata.validate() diff --git a/sdv/metadata/single_table.py b/sdv/metadata/single_table.py index a92b74acd..8bac0e2d6 100644 --- a/sdv/metadata/single_table.py +++ b/sdv/metadata/single_table.py @@ -485,11 +485,11 @@ def save_to_json(self, filepath): json.dump(metadata, metadata_file, indent=4) @classmethod - def _load_from_dict(cls, metadata): + def load_from_dict(cls, metadata_dict): """Create a ``SingleTableMetadata`` instance from a python ``dict``. Args: - metadata (dict): + metadata_dict (dict): Python dictionary representing a ``SingleTableMetadata`` object. Returns: @@ -497,7 +497,7 @@ def _load_from_dict(cls, metadata): """ instance = cls() for key in instance._KEYS: - value = deepcopy(metadata.get(key)) + value = deepcopy(metadata_dict.get(key)) if value: setattr(instance, f'{key}', value) @@ -525,7 +525,7 @@ def load_from_json(cls, filepath): 'class and version.' ) - return cls._load_from_dict(metadata) + return cls.load_from_dict(metadata) def __repr__(self): """Pretty print the ``SingleTableMetadata``.""" @@ -560,7 +560,7 @@ def upgrade_metadata(cls, old_filepath, new_filepath): old_metadata = list(tables.values())[0] new_metadata = convert_metadata(old_metadata) - metadata = cls._load_from_dict(new_metadata) + metadata = cls.load_from_dict(new_metadata) metadata.save_to_json(new_filepath) try: diff --git a/sdv/sequential/par.py b/sdv/sequential/par.py index 4127f9ba8..f3cbd52da 100644 --- a/sdv/sequential/par.py +++ b/sdv/sequential/par.py @@ -74,7 +74,7 @@ def _get_context_metadata(self): context_columns_dict[column] = self.metadata.columns[column] context_metadata_dict = {'columns': context_columns_dict} - return SingleTableMetadata._load_from_dict(context_metadata_dict) + return SingleTableMetadata.load_from_dict(context_metadata_dict) def __init__(self, metadata, enforce_min_max_values=True, enforce_rounding=False, context_columns=None, segment_size=None, epochs=128, sample_size=1, cuda=True, diff --git a/tests/integration/single_table/test_base.py b/tests/integration/single_table/test_base.py index 2cfea1a2e..a9c06780e 100644 --- a/tests/integration/single_table/test_base.py +++ b/tests/integration/single_table/test_base.py @@ -15,7 +15,7 @@ from sdv.single_table.base import BaseSingleTableSynthesizer from tests.integration.single_table.custom_constraints import MyConstraint -METADATA = SingleTableMetadata._load_from_dict({ +METADATA = SingleTableMetadata.load_from_dict({ 'METADATA_SPEC_VERSION': 'SINGLE_TABLE_V1', 'columns': { 'column1': { @@ -436,7 +436,7 @@ def test_sampling(synthesizer): @pytest.mark.parametrize('synthesizer', SYNTHESIZERS) def test_sampling_reset_sampling(synthesizer): """Test ``sample`` method for each synthesizer using ``reset_sampling``.""" - metadata = SingleTableMetadata._load_from_dict({ + metadata = SingleTableMetadata.load_from_dict({ 'METADATA_SPEC_VERSION': 'SINGLE_TABLE_V1', 'columns': { 'column1': { diff --git a/tests/unit/metadata/test_multi_table.py b/tests/unit/metadata/test_multi_table.py index c43a945e3..c66542e36 100644 --- a/tests/unit/metadata/test_multi_table.py +++ b/tests/unit/metadata/test_multi_table.py @@ -73,7 +73,7 @@ def get_metadata(self): } ] - return MultiTableMetadata._load_from_dict(metadata) + return MultiTableMetadata.load_from_dict(metadata) def test___init__(self): """Test the ``__init__`` method of ``MultiTableMetadata``.""" @@ -594,7 +594,7 @@ def test__validate_single_table(self): - Errors has been updated with the error message for that column. """ # Setup - table_accounts = SingleTableMetadata._load_from_dict({ + table_accounts = SingleTableMetadata.load_from_dict({ 'columns': { 'id': {'sdtype': 'numerical'}, 'branch_id': {'sdtype': 'numerical'}, @@ -1009,7 +1009,7 @@ def test__set_metadata(self, mock_singletablemetadata): Side Effects: - ``instance`` now contains ``instance.tables`` and ``instance.relationships``. - - ``SingleTableMetadata._load_from_dict`` has been called. + - ``SingleTableMetadata.load_from_dict`` has been called. """ # Setup multitable_metadata = { @@ -1038,7 +1038,7 @@ def test__set_metadata(self, mock_singletablemetadata): single_table_accounts = object() single_table_branches = object() - mock_singletablemetadata._load_from_dict.side_effect = [ + mock_singletablemetadata.load_from_dict.side_effect = [ single_table_accounts, single_table_branches ] @@ -1064,10 +1064,10 @@ def test__set_metadata(self, mock_singletablemetadata): ] @patch('sdv.metadata.multi_table.SingleTableMetadata') - def test__load_from_dict(self, mock_singletablemetadata): - """Test that ``_load_from_dict`` returns a instance of ``MultiTableMetadata``. + def test_load_from_dict(self, mock_singletablemetadata): + """Test that ``load_from_dict`` returns a instance of ``MultiTableMetadata``. - Test that when calling the ``_load_from_dict`` method a new instance with the passed + Test that when calling the ``load_from_dict`` method a new instance with the passed python ``dict`` details should be created. Setup: @@ -1080,7 +1080,7 @@ def test__load_from_dict(self, mock_singletablemetadata): - ``instance`` that contains ``instance.tables`` and ``instance.relationships``. Side Effects: - - ``SingleTableMetadata._load_from_dict`` has been called. + - ``SingleTableMetadata.load_from_dict`` has been called. """ # Setup multitable_metadata = { @@ -1109,13 +1109,13 @@ def test__load_from_dict(self, mock_singletablemetadata): single_table_accounts = object() single_table_branches = object() - mock_singletablemetadata._load_from_dict.side_effect = [ + mock_singletablemetadata.load_from_dict.side_effect = [ single_table_accounts, single_table_branches ] # Run - instance = MultiTableMetadata._load_from_dict(multitable_metadata) + instance = MultiTableMetadata.load_from_dict(multitable_metadata) # Assert assert instance.tables == { @@ -1938,7 +1938,7 @@ def test__convert_relationships(self): @patch('sdv.metadata.multi_table.read_json') @patch('sdv.metadata.multi_table.MultiTableMetadata._convert_relationships') @patch('sdv.metadata.multi_table.convert_metadata') - @patch('sdv.metadata.multi_table.MultiTableMetadata._load_from_dict') + @patch('sdv.metadata.multi_table.MultiTableMetadata.load_from_dict') def test_upgrade_metadata( self, from_dict_mock, convert_mock, relationships_mock, read_json_mock, validate_mock): """Test the ``upgrade_metadata`` method. @@ -2024,7 +2024,7 @@ def test_upgrade_metadata( @patch('sdv.metadata.multi_table.read_json') @patch('sdv.metadata.multi_table.MultiTableMetadata._convert_relationships') @patch('sdv.metadata.multi_table.convert_metadata') - @patch('sdv.metadata.multi_table.MultiTableMetadata._load_from_dict') + @patch('sdv.metadata.multi_table.MultiTableMetadata.load_from_dict') def test_upgrade_metadata_validate_error( self, from_dict_mock, convert_mock, relationships_mock, read_json_mock, validate_mock, warnings_mock): diff --git a/tests/unit/metadata/test_single_table.py b/tests/unit/metadata/test_single_table.py index ed6a464fb..4afa06694 100644 --- a/tests/unit/metadata/test_single_table.py +++ b/tests/unit/metadata/test_single_table.py @@ -1507,8 +1507,8 @@ def test_to_dict(self): result['columns']['my_column'] = 1 assert instance.columns['my_column'] == 'value' - def test__load_from_dict(self): - """Test that ``_load_from_dict`` returns a instance with the ``dict`` updated objects.""" + def test_load_from_dict(self): + """Test that ``load_from_dict`` returns a instance with the ``dict`` updated objects.""" # Setup my_metadata = { 'columns': {'my_column': 'value'}, @@ -1520,7 +1520,7 @@ def test__load_from_dict(self): } # Run - instance = SingleTableMetadata._load_from_dict(my_metadata) + instance = SingleTableMetadata.load_from_dict(my_metadata) # Assert assert instance.columns == {'my_column': 'value'} @@ -1725,7 +1725,7 @@ def test___repr__(self, mock_json): @patch('sdv.metadata.single_table.validate_file_does_not_exist') @patch('sdv.metadata.single_table.read_json') @patch('sdv.metadata.single_table.convert_metadata') - @patch('sdv.metadata.single_table.SingleTableMetadata._load_from_dict') + @patch('sdv.metadata.single_table.SingleTableMetadata.load_from_dict') def test_upgrade_metadata(self, from_dict_mock, convert_mock, read_json_mock, validate_mock): """Test the ``upgrade_metadata`` method. @@ -1764,7 +1764,7 @@ def test_upgrade_metadata(self, from_dict_mock, convert_mock, read_json_mock, va @patch('sdv.metadata.single_table.validate_file_does_not_exist') @patch('sdv.metadata.single_table.read_json') @patch('sdv.metadata.single_table.convert_metadata') - @patch('sdv.metadata.single_table.SingleTableMetadata._load_from_dict') + @patch('sdv.metadata.single_table.SingleTableMetadata.load_from_dict') def test_upgrade_metadata_multiple_tables( self, from_dict_mock, convert_mock, read_json_mock, validate_mock): """Test the ``upgrade_metadata`` method. @@ -1805,7 +1805,7 @@ def test_upgrade_metadata_multiple_tables( @patch('sdv.metadata.single_table.validate_file_does_not_exist') @patch('sdv.metadata.single_table.read_json') @patch('sdv.metadata.single_table.convert_metadata') - @patch('sdv.metadata.single_table.SingleTableMetadata._load_from_dict') + @patch('sdv.metadata.single_table.SingleTableMetadata.load_from_dict') def test_upgrade_metadata_multiple_tables_fails( self, from_dict_mock, convert_mock, read_json_mock, validate_mock): """Test the ``upgrade_metadata`` method. @@ -1847,7 +1847,7 @@ def test_upgrade_metadata_multiple_tables_fails( @patch('sdv.metadata.single_table.validate_file_does_not_exist') @patch('sdv.metadata.single_table.read_json') @patch('sdv.metadata.single_table.convert_metadata') - @patch('sdv.metadata.single_table.SingleTableMetadata._load_from_dict') + @patch('sdv.metadata.single_table.SingleTableMetadata.load_from_dict') def test_upgrade_metadata_validate_error( self, from_dict_mock, convert_mock, read_json_mock, validate_mock, warnings_mock): """Test the ``upgrade_metadata`` method. diff --git a/tests/utils.py b/tests/utils.py index 3b127b2fb..0383cccda 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -60,7 +60,7 @@ def get_multi_table_metadata(): 'METADATA_SPEC_VERSION': 'MULTI_TABLE_V1' } - return MultiTableMetadata._load_from_dict(dict_metadata) + return MultiTableMetadata.load_from_dict(dict_metadata) def get_multi_table_data():