Skip to content

Commit

Permalink
Add load_from_dict to SingleTableMetadata and MultiTableMetadata (#1315)
Browse files Browse the repository at this point in the history
  • Loading branch information
amontanez24 authored Mar 17, 2023
1 parent e087f90 commit dfd1a59
Show file tree
Hide file tree
Showing 9 changed files with 36 additions and 36 deletions.
2 changes: 1 addition & 1 deletion sdv/data_processing/data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -824,7 +824,7 @@ def from_dict(cls, metadata_dict, enforce_rounding=True, enforce_min_max_values=
If passed, set the ``enforce_min_max_values`` on the new instance.
"""
instance = cls(
metadata=SingleTableMetadata._load_from_dict(metadata_dict['metadata']),
metadata=SingleTableMetadata.load_from_dict(metadata_dict['metadata']),
enforce_rounding=enforce_rounding,
enforce_min_max_values=enforce_min_max_values,
model_kwargs=metadata_dict.get('model_kwargs')
Expand Down
2 changes: 1 addition & 1 deletion sdv/datasets/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def _get_metadata(modality, output_folder_name, in_memory_directory):

else:
metadict = json.loads(in_memory_directory['metadata_v1.json'])
metadata = metadata._load_from_dict(metadict)
metadata = metadata.load_from_dict(metadict)

return metadata

Expand Down
12 changes: 6 additions & 6 deletions sdv/metadata/multi_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,24 +579,24 @@ def _set_metadata_dict(self, metadata):
Python dictionary representing a ``MultiTableMetadata`` object.
"""
for table_name, table_dict in metadata.get('tables', {}).items():
self.tables[table_name] = SingleTableMetadata._load_from_dict(table_dict)
self.tables[table_name] = SingleTableMetadata.load_from_dict(table_dict)

for relationship in metadata.get('relationships', []):
self.relationships.append(relationship)

@classmethod
def _load_from_dict(cls, metadata):
def load_from_dict(cls, metadata_dict):
"""Create a ``MultiTableMetadata`` instance from a python ``dict``.
Args:
metadata (dict):
metadata_dict (dict):
Python dictionary representing a ``MultiTableMetadata`` object.
Returns:
Instance of ``MultiTableMetadata``.
"""
instance = cls()
instance._set_metadata_dict(metadata)
instance._set_metadata_dict(metadata_dict)
return instance

def save_to_json(self, filepath):
Expand Down Expand Up @@ -630,7 +630,7 @@ def load_from_json(cls, filepath):
A ``MultiTableMetadata`` instance.
"""
metadata = read_json(filepath)
return cls._load_from_dict(metadata)
return cls.load_from_dict(metadata)

def __repr__(self):
"""Pretty print the ``MultiTableMetadata``."""
Expand Down Expand Up @@ -698,7 +698,7 @@ def upgrade_metadata(cls, old_filepath, new_filepath):
'relationships': relationships,
'METADATA_SPEC_VERSION': cls.METADATA_SPEC_VERSION
}
metadata = cls._load_from_dict(metadata_dict)
metadata = cls.load_from_dict(metadata_dict)
metadata.save_to_json(new_filepath)
try:
metadata.validate()
Expand Down
10 changes: 5 additions & 5 deletions sdv/metadata/single_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,19 +485,19 @@ def save_to_json(self, filepath):
json.dump(metadata, metadata_file, indent=4)

@classmethod
def _load_from_dict(cls, metadata):
def load_from_dict(cls, metadata_dict):
"""Create a ``SingleTableMetadata`` instance from a python ``dict``.
Args:
metadata (dict):
metadata_dict (dict):
Python dictionary representing a ``SingleTableMetadata`` object.
Returns:
Instance of ``SingleTableMetadata``.
"""
instance = cls()
for key in instance._KEYS:
value = deepcopy(metadata.get(key))
value = deepcopy(metadata_dict.get(key))
if value:
setattr(instance, f'{key}', value)

Expand Down Expand Up @@ -525,7 +525,7 @@ def load_from_json(cls, filepath):
'class and version.'
)

return cls._load_from_dict(metadata)
return cls.load_from_dict(metadata)

def __repr__(self):
"""Pretty print the ``SingleTableMetadata``."""
Expand Down Expand Up @@ -560,7 +560,7 @@ def upgrade_metadata(cls, old_filepath, new_filepath):
old_metadata = list(tables.values())[0]

new_metadata = convert_metadata(old_metadata)
metadata = cls._load_from_dict(new_metadata)
metadata = cls.load_from_dict(new_metadata)
metadata.save_to_json(new_filepath)

try:
Expand Down
2 changes: 1 addition & 1 deletion sdv/sequential/par.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def _get_context_metadata(self):
context_columns_dict[column] = self.metadata.columns[column]

context_metadata_dict = {'columns': context_columns_dict}
return SingleTableMetadata._load_from_dict(context_metadata_dict)
return SingleTableMetadata.load_from_dict(context_metadata_dict)

def __init__(self, metadata, enforce_min_max_values=True, enforce_rounding=False,
context_columns=None, segment_size=None, epochs=128, sample_size=1, cuda=True,
Expand Down
4 changes: 2 additions & 2 deletions tests/integration/single_table/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from sdv.single_table.base import BaseSingleTableSynthesizer
from tests.integration.single_table.custom_constraints import MyConstraint

METADATA = SingleTableMetadata._load_from_dict({
METADATA = SingleTableMetadata.load_from_dict({
'METADATA_SPEC_VERSION': 'SINGLE_TABLE_V1',
'columns': {
'column1': {
Expand Down Expand Up @@ -436,7 +436,7 @@ def test_sampling(synthesizer):
@pytest.mark.parametrize('synthesizer', SYNTHESIZERS)
def test_sampling_reset_sampling(synthesizer):
"""Test ``sample`` method for each synthesizer using ``reset_sampling``."""
metadata = SingleTableMetadata._load_from_dict({
metadata = SingleTableMetadata.load_from_dict({
'METADATA_SPEC_VERSION': 'SINGLE_TABLE_V1',
'columns': {
'column1': {
Expand Down
24 changes: 12 additions & 12 deletions tests/unit/metadata/test_multi_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def get_metadata(self):
}
]

return MultiTableMetadata._load_from_dict(metadata)
return MultiTableMetadata.load_from_dict(metadata)

def test___init__(self):
"""Test the ``__init__`` method of ``MultiTableMetadata``."""
Expand Down Expand Up @@ -594,7 +594,7 @@ def test__validate_single_table(self):
- Errors has been updated with the error message for that column.
"""
# Setup
table_accounts = SingleTableMetadata._load_from_dict({
table_accounts = SingleTableMetadata.load_from_dict({
'columns': {
'id': {'sdtype': 'numerical'},
'branch_id': {'sdtype': 'numerical'},
Expand Down Expand Up @@ -1009,7 +1009,7 @@ def test__set_metadata(self, mock_singletablemetadata):
Side Effects:
- ``instance`` now contains ``instance.tables`` and ``instance.relationships``.
- ``SingleTableMetadata._load_from_dict`` has been called.
- ``SingleTableMetadata.load_from_dict`` has been called.
"""
# Setup
multitable_metadata = {
Expand Down Expand Up @@ -1038,7 +1038,7 @@ def test__set_metadata(self, mock_singletablemetadata):

single_table_accounts = object()
single_table_branches = object()
mock_singletablemetadata._load_from_dict.side_effect = [
mock_singletablemetadata.load_from_dict.side_effect = [
single_table_accounts,
single_table_branches
]
Expand All @@ -1064,10 +1064,10 @@ def test__set_metadata(self, mock_singletablemetadata):
]

@patch('sdv.metadata.multi_table.SingleTableMetadata')
def test__load_from_dict(self, mock_singletablemetadata):
"""Test that ``_load_from_dict`` returns a instance of ``MultiTableMetadata``.
def test_load_from_dict(self, mock_singletablemetadata):
"""Test that ``load_from_dict`` returns a instance of ``MultiTableMetadata``.
Test that when calling the ``_load_from_dict`` method a new instance with the passed
Test that when calling the ``load_from_dict`` method a new instance with the passed
python ``dict`` details should be created.
Setup:
Expand All @@ -1080,7 +1080,7 @@ def test__load_from_dict(self, mock_singletablemetadata):
- ``instance`` that contains ``instance.tables`` and ``instance.relationships``.
Side Effects:
- ``SingleTableMetadata._load_from_dict`` has been called.
- ``SingleTableMetadata.load_from_dict`` has been called.
"""
# Setup
multitable_metadata = {
Expand Down Expand Up @@ -1109,13 +1109,13 @@ def test__load_from_dict(self, mock_singletablemetadata):

single_table_accounts = object()
single_table_branches = object()
mock_singletablemetadata._load_from_dict.side_effect = [
mock_singletablemetadata.load_from_dict.side_effect = [
single_table_accounts,
single_table_branches
]

# Run
instance = MultiTableMetadata._load_from_dict(multitable_metadata)
instance = MultiTableMetadata.load_from_dict(multitable_metadata)

# Assert
assert instance.tables == {
Expand Down Expand Up @@ -1938,7 +1938,7 @@ def test__convert_relationships(self):
@patch('sdv.metadata.multi_table.read_json')
@patch('sdv.metadata.multi_table.MultiTableMetadata._convert_relationships')
@patch('sdv.metadata.multi_table.convert_metadata')
@patch('sdv.metadata.multi_table.MultiTableMetadata._load_from_dict')
@patch('sdv.metadata.multi_table.MultiTableMetadata.load_from_dict')
def test_upgrade_metadata(
self, from_dict_mock, convert_mock, relationships_mock, read_json_mock, validate_mock):
"""Test the ``upgrade_metadata`` method.
Expand Down Expand Up @@ -2024,7 +2024,7 @@ def test_upgrade_metadata(
@patch('sdv.metadata.multi_table.read_json')
@patch('sdv.metadata.multi_table.MultiTableMetadata._convert_relationships')
@patch('sdv.metadata.multi_table.convert_metadata')
@patch('sdv.metadata.multi_table.MultiTableMetadata._load_from_dict')
@patch('sdv.metadata.multi_table.MultiTableMetadata.load_from_dict')
def test_upgrade_metadata_validate_error(
self, from_dict_mock, convert_mock, relationships_mock, read_json_mock, validate_mock,
warnings_mock):
Expand Down
14 changes: 7 additions & 7 deletions tests/unit/metadata/test_single_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -1507,8 +1507,8 @@ def test_to_dict(self):
result['columns']['my_column'] = 1
assert instance.columns['my_column'] == 'value'

def test__load_from_dict(self):
"""Test that ``_load_from_dict`` returns a instance with the ``dict`` updated objects."""
def test_load_from_dict(self):
"""Test that ``load_from_dict`` returns a instance with the ``dict`` updated objects."""
# Setup
my_metadata = {
'columns': {'my_column': 'value'},
Expand All @@ -1520,7 +1520,7 @@ def test__load_from_dict(self):
}

# Run
instance = SingleTableMetadata._load_from_dict(my_metadata)
instance = SingleTableMetadata.load_from_dict(my_metadata)

# Assert
assert instance.columns == {'my_column': 'value'}
Expand Down Expand Up @@ -1725,7 +1725,7 @@ def test___repr__(self, mock_json):
@patch('sdv.metadata.single_table.validate_file_does_not_exist')
@patch('sdv.metadata.single_table.read_json')
@patch('sdv.metadata.single_table.convert_metadata')
@patch('sdv.metadata.single_table.SingleTableMetadata._load_from_dict')
@patch('sdv.metadata.single_table.SingleTableMetadata.load_from_dict')
def test_upgrade_metadata(self, from_dict_mock, convert_mock, read_json_mock, validate_mock):
"""Test the ``upgrade_metadata`` method.
Expand Down Expand Up @@ -1764,7 +1764,7 @@ def test_upgrade_metadata(self, from_dict_mock, convert_mock, read_json_mock, va
@patch('sdv.metadata.single_table.validate_file_does_not_exist')
@patch('sdv.metadata.single_table.read_json')
@patch('sdv.metadata.single_table.convert_metadata')
@patch('sdv.metadata.single_table.SingleTableMetadata._load_from_dict')
@patch('sdv.metadata.single_table.SingleTableMetadata.load_from_dict')
def test_upgrade_metadata_multiple_tables(
self, from_dict_mock, convert_mock, read_json_mock, validate_mock):
"""Test the ``upgrade_metadata`` method.
Expand Down Expand Up @@ -1805,7 +1805,7 @@ def test_upgrade_metadata_multiple_tables(
@patch('sdv.metadata.single_table.validate_file_does_not_exist')
@patch('sdv.metadata.single_table.read_json')
@patch('sdv.metadata.single_table.convert_metadata')
@patch('sdv.metadata.single_table.SingleTableMetadata._load_from_dict')
@patch('sdv.metadata.single_table.SingleTableMetadata.load_from_dict')
def test_upgrade_metadata_multiple_tables_fails(
self, from_dict_mock, convert_mock, read_json_mock, validate_mock):
"""Test the ``upgrade_metadata`` method.
Expand Down Expand Up @@ -1847,7 +1847,7 @@ def test_upgrade_metadata_multiple_tables_fails(
@patch('sdv.metadata.single_table.validate_file_does_not_exist')
@patch('sdv.metadata.single_table.read_json')
@patch('sdv.metadata.single_table.convert_metadata')
@patch('sdv.metadata.single_table.SingleTableMetadata._load_from_dict')
@patch('sdv.metadata.single_table.SingleTableMetadata.load_from_dict')
def test_upgrade_metadata_validate_error(
self, from_dict_mock, convert_mock, read_json_mock, validate_mock, warnings_mock):
"""Test the ``upgrade_metadata`` method.
Expand Down
2 changes: 1 addition & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def get_multi_table_metadata():
'METADATA_SPEC_VERSION': 'MULTI_TABLE_V1'
}

return MultiTableMetadata._load_from_dict(dict_metadata)
return MultiTableMetadata.load_from_dict(dict_metadata)


def get_multi_table_data():
Expand Down

0 comments on commit dfd1a59

Please sign in to comment.